Refactor out global settings usage in user prompt

This commit is contained in:
gfyoung 2017-07-09 12:00:11 -07:00
parent a819131927
commit 4b96f3f34a
2 changed files with 391 additions and 29 deletions

View File

@ -10,8 +10,10 @@ from updateHostsFile import (Colors, PY3, colorize, flush_dns_cache,
get_file_by_url, is_valid_domain_format,
move_hosts_file_into_place, normalize_rule,
path_join_robust, print_failure, print_success,
supports_color, query_yes_no, recursive_glob,
remove_old_hosts_file, strip_rule,
prompt_for_exclusions, prompt_for_move,
prompt_for_flush_dns_cache, prompt_for_update,
query_yes_no, recursive_glob,
remove_old_hosts_file, supports_color, strip_rule,
update_readme_data, write_data,
write_opening_header)
import updateHostsFile
@ -33,7 +35,7 @@ else:
import mock
# Base Test Classes
# Test Helper Objects
class Base(unittest.TestCase):
@staticmethod
@ -66,7 +68,14 @@ class BaseMockDir(Base):
def tearDown(self):
shutil.rmtree(self.test_dir)
# End Base Test Classes
def builtins():
if PY3:
return "builtins"
else:
return "__builtin__"
# End Test Helper Objects
# Project Settings
@ -107,6 +116,303 @@ class TestGetDefaults(Base):
# End Project Settings
# Prompt the User
class TestPromptForUpdate(BaseStdout, BaseMockDir):
def setUp(self):
BaseStdout.setUp(self)
BaseMockDir.setUp(self)
def test_no_freshen_no_new_file(self):
hosts_file = os.path.join(self.test_dir, "hosts")
hosts_data = "This data should not be overwritten"
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
updateHostsFile.BASEDIR_PATH = self.test_dir
with open(hosts_file, "w") as f:
f.write(hosts_data)
for update_auto in (False, True):
dir_count = self.dir_count
prompt_for_update(freshen=False, update_auto=update_auto)
output = sys.stdout.getvalue()
self.assertEqual(output, "")
sys.stdout = StringIO()
self.assertEqual(self.dir_count, dir_count)
with open(hosts_file, "r") as f:
contents = f.read()
self.assertEqual(contents, hosts_data)
def test_no_freshen_new_file(self):
hosts_file = os.path.join(self.test_dir, "hosts")
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
updateHostsFile.BASEDIR_PATH = self.test_dir
dir_count = self.dir_count
prompt_for_update(freshen=False, update_auto=False)
output = sys.stdout.getvalue()
self.assertEqual(output, "")
sys.stdout = StringIO()
self.assertEqual(self.dir_count, dir_count + 1)
with open(hosts_file, "r") as f:
contents = f.read()
self.assertEqual(contents, "")
@mock.patch(builtins() + ".open")
def test_no_freshen_fail_new_file(self, mock_open):
for exc in (IOError, OSError):
mock_open.side_effect = exc("failed open")
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
updateHostsFile.BASEDIR_PATH = self.test_dir
prompt_for_update(freshen=False, update_auto=False)
output = sys.stdout.getvalue()
expected = ("ERROR: No 'hosts' file in the folder. "
"Try creating one manually.")
self.assertIn(expected, output)
sys.stdout = StringIO()
@mock.patch("updateHostsFile.update_all_sources", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def test_freshen_no_update(self, _, mock_update):
hosts_file = os.path.join(self.test_dir, "hosts")
hosts_data = "This data should not be overwritten"
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
updateHostsFile.BASEDIR_PATH = self.test_dir
with open(hosts_file, "w") as f:
f.write(hosts_data)
dir_count = self.dir_count
prompt_for_update(freshen=True, update_auto=False)
mock_update.assert_not_called()
mock_update.reset_mock()
output = sys.stdout.getvalue()
expected = ("OK, we'll stick with "
"what we've got locally.")
self.assertIn(expected, output)
sys.stdout = StringIO()
self.assertEqual(self.dir_count, dir_count)
with open(hosts_file, "r") as f:
contents = f.read()
self.assertEqual(contents, hosts_data)
@mock.patch("updateHostsFile.update_all_sources", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=True)
def test_freshen_update(self, _, mock_update):
hosts_file = os.path.join(self.test_dir, "hosts")
hosts_data = "This data should not be overwritten"
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
updateHostsFile.BASEDIR_PATH = self.test_dir
with open(hosts_file, "w") as f:
f.write(hosts_data)
dir_count = self.dir_count
for update_auto in (False, True):
prompt_for_update(freshen=True, update_auto=update_auto)
mock_update.assert_called_once()
mock_update.reset_mock()
output = sys.stdout.getvalue()
self.assertEqual(output, "")
sys.stdout = StringIO()
self.assertEqual(self.dir_count, dir_count)
with open(hosts_file, "r") as f:
contents = f.read()
self.assertEqual(contents, hosts_data)
def tearDown(self):
BaseStdout.tearDown(self)
BaseStdout.tearDown(self)
class TestPromptForExclusions(BaseStdout):
@mock.patch("updateHostsFile.display_exclusion_options", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testSkipPrompt(self, mock_query, mock_display):
prompt_for_exclusions(skip_prompt=True)
output = sys.stdout.getvalue()
self.assertEqual(output, "")
mock_query.assert_not_called()
mock_display.assert_not_called()
@mock.patch("updateHostsFile.display_exclusion_options", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testNoSkipPromptNoDisplay(self, mock_query, mock_display):
prompt_for_exclusions(skip_prompt=False)
output = sys.stdout.getvalue()
expected = ("OK, we'll only exclude "
"domains in the whitelist.")
self.assertIn(expected, output)
mock_query.assert_called_once()
mock_display.assert_not_called()
@mock.patch("updateHostsFile.display_exclusion_options", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=True)
def testNoSkipPromptDisplay(self, mock_query, mock_display):
prompt_for_exclusions(skip_prompt=False)
output = sys.stdout.getvalue()
self.assertEqual(output, "")
mock_query.assert_called_once()
mock_display.assert_called_once()
class TestPromptForFlushDnsCache(Base):
@mock.patch("updateHostsFile.flush_dns_cache", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testFlushCache(self, mock_query, mock_flush):
for prompt_flush in (False, True):
prompt_for_flush_dns_cache(flush_cache=True,
prompt_flush=prompt_flush)
mock_query.assert_not_called()
mock_flush.assert_called_once()
mock_query.reset_mock()
mock_flush.reset_mock()
@mock.patch("updateHostsFile.flush_dns_cache", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testNoFlushCacheNoPrompt(self, mock_query, mock_flush):
prompt_for_flush_dns_cache(flush_cache=False,
prompt_flush=False)
mock_query.assert_not_called()
mock_flush.assert_not_called()
@mock.patch("updateHostsFile.flush_dns_cache", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testNoFlushCachePromptNoFlush(self, mock_query, mock_flush):
prompt_for_flush_dns_cache(flush_cache=False,
prompt_flush=True)
mock_query.assert_called_once()
mock_flush.assert_not_called()
@mock.patch("updateHostsFile.flush_dns_cache", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=True)
def testNoFlushCachePromptFlush(self, mock_query, mock_flush):
prompt_for_flush_dns_cache(flush_cache=False,
prompt_flush=True)
mock_query.assert_called_once()
mock_flush.assert_called_once()
class TestPromptForMove(Base):
def setUp(self):
Base.setUp(self)
self.final_file = "final.txt"
def prompt_for_move(self, **move_params):
return prompt_for_move(self.final_file, **move_params)
@mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testSkipStaticHosts(self, mock_query, mock_move):
for replace in (False, True):
for auto in (False, True):
move_file = self.prompt_for_move(replace=replace, auto=auto,
skipstatichosts=True)
self.assertFalse(move_file)
mock_query.assert_not_called()
mock_move.assert_not_called()
mock_query.reset_mock()
mock_move.reset_mock()
@mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testReplaceNoSkipStaticHosts(self, mock_query, mock_move):
for auto in (False, True):
move_file = self.prompt_for_move(replace=True, auto=auto,
skipstatichosts=False)
self.assertTrue(move_file)
mock_query.assert_not_called()
mock_move.assert_called_once()
mock_query.reset_mock()
mock_move.reset_mock()
@mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testAutoNoSkipStaticHosts(self, mock_query, mock_move):
for replace in (False, True):
move_file = self.prompt_for_move(replace=replace, auto=True,
skipstatichosts=True)
self.assertFalse(move_file)
mock_query.assert_not_called()
mock_move.assert_not_called()
mock_query.reset_mock()
mock_move.reset_mock()
@mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testPromptNoMove(self, mock_query, mock_move):
move_file = self.prompt_for_move(replace=False, auto=False,
skipstatichosts=False)
self.assertFalse(move_file)
mock_query.assert_called_once()
mock_move.assert_not_called()
mock_query.reset_mock()
mock_move.reset_mock()
@mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=True)
def testPromptMove(self, mock_query, mock_move):
move_file = self.prompt_for_move(replace=False, auto=False,
skipstatichosts=False)
self.assertTrue(move_file)
mock_query.assert_called_once()
mock_move.assert_called_once()
mock_query.reset_mock()
mock_move.reset_mock()
# End Prompt the User
# Exclusion Logic
class TestGatherCustomExclusions(BaseStdout):

View File

@ -146,8 +146,10 @@ def main():
settings["extensions"] = sorted(list(
set(options["extensions"]).intersection(settings["extensions"])))
prompt_for_update()
prompt_for_exclusions()
auto = settings["auto"]
prompt_for_update(freshen=settings["freshen"], update_auto=auto)
prompt_for_exclusions(skip_prompt=auto)
merge_file = create_initial_file()
remove_old_hosts_file(settings["backup"])
@ -158,11 +160,12 @@ def main():
final_file = remove_dups_and_excl(merge_file)
number_of_rules = settings["numberofrules"]
skip_static_hosts = settings["skipstatichosts"]
write_opening_header(final_file, extensions=extensions,
numberofrules=number_of_rules,
outputsubfolder=output_subfolder,
skipstatichosts=settings["skipstatichosts"])
skipstatichosts=skip_static_hosts)
final_file.close()
if settings["ziphosts"]:
@ -183,63 +186,101 @@ def main():
"{:,}".format(number_of_rules) +
" unique entries.")
prompt_for_move(final_file)
move_file = prompt_for_move(final_file, auto=auto,
replace=settings["replace"],
skipstatichosts=skip_static_hosts)
# We only flush the DNS cache if we have
# moved a new hosts file into place.
if move_file:
prompt_for_flush_dns_cache(flush_cache=settings["flushdnscache"],
prompt_flush=not auto)
# Prompt the User
def prompt_for_update():
def prompt_for_update(freshen, update_auto):
"""
Prompt the user to update all hosts files.
If requested, the function will update all data sources after it
checks that a hosts file does indeed exist.
Parameters
----------
freshen : bool
Whether data sources should be updated. This function will return
if it is requested that data sources not be updated.
update_auto : bool
Whether or not to automatically update all data sources.
"""
# Create hosts file if it doesn't exist.
if not os.path.isfile(path_join_robust(BASEDIR_PATH, "hosts")):
try:
open(path_join_robust(BASEDIR_PATH, "hosts"), "w+").close()
except Exception:
print_failure("ERROR: No 'hosts' file in the folder,"
"try creating one manually")
# Create a hosts file if it doesn't exist.
hosts_file = path_join_robust(BASEDIR_PATH, "hosts")
if not settings["freshen"]:
if not os.path.isfile(hosts_file):
try:
open(hosts_file, "w+").close()
except (IOError, OSError):
# Starting in Python 3.3, IOError is aliased
# OSError. However, we have to catch both for
# Python 2.x failures.
print_failure("ERROR: No 'hosts' file in the folder. "
"Try creating one manually.")
if not freshen:
return
prompt = "Do you want to update all data sources?"
if settings["auto"] or query_yes_no(prompt):
if update_auto or query_yes_no(prompt):
update_all_sources()
elif not settings["auto"]:
elif not update_auto:
print("OK, we'll stick with what we've got locally.")
def prompt_for_exclusions():
def prompt_for_exclusions(skip_prompt):
"""
Prompt the user to exclude any custom domains from being blocked.
Parameters
----------
skip_prompt : bool
Whether or not to skip prompting for custom domains to be excluded.
If true, the function returns immediately.
"""
prompt = ("Do you want to exclude any domains?\n"
"For example, hulu.com video streaming must be able to access "
"its tracking and ad servers in order to play video.")
if not settings["auto"]:
if not skip_prompt:
if query_yes_no(prompt):
display_exclusion_options()
else:
print("OK, we'll only exclude domains in the whitelist.")
def prompt_for_flush_dns_cache():
def prompt_for_flush_dns_cache(flush_cache, prompt_flush):
"""
Prompt the user to flush the DNS cache.
Parameters
----------
flush_cache : bool
Whether to flush the DNS cache without prompting.
prompt_flush : bool
If `flush_cache` is False, whether we should prompt for flushing the
cache. Otherwise, the function returns immediately.
"""
if settings["flushdnscache"]:
if flush_cache:
flush_dns_cache()
if not settings["auto"]:
elif prompt_flush:
if query_yes_no("Attempt to flush the DNS cache?"):
flush_dns_cache()
def prompt_for_move(final_file):
def prompt_for_move(final_file, **move_params):
"""
Prompt the user to move the newly created hosts file to its designated
location in the OS.
@ -248,11 +289,25 @@ def prompt_for_move(final_file):
----------
final_file : file
The file object that contains the newly created hosts data.
move_params : kwargs
Dictionary providing additional parameters for moving the hosts file
into place. Currently, those fields are:
1) auto
2) replace
3) skipstatichosts
Returns
-------
move_file : bool
Whether or not the final hosts file was moved.
"""
if settings["replace"] and not settings["skipstatichosts"]:
skip_static_hosts = move_params["skipstatichosts"]
if move_params["replace"] and not skip_static_hosts:
move_file = True
elif settings["auto"] or settings["skipstatichosts"]:
elif move_params["auto"] or skip_static_hosts:
move_file = False
else:
prompt = ("Do you want to replace your existing hosts file " +
@ -261,7 +316,8 @@ def prompt_for_move(final_file):
if move_file:
move_hosts_file_into_place(final_file)
prompt_for_flush_dns_cache()
return move_file
# End Prompt the User