Refactor out global settings usage in user prompt

This commit is contained in:
gfyoung 2017-07-09 12:00:11 -07:00
parent a819131927
commit 4b96f3f34a
2 changed files with 391 additions and 29 deletions

View File

@ -10,8 +10,10 @@ from updateHostsFile import (Colors, PY3, colorize, flush_dns_cache,
get_file_by_url, is_valid_domain_format, get_file_by_url, is_valid_domain_format,
move_hosts_file_into_place, normalize_rule, move_hosts_file_into_place, normalize_rule,
path_join_robust, print_failure, print_success, path_join_robust, print_failure, print_success,
supports_color, query_yes_no, recursive_glob, prompt_for_exclusions, prompt_for_move,
remove_old_hosts_file, strip_rule, prompt_for_flush_dns_cache, prompt_for_update,
query_yes_no, recursive_glob,
remove_old_hosts_file, supports_color, strip_rule,
update_readme_data, write_data, update_readme_data, write_data,
write_opening_header) write_opening_header)
import updateHostsFile import updateHostsFile
@ -33,7 +35,7 @@ else:
import mock import mock
# Base Test Classes # Test Helper Objects
class Base(unittest.TestCase): class Base(unittest.TestCase):
@staticmethod @staticmethod
@ -66,7 +68,14 @@ class BaseMockDir(Base):
def tearDown(self): def tearDown(self):
shutil.rmtree(self.test_dir) shutil.rmtree(self.test_dir)
# End Base Test Classes
def builtins():
if PY3:
return "builtins"
else:
return "__builtin__"
# End Test Helper Objects
# Project Settings # Project Settings
@ -107,6 +116,303 @@ class TestGetDefaults(Base):
# End Project Settings # End Project Settings
# Prompt the User
class TestPromptForUpdate(BaseStdout, BaseMockDir):
def setUp(self):
BaseStdout.setUp(self)
BaseMockDir.setUp(self)
def test_no_freshen_no_new_file(self):
hosts_file = os.path.join(self.test_dir, "hosts")
hosts_data = "This data should not be overwritten"
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
updateHostsFile.BASEDIR_PATH = self.test_dir
with open(hosts_file, "w") as f:
f.write(hosts_data)
for update_auto in (False, True):
dir_count = self.dir_count
prompt_for_update(freshen=False, update_auto=update_auto)
output = sys.stdout.getvalue()
self.assertEqual(output, "")
sys.stdout = StringIO()
self.assertEqual(self.dir_count, dir_count)
with open(hosts_file, "r") as f:
contents = f.read()
self.assertEqual(contents, hosts_data)
def test_no_freshen_new_file(self):
hosts_file = os.path.join(self.test_dir, "hosts")
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
updateHostsFile.BASEDIR_PATH = self.test_dir
dir_count = self.dir_count
prompt_for_update(freshen=False, update_auto=False)
output = sys.stdout.getvalue()
self.assertEqual(output, "")
sys.stdout = StringIO()
self.assertEqual(self.dir_count, dir_count + 1)
with open(hosts_file, "r") as f:
contents = f.read()
self.assertEqual(contents, "")
@mock.patch(builtins() + ".open")
def test_no_freshen_fail_new_file(self, mock_open):
for exc in (IOError, OSError):
mock_open.side_effect = exc("failed open")
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
updateHostsFile.BASEDIR_PATH = self.test_dir
prompt_for_update(freshen=False, update_auto=False)
output = sys.stdout.getvalue()
expected = ("ERROR: No 'hosts' file in the folder. "
"Try creating one manually.")
self.assertIn(expected, output)
sys.stdout = StringIO()
@mock.patch("updateHostsFile.update_all_sources", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def test_freshen_no_update(self, _, mock_update):
hosts_file = os.path.join(self.test_dir, "hosts")
hosts_data = "This data should not be overwritten"
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
updateHostsFile.BASEDIR_PATH = self.test_dir
with open(hosts_file, "w") as f:
f.write(hosts_data)
dir_count = self.dir_count
prompt_for_update(freshen=True, update_auto=False)
mock_update.assert_not_called()
mock_update.reset_mock()
output = sys.stdout.getvalue()
expected = ("OK, we'll stick with "
"what we've got locally.")
self.assertIn(expected, output)
sys.stdout = StringIO()
self.assertEqual(self.dir_count, dir_count)
with open(hosts_file, "r") as f:
contents = f.read()
self.assertEqual(contents, hosts_data)
@mock.patch("updateHostsFile.update_all_sources", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=True)
def test_freshen_update(self, _, mock_update):
hosts_file = os.path.join(self.test_dir, "hosts")
hosts_data = "This data should not be overwritten"
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
updateHostsFile.BASEDIR_PATH = self.test_dir
with open(hosts_file, "w") as f:
f.write(hosts_data)
dir_count = self.dir_count
for update_auto in (False, True):
prompt_for_update(freshen=True, update_auto=update_auto)
mock_update.assert_called_once()
mock_update.reset_mock()
output = sys.stdout.getvalue()
self.assertEqual(output, "")
sys.stdout = StringIO()
self.assertEqual(self.dir_count, dir_count)
with open(hosts_file, "r") as f:
contents = f.read()
self.assertEqual(contents, hosts_data)
def tearDown(self):
BaseStdout.tearDown(self)
BaseStdout.tearDown(self)
class TestPromptForExclusions(BaseStdout):
@mock.patch("updateHostsFile.display_exclusion_options", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testSkipPrompt(self, mock_query, mock_display):
prompt_for_exclusions(skip_prompt=True)
output = sys.stdout.getvalue()
self.assertEqual(output, "")
mock_query.assert_not_called()
mock_display.assert_not_called()
@mock.patch("updateHostsFile.display_exclusion_options", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testNoSkipPromptNoDisplay(self, mock_query, mock_display):
prompt_for_exclusions(skip_prompt=False)
output = sys.stdout.getvalue()
expected = ("OK, we'll only exclude "
"domains in the whitelist.")
self.assertIn(expected, output)
mock_query.assert_called_once()
mock_display.assert_not_called()
@mock.patch("updateHostsFile.display_exclusion_options", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=True)
def testNoSkipPromptDisplay(self, mock_query, mock_display):
prompt_for_exclusions(skip_prompt=False)
output = sys.stdout.getvalue()
self.assertEqual(output, "")
mock_query.assert_called_once()
mock_display.assert_called_once()
class TestPromptForFlushDnsCache(Base):
@mock.patch("updateHostsFile.flush_dns_cache", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testFlushCache(self, mock_query, mock_flush):
for prompt_flush in (False, True):
prompt_for_flush_dns_cache(flush_cache=True,
prompt_flush=prompt_flush)
mock_query.assert_not_called()
mock_flush.assert_called_once()
mock_query.reset_mock()
mock_flush.reset_mock()
@mock.patch("updateHostsFile.flush_dns_cache", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testNoFlushCacheNoPrompt(self, mock_query, mock_flush):
prompt_for_flush_dns_cache(flush_cache=False,
prompt_flush=False)
mock_query.assert_not_called()
mock_flush.assert_not_called()
@mock.patch("updateHostsFile.flush_dns_cache", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testNoFlushCachePromptNoFlush(self, mock_query, mock_flush):
prompt_for_flush_dns_cache(flush_cache=False,
prompt_flush=True)
mock_query.assert_called_once()
mock_flush.assert_not_called()
@mock.patch("updateHostsFile.flush_dns_cache", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=True)
def testNoFlushCachePromptFlush(self, mock_query, mock_flush):
prompt_for_flush_dns_cache(flush_cache=False,
prompt_flush=True)
mock_query.assert_called_once()
mock_flush.assert_called_once()
class TestPromptForMove(Base):
def setUp(self):
Base.setUp(self)
self.final_file = "final.txt"
def prompt_for_move(self, **move_params):
return prompt_for_move(self.final_file, **move_params)
@mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testSkipStaticHosts(self, mock_query, mock_move):
for replace in (False, True):
for auto in (False, True):
move_file = self.prompt_for_move(replace=replace, auto=auto,
skipstatichosts=True)
self.assertFalse(move_file)
mock_query.assert_not_called()
mock_move.assert_not_called()
mock_query.reset_mock()
mock_move.reset_mock()
@mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testReplaceNoSkipStaticHosts(self, mock_query, mock_move):
for auto in (False, True):
move_file = self.prompt_for_move(replace=True, auto=auto,
skipstatichosts=False)
self.assertTrue(move_file)
mock_query.assert_not_called()
mock_move.assert_called_once()
mock_query.reset_mock()
mock_move.reset_mock()
@mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testAutoNoSkipStaticHosts(self, mock_query, mock_move):
for replace in (False, True):
move_file = self.prompt_for_move(replace=replace, auto=True,
skipstatichosts=True)
self.assertFalse(move_file)
mock_query.assert_not_called()
mock_move.assert_not_called()
mock_query.reset_mock()
mock_move.reset_mock()
@mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testPromptNoMove(self, mock_query, mock_move):
move_file = self.prompt_for_move(replace=False, auto=False,
skipstatichosts=False)
self.assertFalse(move_file)
mock_query.assert_called_once()
mock_move.assert_not_called()
mock_query.reset_mock()
mock_move.reset_mock()
@mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=True)
def testPromptMove(self, mock_query, mock_move):
move_file = self.prompt_for_move(replace=False, auto=False,
skipstatichosts=False)
self.assertTrue(move_file)
mock_query.assert_called_once()
mock_move.assert_called_once()
mock_query.reset_mock()
mock_move.reset_mock()
# End Prompt the User
# Exclusion Logic # Exclusion Logic
class TestGatherCustomExclusions(BaseStdout): class TestGatherCustomExclusions(BaseStdout):

View File

@ -146,8 +146,10 @@ def main():
settings["extensions"] = sorted(list( settings["extensions"] = sorted(list(
set(options["extensions"]).intersection(settings["extensions"]))) set(options["extensions"]).intersection(settings["extensions"])))
prompt_for_update() auto = settings["auto"]
prompt_for_exclusions()
prompt_for_update(freshen=settings["freshen"], update_auto=auto)
prompt_for_exclusions(skip_prompt=auto)
merge_file = create_initial_file() merge_file = create_initial_file()
remove_old_hosts_file(settings["backup"]) remove_old_hosts_file(settings["backup"])
@ -158,11 +160,12 @@ def main():
final_file = remove_dups_and_excl(merge_file) final_file = remove_dups_and_excl(merge_file)
number_of_rules = settings["numberofrules"] number_of_rules = settings["numberofrules"]
skip_static_hosts = settings["skipstatichosts"]
write_opening_header(final_file, extensions=extensions, write_opening_header(final_file, extensions=extensions,
numberofrules=number_of_rules, numberofrules=number_of_rules,
outputsubfolder=output_subfolder, outputsubfolder=output_subfolder,
skipstatichosts=settings["skipstatichosts"]) skipstatichosts=skip_static_hosts)
final_file.close() final_file.close()
if settings["ziphosts"]: if settings["ziphosts"]:
@ -183,63 +186,101 @@ def main():
"{:,}".format(number_of_rules) + "{:,}".format(number_of_rules) +
" unique entries.") " unique entries.")
prompt_for_move(final_file) move_file = prompt_for_move(final_file, auto=auto,
replace=settings["replace"],
skipstatichosts=skip_static_hosts)
# We only flush the DNS cache if we have
# moved a new hosts file into place.
if move_file:
prompt_for_flush_dns_cache(flush_cache=settings["flushdnscache"],
prompt_flush=not auto)
# Prompt the User # Prompt the User
def prompt_for_update(): def prompt_for_update(freshen, update_auto):
""" """
Prompt the user to update all hosts files. Prompt the user to update all hosts files.
If requested, the function will update all data sources after it
checks that a hosts file does indeed exist.
Parameters
----------
freshen : bool
Whether data sources should be updated. This function will return
if it is requested that data sources not be updated.
update_auto : bool
Whether or not to automatically update all data sources.
""" """
# Create hosts file if it doesn't exist. # Create a hosts file if it doesn't exist.
if not os.path.isfile(path_join_robust(BASEDIR_PATH, "hosts")): hosts_file = path_join_robust(BASEDIR_PATH, "hosts")
try:
open(path_join_robust(BASEDIR_PATH, "hosts"), "w+").close()
except Exception:
print_failure("ERROR: No 'hosts' file in the folder,"
"try creating one manually")
if not settings["freshen"]: if not os.path.isfile(hosts_file):
try:
open(hosts_file, "w+").close()
except (IOError, OSError):
# Starting in Python 3.3, IOError is aliased
# OSError. However, we have to catch both for
# Python 2.x failures.
print_failure("ERROR: No 'hosts' file in the folder. "
"Try creating one manually.")
if not freshen:
return return
prompt = "Do you want to update all data sources?" prompt = "Do you want to update all data sources?"
if settings["auto"] or query_yes_no(prompt):
if update_auto or query_yes_no(prompt):
update_all_sources() update_all_sources()
elif not settings["auto"]: elif not update_auto:
print("OK, we'll stick with what we've got locally.") print("OK, we'll stick with what we've got locally.")
def prompt_for_exclusions(): def prompt_for_exclusions(skip_prompt):
""" """
Prompt the user to exclude any custom domains from being blocked. Prompt the user to exclude any custom domains from being blocked.
Parameters
----------
skip_prompt : bool
Whether or not to skip prompting for custom domains to be excluded.
If true, the function returns immediately.
""" """
prompt = ("Do you want to exclude any domains?\n" prompt = ("Do you want to exclude any domains?\n"
"For example, hulu.com video streaming must be able to access " "For example, hulu.com video streaming must be able to access "
"its tracking and ad servers in order to play video.") "its tracking and ad servers in order to play video.")
if not settings["auto"]: if not skip_prompt:
if query_yes_no(prompt): if query_yes_no(prompt):
display_exclusion_options() display_exclusion_options()
else: else:
print("OK, we'll only exclude domains in the whitelist.") print("OK, we'll only exclude domains in the whitelist.")
def prompt_for_flush_dns_cache(): def prompt_for_flush_dns_cache(flush_cache, prompt_flush):
""" """
Prompt the user to flush the DNS cache. Prompt the user to flush the DNS cache.
Parameters
----------
flush_cache : bool
Whether to flush the DNS cache without prompting.
prompt_flush : bool
If `flush_cache` is False, whether we should prompt for flushing the
cache. Otherwise, the function returns immediately.
""" """
if settings["flushdnscache"]: if flush_cache:
flush_dns_cache() flush_dns_cache()
elif prompt_flush:
if not settings["auto"]:
if query_yes_no("Attempt to flush the DNS cache?"): if query_yes_no("Attempt to flush the DNS cache?"):
flush_dns_cache() flush_dns_cache()
def prompt_for_move(final_file): def prompt_for_move(final_file, **move_params):
""" """
Prompt the user to move the newly created hosts file to its designated Prompt the user to move the newly created hosts file to its designated
location in the OS. location in the OS.
@ -248,11 +289,25 @@ def prompt_for_move(final_file):
---------- ----------
final_file : file final_file : file
The file object that contains the newly created hosts data. The file object that contains the newly created hosts data.
move_params : kwargs
Dictionary providing additional parameters for moving the hosts file
into place. Currently, those fields are:
1) auto
2) replace
3) skipstatichosts
Returns
-------
move_file : bool
Whether or not the final hosts file was moved.
""" """
if settings["replace"] and not settings["skipstatichosts"]: skip_static_hosts = move_params["skipstatichosts"]
if move_params["replace"] and not skip_static_hosts:
move_file = True move_file = True
elif settings["auto"] or settings["skipstatichosts"]: elif move_params["auto"] or skip_static_hosts:
move_file = False move_file = False
else: else:
prompt = ("Do you want to replace your existing hosts file " + prompt = ("Do you want to replace your existing hosts file " +
@ -261,7 +316,8 @@ def prompt_for_move(final_file):
if move_file: if move_file:
move_hosts_file_into_place(final_file) move_hosts_file_into_place(final_file)
prompt_for_flush_dns_cache()
return move_file
# End Prompt the User # End Prompt the User