Merge pull request #366 from gfyoung/settings-refactor

Refactor out global settings usage in exclusions
This commit is contained in:
Steven Black 2017-08-07 12:26:54 -04:00 committed by GitHub
commit b7c65a10dc
3 changed files with 222 additions and 42 deletions

View File

@ -19,6 +19,7 @@ os:
env:
- PYTHON_VERSION="2.7"
- PYTHON_VERSION="3.5"
- PYTHON_VERSION="3.6"
before_install:

View File

@ -5,17 +5,16 @@
#
# Python script for testing updateHostFiles.py
from updateHostsFile import (Colors, PY3, colorize, flush_dns_cache,
gather_custom_exclusions, get_defaults,
get_file_by_url, is_valid_domain_format,
move_hosts_file_into_place, normalize_rule,
path_join_robust, print_failure, print_success,
prompt_for_exclusions, prompt_for_move,
prompt_for_flush_dns_cache, prompt_for_update,
query_yes_no, recursive_glob,
remove_old_hosts_file, supports_color, strip_rule,
update_readme_data, write_data,
write_opening_header)
from updateHostsFile import (
Colors, PY3, colorize, display_exclusion_options, exclude_domain,
flush_dns_cache, gather_custom_exclusions, get_defaults, get_file_by_url,
is_valid_domain_format, matches_exclusions, move_hosts_file_into_place,
normalize_rule, path_join_robust, print_failure, print_success,
prompt_for_exclusions, prompt_for_move, prompt_for_flush_dns_cache,
prompt_for_update, query_yes_no, recursive_glob, remove_old_hosts_file,
supports_color, strip_rule, update_readme_data, write_data,
write_opening_header)
import updateHostsFile
import unittest
import tempfile
@ -24,6 +23,7 @@ import shutil
import json
import sys
import os
import re
if PY3:
from io import BytesIO, StringIO
@ -260,21 +260,20 @@ class TestPromptForUpdate(BaseStdout, BaseMockDir):
class TestPromptForExclusions(BaseStdout):
@mock.patch("updateHostsFile.display_exclusion_options", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testSkipPrompt(self, mock_query, mock_display):
prompt_for_exclusions(skip_prompt=True)
def testSkipPrompt(self, mock_query):
gather_exclusions = prompt_for_exclusions(skip_prompt=True)
self.assertFalse(gather_exclusions)
output = sys.stdout.getvalue()
self.assertEqual(output, "")
mock_query.assert_not_called()
mock_display.assert_not_called()
@mock.patch("updateHostsFile.display_exclusion_options", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testNoSkipPromptNoDisplay(self, mock_query, mock_display):
prompt_for_exclusions(skip_prompt=False)
def testNoSkipPromptNoDisplay(self, mock_query):
gather_exclusions = prompt_for_exclusions(skip_prompt=False)
self.assertFalse(gather_exclusions)
output = sys.stdout.getvalue()
expected = ("OK, we'll only exclude "
@ -282,18 +281,16 @@ class TestPromptForExclusions(BaseStdout):
self.assertIn(expected, output)
self.assert_called_once(mock_query)
mock_display.assert_not_called()
@mock.patch("updateHostsFile.display_exclusion_options", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=True)
def testNoSkipPromptDisplay(self, mock_query, mock_display):
prompt_for_exclusions(skip_prompt=False)
def testNoSkipPromptDisplay(self, mock_query):
gather_exclusions = prompt_for_exclusions(skip_prompt=False)
self.assertTrue(gather_exclusions)
output = sys.stdout.getvalue()
self.assertEqual(output, "")
self.assert_called_once(mock_query)
self.assert_called_once(mock_display)
class TestPromptForFlushDnsCache(Base):
@ -420,6 +417,52 @@ class TestPromptForMove(Base):
# Exclusion Logic
class TestDisplayExclusionsOptions(Base):
@mock.patch("updateHostsFile.query_yes_no", return_value=0)
@mock.patch("updateHostsFile.exclude_domain", return_value=None)
@mock.patch("updateHostsFile.gather_custom_exclusions", return_value=None)
def test_no_exclusions(self, mock_gather, mock_exclude, _):
common_exclusions = []
display_exclusion_options(common_exclusions, "foo", [])
mock_gather.assert_not_called()
mock_exclude.assert_not_called()
@mock.patch("updateHostsFile.query_yes_no", side_effect=[1, 1, 0])
@mock.patch("updateHostsFile.exclude_domain", return_value=None)
@mock.patch("updateHostsFile.gather_custom_exclusions", return_value=None)
def test_only_common_exclusions(self, mock_gather, mock_exclude, _):
common_exclusions = ["foo", "bar"]
display_exclusion_options(common_exclusions, "foo", [])
mock_gather.assert_not_called()
exclude_calls = [mock.call("foo", "foo", []),
mock.call("bar", "foo", None)]
mock_exclude.assert_has_calls(exclude_calls)
@mock.patch("updateHostsFile.query_yes_no", side_effect=[0, 0, 1])
@mock.patch("updateHostsFile.exclude_domain", return_value=None)
@mock.patch("updateHostsFile.gather_custom_exclusions", return_value=None)
def test_gather_exclusions(self, mock_gather, mock_exclude, _):
common_exclusions = ["foo", "bar"]
display_exclusion_options(common_exclusions, "foo", [])
mock_exclude.assert_not_called()
self.assert_called_once(mock_gather)
@mock.patch("updateHostsFile.query_yes_no", side_effect=[1, 0, 1])
@mock.patch("updateHostsFile.exclude_domain", return_value=None)
@mock.patch("updateHostsFile.gather_custom_exclusions", return_value=None)
def test_mixture_gather_exclusions(self, mock_gather, mock_exclude, _):
common_exclusions = ["foo", "bar"]
display_exclusion_options(common_exclusions, "foo", [])
mock_exclude.assert_called_once_with("foo", "foo", [])
self.assert_called_once(mock_gather)
class TestGatherCustomExclusions(BaseStdout):
# Can only test in the invalid domain case
@ -427,7 +470,7 @@ class TestGatherCustomExclusions(BaseStdout):
@mock.patch("updateHostsFile.raw_input", side_effect=["foo", "no"])
@mock.patch("updateHostsFile.is_valid_domain_format", return_value=False)
def test_basic(self, *_):
gather_custom_exclusions()
gather_custom_exclusions("foo", [])
expected = "Do you have more domains you want to enter? [Y/n]"
output = sys.stdout.getvalue()
@ -437,12 +480,70 @@ class TestGatherCustomExclusions(BaseStdout):
"bar", "no"])
@mock.patch("updateHostsFile.is_valid_domain_format", return_value=False)
def test_multiple(self, *_):
gather_custom_exclusions()
gather_custom_exclusions("foo", [])
expected = ("Do you have more domains you want to enter? [Y/n] "
"Do you have more domains you want to enter? [Y/n]")
output = sys.stdout.getvalue()
self.assertIn(expected, output)
class TestExcludeDomain(Base):
def test_invalid_exclude_domain(self):
exclusion_regexes = []
exclusion_pattern = "*.com"
for domain in ["google.com", "hulu.com", "adaway.org"]:
self.assertRaises(re.error, exclude_domain, domain,
exclusion_pattern, exclusion_regexes)
self.assertListEqual(exclusion_regexes, [])
def test_valid_exclude_domain(self):
exp_count = 0
expected_regexes = []
exclusion_regexes = []
exclusion_pattern = "[a-z]\."
for domain in ["google.com", "hulu.com", "adaway.org"]:
self.assertEqual(len(exclusion_regexes), exp_count)
exclusion_regexes = exclude_domain(domain, exclusion_pattern,
exclusion_regexes)
expected_regex = re.compile(exclusion_pattern + domain)
expected_regexes.append(expected_regex)
exp_count += 1
self.assertEqual(len(exclusion_regexes), exp_count)
self.assertListEqual(exclusion_regexes, expected_regexes)
class TestMatchesExclusions(Base):
def test_no_match_empty_list(self):
exclusion_regexes = []
for domain in ["1.2.3.4 localhost", "5.6.7.8 hulu.com",
"9.1.2.3 yahoo.com", "4.5.6.7 cloudfront.net"]:
self.assertFalse(matches_exclusions(domain, exclusion_regexes))
def test_no_match_list(self):
exclusion_regexes = [".*\.org", ".*\.edu"]
exclusion_regexes = [re.compile(regex) for regex in exclusion_regexes]
for domain in ["1.2.3.4 localhost", "5.6.7.8 hulu.com",
"9.1.2.3 yahoo.com", "4.5.6.7 cloudfront.net"]:
self.assertFalse(matches_exclusions(domain, exclusion_regexes))
def test_match_list(self):
exclusion_regexes = [".*\.com", ".*\.org", ".*\.edu"]
exclusion_regexes = [re.compile(regex) for regex in exclusion_regexes]
for domain in ["5.6.7.8 hulu.com", "9.1.2.3 yahoo.com",
"4.5.6.7 adaway.org", "8.9.1.2 education.edu"]:
self.assertTrue(matches_exclusions(domain, exclusion_regexes))
# End Exclusion Logic

View File

@ -147,9 +147,18 @@ def main():
set(options["extensions"]).intersection(settings["extensions"])))
auto = settings["auto"]
exclusion_regexes = settings["exclusionregexs"]
prompt_for_update(freshen=settings["freshen"], update_auto=auto)
prompt_for_exclusions(skip_prompt=auto)
gather_exclusions = prompt_for_exclusions(skip_prompt=auto)
if gather_exclusions:
common_exclusions = settings["commonexclusions"]
exclusion_pattern = settings["exclusionpattern"]
exclusion_regexes = display_exclusion_options(
common_exclusions=common_exclusions,
exclusion_pattern=exclusion_pattern,
exclusion_regexes=exclusion_regexes)
merge_file = create_initial_file()
remove_old_hosts_file(settings["backup"])
@ -157,7 +166,7 @@ def main():
extensions = settings["extensions"]
output_subfolder = settings["outputsubfolder"]
final_file = remove_dups_and_excl(merge_file)
final_file = remove_dups_and_excl(merge_file, exclusion_regexes)
number_of_rules = settings["numberofrules"]
skip_static_hosts = settings["skipstatichosts"]
@ -247,6 +256,12 @@ def prompt_for_exclusions(skip_prompt):
skip_prompt : bool
Whether or not to skip prompting for custom domains to be excluded.
If true, the function returns immediately.
Returns
-------
gather_exclusions : bool
Whether or not we should proceed to prompt the user to exclude any
custom domains beyond those in the whitelist.
"""
prompt = ("Do you want to exclude any domains?\n"
@ -255,10 +270,12 @@ def prompt_for_exclusions(skip_prompt):
if not skip_prompt:
if query_yes_no(prompt):
display_exclusion_options()
return True
else:
print("OK, we'll only exclude domains in the whitelist.")
return False
def prompt_for_flush_dns_cache(flush_cache, prompt_flush):
"""
@ -322,29 +339,65 @@ def prompt_for_move(final_file, **move_params):
# Exclusion logic
def display_exclusion_options():
def display_exclusion_options(common_exclusions, exclusion_pattern,
exclusion_regexes):
"""
Display the exclusion options to the user.
This function checks whether a user wants to exclude particular domains,
and if so, excludes them.
Parameters
----------
common_exclusions : list
A list of common domains that are excluded from being blocked. One
example is Hulu. This setting is set directly in the script and cannot
be overwritten by the user.
exclusion_pattern : str
The exclusion pattern with which to create the domain regex.
exclusion_regexes : list
The list of regex patterns used to exclude domains.
Returns
-------
aug_exclusion_regexes : list
The original list of regex patterns potentially with additional
patterns from domains that user chooses to exclude.
"""
for exclusion_option in settings["commonexclusions"]:
for exclusion_option in common_exclusions:
prompt = "Do you want to exclude the domain " + exclusion_option + " ?"
if query_yes_no(prompt):
exclude_domain(exclusion_option)
exclusion_regexes = exclude_domain(exclusion_option,
exclusion_pattern,
exclusion_regexes)
else:
continue
if query_yes_no("Do you want to exclude any other domains?"):
gather_custom_exclusions()
exclusion_regexes = gather_custom_exclusions(exclusion_pattern,
exclusion_regexes)
return exclusion_regexes
def gather_custom_exclusions():
def gather_custom_exclusions(exclusion_pattern, exclusion_regexes):
"""
Gather custom exclusions from the user.
Parameters
----------
exclusion_pattern : str
The exclusion pattern with which to create the domain regex.
exclusion_regexes : list
The list of regex patterns used to exclude domains.
Returns
-------
aug_exclusion_regexes : list
The original list of regex patterns potentially with additional
patterns from domains that user chooses to exclude.
"""
# We continue running this while-loop until the user
@ -355,28 +408,46 @@ def gather_custom_exclusions():
user_domain = raw_input(domain_prompt)
if is_valid_domain_format(user_domain):
exclude_domain(user_domain)
exclusion_regexes = exclude_domain(user_domain, exclusion_pattern,
exclusion_regexes)
continue_prompt = "Do you have more domains you want to enter?"
if not query_yes_no(continue_prompt):
return
break
return exclusion_regexes
def exclude_domain(domain):
def exclude_domain(domain, exclusion_pattern, exclusion_regexes):
"""
Exclude a domain from being blocked.
This create the domain regex by which to exclude this domain and appends
it a list of already-existing exclusion regexes.
Parameters
----------
domain : str
The filename or regex pattern to exclude.
exclusion_pattern : str
The exclusion pattern with which to create the domain regex.
exclusion_regexes : list
The list of regex patterns used to exclude domains.
Returns
-------
aug_exclusion_regexes : list
The original list of regex patterns with one additional pattern from
the `domain` input.
"""
settings["exclusionregexs"].append(re.compile(
settings["exclusionpattern"] + domain))
exclusion_regex = re.compile(exclusion_pattern + domain)
exclusion_regexes.append(exclusion_regex)
return exclusion_regexes
def matches_exclusions(stripped_rule):
def matches_exclusions(stripped_rule, exclusion_regexes):
"""
Check whether a rule matches an exclusion rule we already provided.
@ -387,6 +458,8 @@ def matches_exclusions(stripped_rule):
----------
stripped_rule : str
The rule that we are checking.
exclusion_regexes : list
The list of regex patterns used to exclude domains.
Returns
-------
@ -395,9 +468,11 @@ def matches_exclusions(stripped_rule):
"""
stripped_domain = stripped_rule.split()[1]
for exclusionRegex in settings["exclusionregexs"]:
for exclusionRegex in exclusion_regexes:
if exclusionRegex.search(stripped_domain):
return True
return False
# End Exclusion Logic
@ -479,7 +554,7 @@ def create_initial_file():
return merge_file
def remove_dups_and_excl(merge_file):
def remove_dups_and_excl(merge_file, exclusion_regexes):
"""
Remove duplicates and remove hosts that we are excluding.
@ -490,6 +565,8 @@ def remove_dups_and_excl(merge_file):
----------
merge_file : file
The file object that contains the hostnames that we are pruning.
exclusion_regexes : list
The list of regex patterns used to exclude domains.
"""
number_of_rules = settings["numberofrules"]
@ -532,7 +609,8 @@ def remove_dups_and_excl(merge_file):
continue
stripped_rule = strip_rule(line) # strip comments
if not stripped_rule or matches_exclusions(stripped_rule):
if not stripped_rule or matches_exclusions(stripped_rule,
exclusion_regexes):
continue
# Normalize rule