Refactor out global settings usage in update logic

This commit is contained in:
gfyoung 2017-08-07 21:18:35 -07:00
parent 18e89f121e
commit f83a56d317
2 changed files with 107 additions and 19 deletions

View File

@ -12,8 +12,8 @@ from updateHostsFile import (
normalize_rule, path_join_robust, print_failure, print_success,
prompt_for_exclusions, prompt_for_move, prompt_for_flush_dns_cache,
prompt_for_update, query_yes_no, recursive_glob, remove_old_hosts_file,
supports_color, strip_rule, update_readme_data, write_data,
write_opening_header)
supports_color, strip_rule, update_all_sources, update_readme_data,
write_data, write_opening_header)
import updateHostsFile
import unittest
@ -190,9 +190,8 @@ class TestPromptForUpdate(BaseStdout, BaseMockDir):
sys.stdout = StringIO()
@mock.patch("updateHostsFile.update_all_sources", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def test_freshen_no_update(self, _, mock_update):
def test_freshen_no_update(self, _):
hosts_file = os.path.join(self.test_dir, "hosts")
hosts_data = "This data should not be overwritten"
@ -204,10 +203,8 @@ class TestPromptForUpdate(BaseStdout, BaseMockDir):
dir_count = self.dir_count
prompt_for_update(freshen=True, update_auto=False)
mock_update.assert_not_called()
mock_update.reset_mock()
update_sources = prompt_for_update(freshen=True, update_auto=False)
self.assertFalse(update_sources)
output = sys.stdout.getvalue()
expected = ("OK, we'll stick with "
@ -222,9 +219,8 @@ class TestPromptForUpdate(BaseStdout, BaseMockDir):
contents = f.read()
self.assertEqual(contents, hosts_data)
@mock.patch("updateHostsFile.update_all_sources", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=True)
def test_freshen_update(self, _, mock_update):
def test_freshen_update(self, _):
hosts_file = os.path.join(self.test_dir, "hosts")
hosts_data = "This data should not be overwritten"
@ -237,10 +233,9 @@ class TestPromptForUpdate(BaseStdout, BaseMockDir):
dir_count = self.dir_count
for update_auto in (False, True):
prompt_for_update(freshen=True, update_auto=update_auto)
self.assert_called_once(mock_update)
mock_update.reset_mock()
update_sources = prompt_for_update(freshen=True,
update_auto=update_auto)
self.assertTrue(update_sources)
output = sys.stdout.getvalue()
self.assertEqual(output, "")
@ -547,6 +542,76 @@ class TestMatchesExclusions(Base):
# End Exclusion Logic
# Update Logic
class TestUpdateAllSources(BaseStdout):
def setUp(self):
BaseStdout.setUp(self)
self.source_data_filename = "data.json"
self.host_filename = "hosts.txt"
@mock.patch(builtins() + ".open")
@mock.patch("updateHostsFile.recursive_glob", return_value=[])
def test_no_sources(self, _, mock_open):
update_all_sources(self.source_data_filename, self.host_filename)
mock_open.assert_not_called()
@mock.patch(builtins() + ".open", return_value=mock.Mock())
@mock.patch("json.load", return_value={"url": "example.com"})
@mock.patch("updateHostsFile.recursive_glob", return_value=["foo"])
@mock.patch("updateHostsFile.write_data", return_value=0)
@mock.patch("updateHostsFile.get_file_by_url", return_value="file_data")
def test_one_source(self, mock_get, mock_write, *_):
update_all_sources(self.source_data_filename, self.host_filename)
self.assert_called_once(mock_write)
self.assert_called_once(mock_get)
output = sys.stdout.getvalue()
expected = "Updating source from example.com"
self.assertIn(expected, output)
@mock.patch(builtins() + ".open", return_value=mock.Mock())
@mock.patch("json.load", return_value={"url": "example.com"})
@mock.patch("updateHostsFile.recursive_glob", return_value=["foo"])
@mock.patch("updateHostsFile.write_data", return_value=0)
@mock.patch("updateHostsFile.get_file_by_url",
return_value=Exception("fail"))
def test_source_fail(self, mock_get, mock_write, *_):
update_all_sources(self.source_data_filename, self.host_filename)
mock_write.assert_not_called()
self.assert_called_once(mock_get)
output = sys.stdout.getvalue()
expecteds = ["Updating source from example.com",
"Error in updating source: example.com"]
for expected in expecteds:
self.assertIn(expected, output)
@mock.patch(builtins() + ".open", return_value=mock.Mock())
@mock.patch("json.load", side_effect=[{"url": "example.com"},
{"url": "example2.com"}])
@mock.patch("updateHostsFile.recursive_glob", return_value=["foo", "bar"])
@mock.patch("updateHostsFile.write_data", return_value=0)
@mock.patch("updateHostsFile.get_file_by_url",
side_effect=[Exception("fail"), "file_data"])
def test_sources_fail_succeed(self, mock_get, mock_write, *_):
update_all_sources(self.source_data_filename, self.host_filename)
self.assert_called_once(mock_write)
get_calls = [mock.call("example.com"), mock.call("example2.com")]
mock_get.assert_has_calls(get_calls)
output = sys.stdout.getvalue()
expecteds = ["Updating source from example.com",
"Error in updating source: example.com",
"Updating source from example2.com"]
for expected in expecteds:
self.assertIn(expected, output)
# End Update Logic
# File Logic
class TestNormalizeRule(BaseStdout):

View File

@ -149,7 +149,12 @@ def main():
auto = settings["auto"]
exclusion_regexes = settings["exclusionregexs"]
prompt_for_update(freshen=settings["freshen"], update_auto=auto)
update_sources = prompt_for_update(freshen=settings["freshen"],
update_auto=auto)
if update_sources:
update_all_sources(settings["sourcedatafilename"],
settings["hostfilename"])
gather_exclusions = prompt_for_exclusions(skip_prompt=auto)
if gather_exclusions:
@ -221,6 +226,11 @@ def prompt_for_update(freshen, update_auto):
if it is requested that data sources not be updated.
update_auto : bool
Whether or not to automatically update all data sources.
Returns
-------
update_sources : bool
Whether or not we should update data sources for exclusion files.
"""
# Create a hosts file if it doesn't exist.
@ -242,10 +252,12 @@ def prompt_for_update(freshen, update_auto):
prompt = "Do you want to update all data sources?"
if update_auto or query_yes_no(prompt):
update_all_sources()
return True
elif not update_auto:
print("OK, we'll stick with what we've got locally.")
return False
def prompt_for_exclusions(skip_prompt):
"""
@ -478,12 +490,23 @@ def matches_exclusions(stripped_rule, exclusion_regexes):
# Update Logic
def update_all_sources():
def update_all_sources(source_data_filename, host_filename):
"""
Update all host files, regardless of folder depth.
Parameters
----------
source_data_filename : str
The name of the filename where information regarding updating
sources for a particular URL is stored. This filename is assumed
to be the same for all sources.
host_filename : str
The name of the file in which the updated source information
in stored for a particular URL. This filename is assumed to be
the same for all sources.
"""
all_sources = recursive_glob("*", settings["sourcedatafilename"])
all_sources = recursive_glob("*", source_data_filename)
for source in all_sources:
update_file = open(source, "r")
@ -502,7 +525,7 @@ def update_all_sources():
hosts_file = open(path_join_robust(BASEDIR_PATH,
os.path.dirname(source),
settings["hostfilename"]), "wb")
host_filename), "wb")
write_data(hosts_file, updated_file)
hosts_file.close()
except Exception: