From 4b96f3f34ad8dc06ec237ef4c7796283f5e07754 Mon Sep 17 00:00:00 2001 From: gfyoung Date: Sun, 9 Jul 2017 12:00:11 -0700 Subject: [PATCH] Refactor out global settings usage in user prompt --- testUpdateHostsFile.py | 314 ++++++++++++++++++++++++++++++++++++++++- updateHostsFile.py | 106 ++++++++++---- 2 files changed, 391 insertions(+), 29 deletions(-) diff --git a/testUpdateHostsFile.py b/testUpdateHostsFile.py index c1f2e07de..9e8f5e30f 100644 --- a/testUpdateHostsFile.py +++ b/testUpdateHostsFile.py @@ -10,8 +10,10 @@ from updateHostsFile import (Colors, PY3, colorize, flush_dns_cache, get_file_by_url, is_valid_domain_format, move_hosts_file_into_place, normalize_rule, path_join_robust, print_failure, print_success, - supports_color, query_yes_no, recursive_glob, - remove_old_hosts_file, strip_rule, + prompt_for_exclusions, prompt_for_move, + prompt_for_flush_dns_cache, prompt_for_update, + query_yes_no, recursive_glob, + remove_old_hosts_file, supports_color, strip_rule, update_readme_data, write_data, write_opening_header) import updateHostsFile @@ -33,7 +35,7 @@ else: import mock -# Base Test Classes +# Test Helper Objects class Base(unittest.TestCase): @staticmethod @@ -66,7 +68,14 @@ class BaseMockDir(Base): def tearDown(self): shutil.rmtree(self.test_dir) -# End Base Test Classes + + +def builtins(): + if PY3: + return "builtins" + else: + return "__builtin__" +# End Test Helper Objects # Project Settings @@ -107,6 +116,303 @@ class TestGetDefaults(Base): # End Project Settings +# Prompt the User +class TestPromptForUpdate(BaseStdout, BaseMockDir): + + def setUp(self): + BaseStdout.setUp(self) + BaseMockDir.setUp(self) + + def test_no_freshen_no_new_file(self): + hosts_file = os.path.join(self.test_dir, "hosts") + hosts_data = "This data should not be overwritten" + + with self.mock_property("updateHostsFile.BASEDIR_PATH"): + updateHostsFile.BASEDIR_PATH = self.test_dir + + with open(hosts_file, "w") as f: + f.write(hosts_data) + + for update_auto in (False, True): + dir_count = self.dir_count + prompt_for_update(freshen=False, update_auto=update_auto) + + output = sys.stdout.getvalue() + self.assertEqual(output, "") + + sys.stdout = StringIO() + + self.assertEqual(self.dir_count, dir_count) + + with open(hosts_file, "r") as f: + contents = f.read() + self.assertEqual(contents, hosts_data) + + def test_no_freshen_new_file(self): + hosts_file = os.path.join(self.test_dir, "hosts") + + with self.mock_property("updateHostsFile.BASEDIR_PATH"): + updateHostsFile.BASEDIR_PATH = self.test_dir + + dir_count = self.dir_count + prompt_for_update(freshen=False, update_auto=False) + + output = sys.stdout.getvalue() + self.assertEqual(output, "") + + sys.stdout = StringIO() + + self.assertEqual(self.dir_count, dir_count + 1) + + with open(hosts_file, "r") as f: + contents = f.read() + self.assertEqual(contents, "") + + @mock.patch(builtins() + ".open") + def test_no_freshen_fail_new_file(self, mock_open): + for exc in (IOError, OSError): + mock_open.side_effect = exc("failed open") + + with self.mock_property("updateHostsFile.BASEDIR_PATH"): + updateHostsFile.BASEDIR_PATH = self.test_dir + prompt_for_update(freshen=False, update_auto=False) + + output = sys.stdout.getvalue() + expected = ("ERROR: No 'hosts' file in the folder. " + "Try creating one manually.") + self.assertIn(expected, output) + + sys.stdout = StringIO() + + @mock.patch("updateHostsFile.update_all_sources", return_value=0) + @mock.patch("updateHostsFile.query_yes_no", return_value=False) + def test_freshen_no_update(self, _, mock_update): + hosts_file = os.path.join(self.test_dir, "hosts") + hosts_data = "This data should not be overwritten" + + with self.mock_property("updateHostsFile.BASEDIR_PATH"): + updateHostsFile.BASEDIR_PATH = self.test_dir + + with open(hosts_file, "w") as f: + f.write(hosts_data) + + dir_count = self.dir_count + + prompt_for_update(freshen=True, update_auto=False) + + mock_update.assert_not_called() + mock_update.reset_mock() + + output = sys.stdout.getvalue() + expected = ("OK, we'll stick with " + "what we've got locally.") + self.assertIn(expected, output) + + sys.stdout = StringIO() + + self.assertEqual(self.dir_count, dir_count) + + with open(hosts_file, "r") as f: + contents = f.read() + self.assertEqual(contents, hosts_data) + + @mock.patch("updateHostsFile.update_all_sources", return_value=0) + @mock.patch("updateHostsFile.query_yes_no", return_value=True) + def test_freshen_update(self, _, mock_update): + hosts_file = os.path.join(self.test_dir, "hosts") + hosts_data = "This data should not be overwritten" + + with self.mock_property("updateHostsFile.BASEDIR_PATH"): + updateHostsFile.BASEDIR_PATH = self.test_dir + + with open(hosts_file, "w") as f: + f.write(hosts_data) + + dir_count = self.dir_count + + for update_auto in (False, True): + prompt_for_update(freshen=True, update_auto=update_auto) + + mock_update.assert_called_once() + mock_update.reset_mock() + + output = sys.stdout.getvalue() + self.assertEqual(output, "") + + sys.stdout = StringIO() + + self.assertEqual(self.dir_count, dir_count) + + with open(hosts_file, "r") as f: + contents = f.read() + self.assertEqual(contents, hosts_data) + + def tearDown(self): + BaseStdout.tearDown(self) + BaseStdout.tearDown(self) + + +class TestPromptForExclusions(BaseStdout): + + @mock.patch("updateHostsFile.display_exclusion_options", return_value=0) + @mock.patch("updateHostsFile.query_yes_no", return_value=False) + def testSkipPrompt(self, mock_query, mock_display): + prompt_for_exclusions(skip_prompt=True) + + output = sys.stdout.getvalue() + self.assertEqual(output, "") + + mock_query.assert_not_called() + mock_display.assert_not_called() + + @mock.patch("updateHostsFile.display_exclusion_options", return_value=0) + @mock.patch("updateHostsFile.query_yes_no", return_value=False) + def testNoSkipPromptNoDisplay(self, mock_query, mock_display): + prompt_for_exclusions(skip_prompt=False) + + output = sys.stdout.getvalue() + expected = ("OK, we'll only exclude " + "domains in the whitelist.") + self.assertIn(expected, output) + + mock_query.assert_called_once() + mock_display.assert_not_called() + + @mock.patch("updateHostsFile.display_exclusion_options", return_value=0) + @mock.patch("updateHostsFile.query_yes_no", return_value=True) + def testNoSkipPromptDisplay(self, mock_query, mock_display): + prompt_for_exclusions(skip_prompt=False) + + output = sys.stdout.getvalue() + self.assertEqual(output, "") + + mock_query.assert_called_once() + mock_display.assert_called_once() + + +class TestPromptForFlushDnsCache(Base): + + @mock.patch("updateHostsFile.flush_dns_cache", return_value=0) + @mock.patch("updateHostsFile.query_yes_no", return_value=False) + def testFlushCache(self, mock_query, mock_flush): + for prompt_flush in (False, True): + prompt_for_flush_dns_cache(flush_cache=True, + prompt_flush=prompt_flush) + + mock_query.assert_not_called() + mock_flush.assert_called_once() + + mock_query.reset_mock() + mock_flush.reset_mock() + + @mock.patch("updateHostsFile.flush_dns_cache", return_value=0) + @mock.patch("updateHostsFile.query_yes_no", return_value=False) + def testNoFlushCacheNoPrompt(self, mock_query, mock_flush): + prompt_for_flush_dns_cache(flush_cache=False, + prompt_flush=False) + + mock_query.assert_not_called() + mock_flush.assert_not_called() + + @mock.patch("updateHostsFile.flush_dns_cache", return_value=0) + @mock.patch("updateHostsFile.query_yes_no", return_value=False) + def testNoFlushCachePromptNoFlush(self, mock_query, mock_flush): + prompt_for_flush_dns_cache(flush_cache=False, + prompt_flush=True) + + mock_query.assert_called_once() + mock_flush.assert_not_called() + + @mock.patch("updateHostsFile.flush_dns_cache", return_value=0) + @mock.patch("updateHostsFile.query_yes_no", return_value=True) + def testNoFlushCachePromptFlush(self, mock_query, mock_flush): + prompt_for_flush_dns_cache(flush_cache=False, + prompt_flush=True) + + mock_query.assert_called_once() + mock_flush.assert_called_once() + + +class TestPromptForMove(Base): + + def setUp(self): + Base.setUp(self) + self.final_file = "final.txt" + + def prompt_for_move(self, **move_params): + return prompt_for_move(self.final_file, **move_params) + + @mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0) + @mock.patch("updateHostsFile.query_yes_no", return_value=False) + def testSkipStaticHosts(self, mock_query, mock_move): + for replace in (False, True): + for auto in (False, True): + move_file = self.prompt_for_move(replace=replace, auto=auto, + skipstatichosts=True) + self.assertFalse(move_file) + + mock_query.assert_not_called() + mock_move.assert_not_called() + + mock_query.reset_mock() + mock_move.reset_mock() + + @mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0) + @mock.patch("updateHostsFile.query_yes_no", return_value=False) + def testReplaceNoSkipStaticHosts(self, mock_query, mock_move): + for auto in (False, True): + move_file = self.prompt_for_move(replace=True, auto=auto, + skipstatichosts=False) + self.assertTrue(move_file) + + mock_query.assert_not_called() + mock_move.assert_called_once() + + mock_query.reset_mock() + mock_move.reset_mock() + + @mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0) + @mock.patch("updateHostsFile.query_yes_no", return_value=False) + def testAutoNoSkipStaticHosts(self, mock_query, mock_move): + for replace in (False, True): + move_file = self.prompt_for_move(replace=replace, auto=True, + skipstatichosts=True) + self.assertFalse(move_file) + + mock_query.assert_not_called() + mock_move.assert_not_called() + + mock_query.reset_mock() + mock_move.reset_mock() + + @mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0) + @mock.patch("updateHostsFile.query_yes_no", return_value=False) + def testPromptNoMove(self, mock_query, mock_move): + move_file = self.prompt_for_move(replace=False, auto=False, + skipstatichosts=False) + self.assertFalse(move_file) + + mock_query.assert_called_once() + mock_move.assert_not_called() + + mock_query.reset_mock() + mock_move.reset_mock() + + @mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0) + @mock.patch("updateHostsFile.query_yes_no", return_value=True) + def testPromptMove(self, mock_query, mock_move): + move_file = self.prompt_for_move(replace=False, auto=False, + skipstatichosts=False) + self.assertTrue(move_file) + + mock_query.assert_called_once() + mock_move.assert_called_once() + + mock_query.reset_mock() + mock_move.reset_mock() +# End Prompt the User + + # Exclusion Logic class TestGatherCustomExclusions(BaseStdout): diff --git a/updateHostsFile.py b/updateHostsFile.py index 4d795e7ec..96c38706a 100644 --- a/updateHostsFile.py +++ b/updateHostsFile.py @@ -146,8 +146,10 @@ def main(): settings["extensions"] = sorted(list( set(options["extensions"]).intersection(settings["extensions"]))) - prompt_for_update() - prompt_for_exclusions() + auto = settings["auto"] + + prompt_for_update(freshen=settings["freshen"], update_auto=auto) + prompt_for_exclusions(skip_prompt=auto) merge_file = create_initial_file() remove_old_hosts_file(settings["backup"]) @@ -158,11 +160,12 @@ def main(): final_file = remove_dups_and_excl(merge_file) number_of_rules = settings["numberofrules"] + skip_static_hosts = settings["skipstatichosts"] write_opening_header(final_file, extensions=extensions, numberofrules=number_of_rules, outputsubfolder=output_subfolder, - skipstatichosts=settings["skipstatichosts"]) + skipstatichosts=skip_static_hosts) final_file.close() if settings["ziphosts"]: @@ -183,63 +186,101 @@ def main(): "{:,}".format(number_of_rules) + " unique entries.") - prompt_for_move(final_file) + move_file = prompt_for_move(final_file, auto=auto, + replace=settings["replace"], + skipstatichosts=skip_static_hosts) + + # We only flush the DNS cache if we have + # moved a new hosts file into place. + if move_file: + prompt_for_flush_dns_cache(flush_cache=settings["flushdnscache"], + prompt_flush=not auto) # Prompt the User -def prompt_for_update(): +def prompt_for_update(freshen, update_auto): """ Prompt the user to update all hosts files. + + If requested, the function will update all data sources after it + checks that a hosts file does indeed exist. + + Parameters + ---------- + freshen : bool + Whether data sources should be updated. This function will return + if it is requested that data sources not be updated. + update_auto : bool + Whether or not to automatically update all data sources. """ - # Create hosts file if it doesn't exist. - if not os.path.isfile(path_join_robust(BASEDIR_PATH, "hosts")): - try: - open(path_join_robust(BASEDIR_PATH, "hosts"), "w+").close() - except Exception: - print_failure("ERROR: No 'hosts' file in the folder," - "try creating one manually") + # Create a hosts file if it doesn't exist. + hosts_file = path_join_robust(BASEDIR_PATH, "hosts") - if not settings["freshen"]: + if not os.path.isfile(hosts_file): + try: + open(hosts_file, "w+").close() + except (IOError, OSError): + # Starting in Python 3.3, IOError is aliased + # OSError. However, we have to catch both for + # Python 2.x failures. + print_failure("ERROR: No 'hosts' file in the folder. " + "Try creating one manually.") + + if not freshen: return prompt = "Do you want to update all data sources?" - if settings["auto"] or query_yes_no(prompt): + + if update_auto or query_yes_no(prompt): update_all_sources() - elif not settings["auto"]: + elif not update_auto: print("OK, we'll stick with what we've got locally.") -def prompt_for_exclusions(): +def prompt_for_exclusions(skip_prompt): """ Prompt the user to exclude any custom domains from being blocked. + + Parameters + ---------- + skip_prompt : bool + Whether or not to skip prompting for custom domains to be excluded. + If true, the function returns immediately. """ prompt = ("Do you want to exclude any domains?\n" "For example, hulu.com video streaming must be able to access " "its tracking and ad servers in order to play video.") - if not settings["auto"]: + if not skip_prompt: if query_yes_no(prompt): display_exclusion_options() else: print("OK, we'll only exclude domains in the whitelist.") -def prompt_for_flush_dns_cache(): +def prompt_for_flush_dns_cache(flush_cache, prompt_flush): """ Prompt the user to flush the DNS cache. + + Parameters + ---------- + flush_cache : bool + Whether to flush the DNS cache without prompting. + prompt_flush : bool + If `flush_cache` is False, whether we should prompt for flushing the + cache. Otherwise, the function returns immediately. """ - if settings["flushdnscache"]: + if flush_cache: flush_dns_cache() - - if not settings["auto"]: + elif prompt_flush: if query_yes_no("Attempt to flush the DNS cache?"): flush_dns_cache() -def prompt_for_move(final_file): +def prompt_for_move(final_file, **move_params): """ Prompt the user to move the newly created hosts file to its designated location in the OS. @@ -248,11 +289,25 @@ def prompt_for_move(final_file): ---------- final_file : file The file object that contains the newly created hosts data. + move_params : kwargs + Dictionary providing additional parameters for moving the hosts file + into place. Currently, those fields are: + + 1) auto + 2) replace + 3) skipstatichosts + + Returns + ------- + move_file : bool + Whether or not the final hosts file was moved. """ - if settings["replace"] and not settings["skipstatichosts"]: + skip_static_hosts = move_params["skipstatichosts"] + + if move_params["replace"] and not skip_static_hosts: move_file = True - elif settings["auto"] or settings["skipstatichosts"]: + elif move_params["auto"] or skip_static_hosts: move_file = False else: prompt = ("Do you want to replace your existing hosts file " + @@ -261,7 +316,8 @@ def prompt_for_move(final_file): if move_file: move_hosts_file_into_place(final_file) - prompt_for_flush_dns_cache() + + return move_file # End Prompt the User