From f83a56d31756227b5673d841a522fbad1209b10d Mon Sep 17 00:00:00 2001 From: gfyoung Date: Mon, 7 Aug 2017 21:18:35 -0700 Subject: [PATCH] Refactor out global settings usage in update logic --- testUpdateHostsFile.py | 93 +++++++++++++++++++++++++++++++++++------- updateHostsFile.py | 33 ++++++++++++--- 2 files changed, 107 insertions(+), 19 deletions(-) diff --git a/testUpdateHostsFile.py b/testUpdateHostsFile.py index 139bd8955..444c71b35 100644 --- a/testUpdateHostsFile.py +++ b/testUpdateHostsFile.py @@ -12,8 +12,8 @@ from updateHostsFile import ( normalize_rule, path_join_robust, print_failure, print_success, prompt_for_exclusions, prompt_for_move, prompt_for_flush_dns_cache, prompt_for_update, query_yes_no, recursive_glob, remove_old_hosts_file, - supports_color, strip_rule, update_readme_data, write_data, - write_opening_header) + supports_color, strip_rule, update_all_sources, update_readme_data, + write_data, write_opening_header) import updateHostsFile import unittest @@ -190,9 +190,8 @@ class TestPromptForUpdate(BaseStdout, BaseMockDir): sys.stdout = StringIO() - @mock.patch("updateHostsFile.update_all_sources", return_value=0) @mock.patch("updateHostsFile.query_yes_no", return_value=False) - def test_freshen_no_update(self, _, mock_update): + def test_freshen_no_update(self, _): hosts_file = os.path.join(self.test_dir, "hosts") hosts_data = "This data should not be overwritten" @@ -204,10 +203,8 @@ class TestPromptForUpdate(BaseStdout, BaseMockDir): dir_count = self.dir_count - prompt_for_update(freshen=True, update_auto=False) - - mock_update.assert_not_called() - mock_update.reset_mock() + update_sources = prompt_for_update(freshen=True, update_auto=False) + self.assertFalse(update_sources) output = sys.stdout.getvalue() expected = ("OK, we'll stick with " @@ -222,9 +219,8 @@ class TestPromptForUpdate(BaseStdout, BaseMockDir): contents = f.read() self.assertEqual(contents, hosts_data) - @mock.patch("updateHostsFile.update_all_sources", return_value=0) @mock.patch("updateHostsFile.query_yes_no", return_value=True) - def test_freshen_update(self, _, mock_update): + def test_freshen_update(self, _): hosts_file = os.path.join(self.test_dir, "hosts") hosts_data = "This data should not be overwritten" @@ -237,10 +233,9 @@ class TestPromptForUpdate(BaseStdout, BaseMockDir): dir_count = self.dir_count for update_auto in (False, True): - prompt_for_update(freshen=True, update_auto=update_auto) - - self.assert_called_once(mock_update) - mock_update.reset_mock() + update_sources = prompt_for_update(freshen=True, + update_auto=update_auto) + self.assertTrue(update_sources) output = sys.stdout.getvalue() self.assertEqual(output, "") @@ -547,6 +542,76 @@ class TestMatchesExclusions(Base): # End Exclusion Logic +# Update Logic +class TestUpdateAllSources(BaseStdout): + + def setUp(self): + BaseStdout.setUp(self) + + self.source_data_filename = "data.json" + self.host_filename = "hosts.txt" + + @mock.patch(builtins() + ".open") + @mock.patch("updateHostsFile.recursive_glob", return_value=[]) + def test_no_sources(self, _, mock_open): + update_all_sources(self.source_data_filename, self.host_filename) + mock_open.assert_not_called() + + @mock.patch(builtins() + ".open", return_value=mock.Mock()) + @mock.patch("json.load", return_value={"url": "example.com"}) + @mock.patch("updateHostsFile.recursive_glob", return_value=["foo"]) + @mock.patch("updateHostsFile.write_data", return_value=0) + @mock.patch("updateHostsFile.get_file_by_url", return_value="file_data") + def test_one_source(self, mock_get, mock_write, *_): + update_all_sources(self.source_data_filename, self.host_filename) + self.assert_called_once(mock_write) + self.assert_called_once(mock_get) + + output = sys.stdout.getvalue() + expected = "Updating source from example.com" + + self.assertIn(expected, output) + + @mock.patch(builtins() + ".open", return_value=mock.Mock()) + @mock.patch("json.load", return_value={"url": "example.com"}) + @mock.patch("updateHostsFile.recursive_glob", return_value=["foo"]) + @mock.patch("updateHostsFile.write_data", return_value=0) + @mock.patch("updateHostsFile.get_file_by_url", + return_value=Exception("fail")) + def test_source_fail(self, mock_get, mock_write, *_): + update_all_sources(self.source_data_filename, self.host_filename) + mock_write.assert_not_called() + self.assert_called_once(mock_get) + + output = sys.stdout.getvalue() + expecteds = ["Updating source from example.com", + "Error in updating source: example.com"] + for expected in expecteds: + self.assertIn(expected, output) + + @mock.patch(builtins() + ".open", return_value=mock.Mock()) + @mock.patch("json.load", side_effect=[{"url": "example.com"}, + {"url": "example2.com"}]) + @mock.patch("updateHostsFile.recursive_glob", return_value=["foo", "bar"]) + @mock.patch("updateHostsFile.write_data", return_value=0) + @mock.patch("updateHostsFile.get_file_by_url", + side_effect=[Exception("fail"), "file_data"]) + def test_sources_fail_succeed(self, mock_get, mock_write, *_): + update_all_sources(self.source_data_filename, self.host_filename) + self.assert_called_once(mock_write) + + get_calls = [mock.call("example.com"), mock.call("example2.com")] + mock_get.assert_has_calls(get_calls) + + output = sys.stdout.getvalue() + expecteds = ["Updating source from example.com", + "Error in updating source: example.com", + "Updating source from example2.com"] + for expected in expecteds: + self.assertIn(expected, output) +# End Update Logic + + # File Logic class TestNormalizeRule(BaseStdout): diff --git a/updateHostsFile.py b/updateHostsFile.py index dbcfe7696..7f429ce86 100644 --- a/updateHostsFile.py +++ b/updateHostsFile.py @@ -149,7 +149,12 @@ def main(): auto = settings["auto"] exclusion_regexes = settings["exclusionregexs"] - prompt_for_update(freshen=settings["freshen"], update_auto=auto) + update_sources = prompt_for_update(freshen=settings["freshen"], + update_auto=auto) + if update_sources: + update_all_sources(settings["sourcedatafilename"], + settings["hostfilename"]) + gather_exclusions = prompt_for_exclusions(skip_prompt=auto) if gather_exclusions: @@ -221,6 +226,11 @@ def prompt_for_update(freshen, update_auto): if it is requested that data sources not be updated. update_auto : bool Whether or not to automatically update all data sources. + + Returns + ------- + update_sources : bool + Whether or not we should update data sources for exclusion files. """ # Create a hosts file if it doesn't exist. @@ -242,10 +252,12 @@ def prompt_for_update(freshen, update_auto): prompt = "Do you want to update all data sources?" if update_auto or query_yes_no(prompt): - update_all_sources() + return True elif not update_auto: print("OK, we'll stick with what we've got locally.") + return False + def prompt_for_exclusions(skip_prompt): """ @@ -478,12 +490,23 @@ def matches_exclusions(stripped_rule, exclusion_regexes): # Update Logic -def update_all_sources(): +def update_all_sources(source_data_filename, host_filename): """ Update all host files, regardless of folder depth. + + Parameters + ---------- + source_data_filename : str + The name of the filename where information regarding updating + sources for a particular URL is stored. This filename is assumed + to be the same for all sources. + host_filename : str + The name of the file in which the updated source information + in stored for a particular URL. This filename is assumed to be + the same for all sources. """ - all_sources = recursive_glob("*", settings["sourcedatafilename"]) + all_sources = recursive_glob("*", source_data_filename) for source in all_sources: update_file = open(source, "r") @@ -502,7 +525,7 @@ def update_all_sources(): hosts_file = open(path_join_robust(BASEDIR_PATH, os.path.dirname(source), - settings["hostfilename"]), "wb") + host_filename), "wb") write_data(hosts_file, updated_file) hosts_file.close() except Exception: