mirror of
https://github.com/StevenBlack/hosts.git
synced 2024-07-04 19:46:02 +02:00
Refactor out global settings usage in update logic
This commit is contained in:
parent
18e89f121e
commit
f83a56d317
|
@ -12,8 +12,8 @@ from updateHostsFile import (
|
|||
normalize_rule, path_join_robust, print_failure, print_success,
|
||||
prompt_for_exclusions, prompt_for_move, prompt_for_flush_dns_cache,
|
||||
prompt_for_update, query_yes_no, recursive_glob, remove_old_hosts_file,
|
||||
supports_color, strip_rule, update_readme_data, write_data,
|
||||
write_opening_header)
|
||||
supports_color, strip_rule, update_all_sources, update_readme_data,
|
||||
write_data, write_opening_header)
|
||||
|
||||
import updateHostsFile
|
||||
import unittest
|
||||
|
@ -190,9 +190,8 @@ class TestPromptForUpdate(BaseStdout, BaseMockDir):
|
|||
|
||||
sys.stdout = StringIO()
|
||||
|
||||
@mock.patch("updateHostsFile.update_all_sources", return_value=0)
|
||||
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
|
||||
def test_freshen_no_update(self, _, mock_update):
|
||||
def test_freshen_no_update(self, _):
|
||||
hosts_file = os.path.join(self.test_dir, "hosts")
|
||||
hosts_data = "This data should not be overwritten"
|
||||
|
||||
|
@ -204,10 +203,8 @@ class TestPromptForUpdate(BaseStdout, BaseMockDir):
|
|||
|
||||
dir_count = self.dir_count
|
||||
|
||||
prompt_for_update(freshen=True, update_auto=False)
|
||||
|
||||
mock_update.assert_not_called()
|
||||
mock_update.reset_mock()
|
||||
update_sources = prompt_for_update(freshen=True, update_auto=False)
|
||||
self.assertFalse(update_sources)
|
||||
|
||||
output = sys.stdout.getvalue()
|
||||
expected = ("OK, we'll stick with "
|
||||
|
@ -222,9 +219,8 @@ class TestPromptForUpdate(BaseStdout, BaseMockDir):
|
|||
contents = f.read()
|
||||
self.assertEqual(contents, hosts_data)
|
||||
|
||||
@mock.patch("updateHostsFile.update_all_sources", return_value=0)
|
||||
@mock.patch("updateHostsFile.query_yes_no", return_value=True)
|
||||
def test_freshen_update(self, _, mock_update):
|
||||
def test_freshen_update(self, _):
|
||||
hosts_file = os.path.join(self.test_dir, "hosts")
|
||||
hosts_data = "This data should not be overwritten"
|
||||
|
||||
|
@ -237,10 +233,9 @@ class TestPromptForUpdate(BaseStdout, BaseMockDir):
|
|||
dir_count = self.dir_count
|
||||
|
||||
for update_auto in (False, True):
|
||||
prompt_for_update(freshen=True, update_auto=update_auto)
|
||||
|
||||
self.assert_called_once(mock_update)
|
||||
mock_update.reset_mock()
|
||||
update_sources = prompt_for_update(freshen=True,
|
||||
update_auto=update_auto)
|
||||
self.assertTrue(update_sources)
|
||||
|
||||
output = sys.stdout.getvalue()
|
||||
self.assertEqual(output, "")
|
||||
|
@ -547,6 +542,76 @@ class TestMatchesExclusions(Base):
|
|||
# End Exclusion Logic
|
||||
|
||||
|
||||
# Update Logic
|
||||
class TestUpdateAllSources(BaseStdout):
|
||||
|
||||
def setUp(self):
|
||||
BaseStdout.setUp(self)
|
||||
|
||||
self.source_data_filename = "data.json"
|
||||
self.host_filename = "hosts.txt"
|
||||
|
||||
@mock.patch(builtins() + ".open")
|
||||
@mock.patch("updateHostsFile.recursive_glob", return_value=[])
|
||||
def test_no_sources(self, _, mock_open):
|
||||
update_all_sources(self.source_data_filename, self.host_filename)
|
||||
mock_open.assert_not_called()
|
||||
|
||||
@mock.patch(builtins() + ".open", return_value=mock.Mock())
|
||||
@mock.patch("json.load", return_value={"url": "example.com"})
|
||||
@mock.patch("updateHostsFile.recursive_glob", return_value=["foo"])
|
||||
@mock.patch("updateHostsFile.write_data", return_value=0)
|
||||
@mock.patch("updateHostsFile.get_file_by_url", return_value="file_data")
|
||||
def test_one_source(self, mock_get, mock_write, *_):
|
||||
update_all_sources(self.source_data_filename, self.host_filename)
|
||||
self.assert_called_once(mock_write)
|
||||
self.assert_called_once(mock_get)
|
||||
|
||||
output = sys.stdout.getvalue()
|
||||
expected = "Updating source from example.com"
|
||||
|
||||
self.assertIn(expected, output)
|
||||
|
||||
@mock.patch(builtins() + ".open", return_value=mock.Mock())
|
||||
@mock.patch("json.load", return_value={"url": "example.com"})
|
||||
@mock.patch("updateHostsFile.recursive_glob", return_value=["foo"])
|
||||
@mock.patch("updateHostsFile.write_data", return_value=0)
|
||||
@mock.patch("updateHostsFile.get_file_by_url",
|
||||
return_value=Exception("fail"))
|
||||
def test_source_fail(self, mock_get, mock_write, *_):
|
||||
update_all_sources(self.source_data_filename, self.host_filename)
|
||||
mock_write.assert_not_called()
|
||||
self.assert_called_once(mock_get)
|
||||
|
||||
output = sys.stdout.getvalue()
|
||||
expecteds = ["Updating source from example.com",
|
||||
"Error in updating source: example.com"]
|
||||
for expected in expecteds:
|
||||
self.assertIn(expected, output)
|
||||
|
||||
@mock.patch(builtins() + ".open", return_value=mock.Mock())
|
||||
@mock.patch("json.load", side_effect=[{"url": "example.com"},
|
||||
{"url": "example2.com"}])
|
||||
@mock.patch("updateHostsFile.recursive_glob", return_value=["foo", "bar"])
|
||||
@mock.patch("updateHostsFile.write_data", return_value=0)
|
||||
@mock.patch("updateHostsFile.get_file_by_url",
|
||||
side_effect=[Exception("fail"), "file_data"])
|
||||
def test_sources_fail_succeed(self, mock_get, mock_write, *_):
|
||||
update_all_sources(self.source_data_filename, self.host_filename)
|
||||
self.assert_called_once(mock_write)
|
||||
|
||||
get_calls = [mock.call("example.com"), mock.call("example2.com")]
|
||||
mock_get.assert_has_calls(get_calls)
|
||||
|
||||
output = sys.stdout.getvalue()
|
||||
expecteds = ["Updating source from example.com",
|
||||
"Error in updating source: example.com",
|
||||
"Updating source from example2.com"]
|
||||
for expected in expecteds:
|
||||
self.assertIn(expected, output)
|
||||
# End Update Logic
|
||||
|
||||
|
||||
# File Logic
|
||||
class TestNormalizeRule(BaseStdout):
|
||||
|
||||
|
|
|
@ -149,7 +149,12 @@ def main():
|
|||
auto = settings["auto"]
|
||||
exclusion_regexes = settings["exclusionregexs"]
|
||||
|
||||
prompt_for_update(freshen=settings["freshen"], update_auto=auto)
|
||||
update_sources = prompt_for_update(freshen=settings["freshen"],
|
||||
update_auto=auto)
|
||||
if update_sources:
|
||||
update_all_sources(settings["sourcedatafilename"],
|
||||
settings["hostfilename"])
|
||||
|
||||
gather_exclusions = prompt_for_exclusions(skip_prompt=auto)
|
||||
|
||||
if gather_exclusions:
|
||||
|
@ -221,6 +226,11 @@ def prompt_for_update(freshen, update_auto):
|
|||
if it is requested that data sources not be updated.
|
||||
update_auto : bool
|
||||
Whether or not to automatically update all data sources.
|
||||
|
||||
Returns
|
||||
-------
|
||||
update_sources : bool
|
||||
Whether or not we should update data sources for exclusion files.
|
||||
"""
|
||||
|
||||
# Create a hosts file if it doesn't exist.
|
||||
|
@ -242,10 +252,12 @@ def prompt_for_update(freshen, update_auto):
|
|||
prompt = "Do you want to update all data sources?"
|
||||
|
||||
if update_auto or query_yes_no(prompt):
|
||||
update_all_sources()
|
||||
return True
|
||||
elif not update_auto:
|
||||
print("OK, we'll stick with what we've got locally.")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def prompt_for_exclusions(skip_prompt):
|
||||
"""
|
||||
|
@ -478,12 +490,23 @@ def matches_exclusions(stripped_rule, exclusion_regexes):
|
|||
|
||||
|
||||
# Update Logic
|
||||
def update_all_sources():
|
||||
def update_all_sources(source_data_filename, host_filename):
|
||||
"""
|
||||
Update all host files, regardless of folder depth.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_data_filename : str
|
||||
The name of the filename where information regarding updating
|
||||
sources for a particular URL is stored. This filename is assumed
|
||||
to be the same for all sources.
|
||||
host_filename : str
|
||||
The name of the file in which the updated source information
|
||||
in stored for a particular URL. This filename is assumed to be
|
||||
the same for all sources.
|
||||
"""
|
||||
|
||||
all_sources = recursive_glob("*", settings["sourcedatafilename"])
|
||||
all_sources = recursive_glob("*", source_data_filename)
|
||||
|
||||
for source in all_sources:
|
||||
update_file = open(source, "r")
|
||||
|
@ -502,7 +525,7 @@ def update_all_sources():
|
|||
|
||||
hosts_file = open(path_join_robust(BASEDIR_PATH,
|
||||
os.path.dirname(source),
|
||||
settings["hostfilename"]), "wb")
|
||||
host_filename), "wb")
|
||||
write_data(hosts_file, updated_file)
|
||||
hosts_file.close()
|
||||
except Exception:
|
||||
|
|
Loading…
Reference in New Issue
Block a user