diff --git a/.travis.yml b/.travis.yml index 3f4c0778a..6b9036e86 100644 --- a/.travis.yml +++ b/.travis.yml @@ -19,6 +19,7 @@ os: env: - PYTHON_VERSION="2.7" + - PYTHON_VERSION="3.5" - PYTHON_VERSION="3.6" before_install: diff --git a/testUpdateHostsFile.py b/testUpdateHostsFile.py index ff3133186..139bd8955 100644 --- a/testUpdateHostsFile.py +++ b/testUpdateHostsFile.py @@ -5,17 +5,16 @@ # # Python script for testing updateHostFiles.py -from updateHostsFile import (Colors, PY3, colorize, flush_dns_cache, - gather_custom_exclusions, get_defaults, - get_file_by_url, is_valid_domain_format, - move_hosts_file_into_place, normalize_rule, - path_join_robust, print_failure, print_success, - prompt_for_exclusions, prompt_for_move, - prompt_for_flush_dns_cache, prompt_for_update, - query_yes_no, recursive_glob, - remove_old_hosts_file, supports_color, strip_rule, - update_readme_data, write_data, - write_opening_header) +from updateHostsFile import ( + Colors, PY3, colorize, display_exclusion_options, exclude_domain, + flush_dns_cache, gather_custom_exclusions, get_defaults, get_file_by_url, + is_valid_domain_format, matches_exclusions, move_hosts_file_into_place, + normalize_rule, path_join_robust, print_failure, print_success, + prompt_for_exclusions, prompt_for_move, prompt_for_flush_dns_cache, + prompt_for_update, query_yes_no, recursive_glob, remove_old_hosts_file, + supports_color, strip_rule, update_readme_data, write_data, + write_opening_header) + import updateHostsFile import unittest import tempfile @@ -24,6 +23,7 @@ import shutil import json import sys import os +import re if PY3: from io import BytesIO, StringIO @@ -260,21 +260,20 @@ class TestPromptForUpdate(BaseStdout, BaseMockDir): class TestPromptForExclusions(BaseStdout): - @mock.patch("updateHostsFile.display_exclusion_options", return_value=0) @mock.patch("updateHostsFile.query_yes_no", return_value=False) - def testSkipPrompt(self, mock_query, mock_display): - prompt_for_exclusions(skip_prompt=True) + def testSkipPrompt(self, mock_query): + gather_exclusions = prompt_for_exclusions(skip_prompt=True) + self.assertFalse(gather_exclusions) output = sys.stdout.getvalue() self.assertEqual(output, "") mock_query.assert_not_called() - mock_display.assert_not_called() - @mock.patch("updateHostsFile.display_exclusion_options", return_value=0) @mock.patch("updateHostsFile.query_yes_no", return_value=False) - def testNoSkipPromptNoDisplay(self, mock_query, mock_display): - prompt_for_exclusions(skip_prompt=False) + def testNoSkipPromptNoDisplay(self, mock_query): + gather_exclusions = prompt_for_exclusions(skip_prompt=False) + self.assertFalse(gather_exclusions) output = sys.stdout.getvalue() expected = ("OK, we'll only exclude " @@ -282,18 +281,16 @@ class TestPromptForExclusions(BaseStdout): self.assertIn(expected, output) self.assert_called_once(mock_query) - mock_display.assert_not_called() - @mock.patch("updateHostsFile.display_exclusion_options", return_value=0) @mock.patch("updateHostsFile.query_yes_no", return_value=True) - def testNoSkipPromptDisplay(self, mock_query, mock_display): - prompt_for_exclusions(skip_prompt=False) + def testNoSkipPromptDisplay(self, mock_query): + gather_exclusions = prompt_for_exclusions(skip_prompt=False) + self.assertTrue(gather_exclusions) output = sys.stdout.getvalue() self.assertEqual(output, "") self.assert_called_once(mock_query) - self.assert_called_once(mock_display) class TestPromptForFlushDnsCache(Base): @@ -420,6 +417,52 @@ class TestPromptForMove(Base): # Exclusion Logic +class TestDisplayExclusionsOptions(Base): + + @mock.patch("updateHostsFile.query_yes_no", return_value=0) + @mock.patch("updateHostsFile.exclude_domain", return_value=None) + @mock.patch("updateHostsFile.gather_custom_exclusions", return_value=None) + def test_no_exclusions(self, mock_gather, mock_exclude, _): + common_exclusions = [] + display_exclusion_options(common_exclusions, "foo", []) + + mock_gather.assert_not_called() + mock_exclude.assert_not_called() + + @mock.patch("updateHostsFile.query_yes_no", side_effect=[1, 1, 0]) + @mock.patch("updateHostsFile.exclude_domain", return_value=None) + @mock.patch("updateHostsFile.gather_custom_exclusions", return_value=None) + def test_only_common_exclusions(self, mock_gather, mock_exclude, _): + common_exclusions = ["foo", "bar"] + display_exclusion_options(common_exclusions, "foo", []) + + mock_gather.assert_not_called() + + exclude_calls = [mock.call("foo", "foo", []), + mock.call("bar", "foo", None)] + mock_exclude.assert_has_calls(exclude_calls) + + @mock.patch("updateHostsFile.query_yes_no", side_effect=[0, 0, 1]) + @mock.patch("updateHostsFile.exclude_domain", return_value=None) + @mock.patch("updateHostsFile.gather_custom_exclusions", return_value=None) + def test_gather_exclusions(self, mock_gather, mock_exclude, _): + common_exclusions = ["foo", "bar"] + display_exclusion_options(common_exclusions, "foo", []) + + mock_exclude.assert_not_called() + self.assert_called_once(mock_gather) + + @mock.patch("updateHostsFile.query_yes_no", side_effect=[1, 0, 1]) + @mock.patch("updateHostsFile.exclude_domain", return_value=None) + @mock.patch("updateHostsFile.gather_custom_exclusions", return_value=None) + def test_mixture_gather_exclusions(self, mock_gather, mock_exclude, _): + common_exclusions = ["foo", "bar"] + display_exclusion_options(common_exclusions, "foo", []) + + mock_exclude.assert_called_once_with("foo", "foo", []) + self.assert_called_once(mock_gather) + + class TestGatherCustomExclusions(BaseStdout): # Can only test in the invalid domain case @@ -427,7 +470,7 @@ class TestGatherCustomExclusions(BaseStdout): @mock.patch("updateHostsFile.raw_input", side_effect=["foo", "no"]) @mock.patch("updateHostsFile.is_valid_domain_format", return_value=False) def test_basic(self, *_): - gather_custom_exclusions() + gather_custom_exclusions("foo", []) expected = "Do you have more domains you want to enter? [Y/n]" output = sys.stdout.getvalue() @@ -437,12 +480,70 @@ class TestGatherCustomExclusions(BaseStdout): "bar", "no"]) @mock.patch("updateHostsFile.is_valid_domain_format", return_value=False) def test_multiple(self, *_): - gather_custom_exclusions() + gather_custom_exclusions("foo", []) expected = ("Do you have more domains you want to enter? [Y/n] " "Do you have more domains you want to enter? [Y/n]") output = sys.stdout.getvalue() self.assertIn(expected, output) + + +class TestExcludeDomain(Base): + + def test_invalid_exclude_domain(self): + exclusion_regexes = [] + exclusion_pattern = "*.com" + + for domain in ["google.com", "hulu.com", "adaway.org"]: + self.assertRaises(re.error, exclude_domain, domain, + exclusion_pattern, exclusion_regexes) + + self.assertListEqual(exclusion_regexes, []) + + def test_valid_exclude_domain(self): + exp_count = 0 + expected_regexes = [] + exclusion_regexes = [] + exclusion_pattern = "[a-z]\." + + for domain in ["google.com", "hulu.com", "adaway.org"]: + self.assertEqual(len(exclusion_regexes), exp_count) + + exclusion_regexes = exclude_domain(domain, exclusion_pattern, + exclusion_regexes) + expected_regex = re.compile(exclusion_pattern + domain) + + expected_regexes.append(expected_regex) + exp_count += 1 + + self.assertEqual(len(exclusion_regexes), exp_count) + self.assertListEqual(exclusion_regexes, expected_regexes) + + +class TestMatchesExclusions(Base): + + def test_no_match_empty_list(self): + exclusion_regexes = [] + + for domain in ["1.2.3.4 localhost", "5.6.7.8 hulu.com", + "9.1.2.3 yahoo.com", "4.5.6.7 cloudfront.net"]: + self.assertFalse(matches_exclusions(domain, exclusion_regexes)) + + def test_no_match_list(self): + exclusion_regexes = [".*\.org", ".*\.edu"] + exclusion_regexes = [re.compile(regex) for regex in exclusion_regexes] + + for domain in ["1.2.3.4 localhost", "5.6.7.8 hulu.com", + "9.1.2.3 yahoo.com", "4.5.6.7 cloudfront.net"]: + self.assertFalse(matches_exclusions(domain, exclusion_regexes)) + + def test_match_list(self): + exclusion_regexes = [".*\.com", ".*\.org", ".*\.edu"] + exclusion_regexes = [re.compile(regex) for regex in exclusion_regexes] + + for domain in ["5.6.7.8 hulu.com", "9.1.2.3 yahoo.com", + "4.5.6.7 adaway.org", "8.9.1.2 education.edu"]: + self.assertTrue(matches_exclusions(domain, exclusion_regexes)) # End Exclusion Logic diff --git a/updateHostsFile.py b/updateHostsFile.py index 96c38706a..dbcfe7696 100644 --- a/updateHostsFile.py +++ b/updateHostsFile.py @@ -147,9 +147,18 @@ def main(): set(options["extensions"]).intersection(settings["extensions"]))) auto = settings["auto"] + exclusion_regexes = settings["exclusionregexs"] prompt_for_update(freshen=settings["freshen"], update_auto=auto) - prompt_for_exclusions(skip_prompt=auto) + gather_exclusions = prompt_for_exclusions(skip_prompt=auto) + + if gather_exclusions: + common_exclusions = settings["commonexclusions"] + exclusion_pattern = settings["exclusionpattern"] + exclusion_regexes = display_exclusion_options( + common_exclusions=common_exclusions, + exclusion_pattern=exclusion_pattern, + exclusion_regexes=exclusion_regexes) merge_file = create_initial_file() remove_old_hosts_file(settings["backup"]) @@ -157,7 +166,7 @@ def main(): extensions = settings["extensions"] output_subfolder = settings["outputsubfolder"] - final_file = remove_dups_and_excl(merge_file) + final_file = remove_dups_and_excl(merge_file, exclusion_regexes) number_of_rules = settings["numberofrules"] skip_static_hosts = settings["skipstatichosts"] @@ -247,6 +256,12 @@ def prompt_for_exclusions(skip_prompt): skip_prompt : bool Whether or not to skip prompting for custom domains to be excluded. If true, the function returns immediately. + + Returns + ------- + gather_exclusions : bool + Whether or not we should proceed to prompt the user to exclude any + custom domains beyond those in the whitelist. """ prompt = ("Do you want to exclude any domains?\n" @@ -255,10 +270,12 @@ def prompt_for_exclusions(skip_prompt): if not skip_prompt: if query_yes_no(prompt): - display_exclusion_options() + return True else: print("OK, we'll only exclude domains in the whitelist.") + return False + def prompt_for_flush_dns_cache(flush_cache, prompt_flush): """ @@ -322,29 +339,65 @@ def prompt_for_move(final_file, **move_params): # Exclusion logic -def display_exclusion_options(): +def display_exclusion_options(common_exclusions, exclusion_pattern, + exclusion_regexes): """ Display the exclusion options to the user. This function checks whether a user wants to exclude particular domains, and if so, excludes them. + + Parameters + ---------- + common_exclusions : list + A list of common domains that are excluded from being blocked. One + example is Hulu. This setting is set directly in the script and cannot + be overwritten by the user. + exclusion_pattern : str + The exclusion pattern with which to create the domain regex. + exclusion_regexes : list + The list of regex patterns used to exclude domains. + + Returns + ------- + aug_exclusion_regexes : list + The original list of regex patterns potentially with additional + patterns from domains that user chooses to exclude. """ - for exclusion_option in settings["commonexclusions"]: + for exclusion_option in common_exclusions: prompt = "Do you want to exclude the domain " + exclusion_option + " ?" if query_yes_no(prompt): - exclude_domain(exclusion_option) + exclusion_regexes = exclude_domain(exclusion_option, + exclusion_pattern, + exclusion_regexes) else: continue if query_yes_no("Do you want to exclude any other domains?"): - gather_custom_exclusions() + exclusion_regexes = gather_custom_exclusions(exclusion_pattern, + exclusion_regexes) + + return exclusion_regexes -def gather_custom_exclusions(): +def gather_custom_exclusions(exclusion_pattern, exclusion_regexes): """ Gather custom exclusions from the user. + + Parameters + ---------- + exclusion_pattern : str + The exclusion pattern with which to create the domain regex. + exclusion_regexes : list + The list of regex patterns used to exclude domains. + + Returns + ------- + aug_exclusion_regexes : list + The original list of regex patterns potentially with additional + patterns from domains that user chooses to exclude. """ # We continue running this while-loop until the user @@ -355,28 +408,46 @@ def gather_custom_exclusions(): user_domain = raw_input(domain_prompt) if is_valid_domain_format(user_domain): - exclude_domain(user_domain) + exclusion_regexes = exclude_domain(user_domain, exclusion_pattern, + exclusion_regexes) continue_prompt = "Do you have more domains you want to enter?" if not query_yes_no(continue_prompt): - return + break + + return exclusion_regexes -def exclude_domain(domain): +def exclude_domain(domain, exclusion_pattern, exclusion_regexes): """ Exclude a domain from being blocked. + This create the domain regex by which to exclude this domain and appends + it a list of already-existing exclusion regexes. + Parameters ---------- domain : str The filename or regex pattern to exclude. + exclusion_pattern : str + The exclusion pattern with which to create the domain regex. + exclusion_regexes : list + The list of regex patterns used to exclude domains. + + Returns + ------- + aug_exclusion_regexes : list + The original list of regex patterns with one additional pattern from + the `domain` input. """ - settings["exclusionregexs"].append(re.compile( - settings["exclusionpattern"] + domain)) + exclusion_regex = re.compile(exclusion_pattern + domain) + exclusion_regexes.append(exclusion_regex) + + return exclusion_regexes -def matches_exclusions(stripped_rule): +def matches_exclusions(stripped_rule, exclusion_regexes): """ Check whether a rule matches an exclusion rule we already provided. @@ -387,6 +458,8 @@ def matches_exclusions(stripped_rule): ---------- stripped_rule : str The rule that we are checking. + exclusion_regexes : list + The list of regex patterns used to exclude domains. Returns ------- @@ -395,9 +468,11 @@ def matches_exclusions(stripped_rule): """ stripped_domain = stripped_rule.split()[1] - for exclusionRegex in settings["exclusionregexs"]: + + for exclusionRegex in exclusion_regexes: if exclusionRegex.search(stripped_domain): return True + return False # End Exclusion Logic @@ -479,7 +554,7 @@ def create_initial_file(): return merge_file -def remove_dups_and_excl(merge_file): +def remove_dups_and_excl(merge_file, exclusion_regexes): """ Remove duplicates and remove hosts that we are excluding. @@ -490,6 +565,8 @@ def remove_dups_and_excl(merge_file): ---------- merge_file : file The file object that contains the hostnames that we are pruning. + exclusion_regexes : list + The list of regex patterns used to exclude domains. """ number_of_rules = settings["numberofrules"] @@ -532,7 +609,8 @@ def remove_dups_and_excl(merge_file): continue stripped_rule = strip_rule(line) # strip comments - if not stripped_rule or matches_exclusions(stripped_rule): + if not stripped_rule or matches_exclusions(stripped_rule, + exclusion_regexes): continue # Normalize rule