From 415b89a0228f92141b590932ca40f5ae17b3550b Mon Sep 17 00:00:00 2001 From: gfyoung Date: Thu, 29 Jun 2017 20:55:41 -0700 Subject: [PATCH] Reduce dependency on global settings variable Global variables make code less modular and therefore more difficult to test. --- testUpdateHostsFile.py | 428 ++++++++++++++++++++++++++++++++++++++++- updateHostsFile.py | 155 ++++++++++----- 2 files changed, 532 insertions(+), 51 deletions(-) diff --git a/testUpdateHostsFile.py b/testUpdateHostsFile.py index e875737d8..2eeac7be4 100644 --- a/testUpdateHostsFile.py +++ b/testUpdateHostsFile.py @@ -11,10 +11,15 @@ from updateHostsFile import (Colors, PY3, colorize, flush_dns_cache, move_hosts_file_into_place, normalize_rule, path_join_robust, print_failure, print_success, supports_color, query_yes_no, recursive_glob, - strip_rule, write_data) + remove_old_hosts_file, strip_rule, + update_readme_data, write_data, + write_opening_header) import updateHostsFile import unittest +import tempfile import locale +import shutil +import json import sys import os @@ -48,6 +53,19 @@ class BaseStdout(Base): def tearDown(self): sys.stdout.close() sys.stdout = sys.__stdout__ + + +class BaseMockDir(Base): + + @property + def dir_count(self): + return len(os.listdir(self.test_dir)) + + def setUp(self): + self.test_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.test_dir) # End Base Test Classes @@ -119,12 +137,12 @@ class TestGatherCustomExclusions(BaseStdout): # File Logic class TestNormalizeRule(BaseStdout): - # Can only test non-matches because they don't - # interact with the settings global variable. def test_no_match(self): + kwargs = dict(target_ip="0.0.0.0", keep_domain_comments=False) + for rule in ["foo", "128.0.0.1", "bar.com/usa", "0.0.0 google", "0.1.2.3.4 foo/bar", "twitter.com"]: - self.assertEqual(normalize_rule(rule), (None, None)) + self.assertEqual(normalize_rule(rule, **kwargs), (None, None)) output = sys.stdout.getvalue() sys.stdout = StringIO() @@ -132,6 +150,38 @@ class TestNormalizeRule(BaseStdout): expected = "==>" + rule + "<==" self.assertIn(expected, output) + def test_no_comments(self): + for target_ip in ("0.0.0.0", "127.0.0.1", "8.8.8.8"): + rule = "127.0.0.1 google foo" + expected = ("google", str(target_ip) + " google\n") + + actual = normalize_rule(rule, target_ip=target_ip, + keep_domain_comments=False) + self.assertEqual(actual, expected) + + # Nothing gets printed if there's a match. + output = sys.stdout.getvalue() + self.assertEqual(output, "") + + sys.stdout = StringIO() + + def test_with_comments(self): + for target_ip in ("0.0.0.0", "127.0.0.1", "8.8.8.8"): + for comment in ("foo", "bar", "baz"): + rule = "127.0.0.1 google " + comment + expected = ("google", (str(target_ip) + " google # " + + comment + "\n")) + + actual = normalize_rule(rule, target_ip=target_ip, + keep_domain_comments=True) + self.assertEqual(actual, expected) + + # Nothing gets printed if there's a match. + output = sys.stdout.getvalue() + self.assertEqual(output, "") + + sys.stdout = StringIO() + class TestStripRule(Base): @@ -153,6 +203,304 @@ class TestStripRule(Base): self.assertEqual(output, line) +class TestWriteOpeningHeader(BaseMockDir): + + def setUp(self): + super(TestWriteOpeningHeader, self).setUp() + self.final_file = BytesIO() + + def test_missing_keyword(self): + kwargs = dict(extensions="", outputsubfolder="", + numberofrules=5, skipstatichosts=False) + + for k in kwargs.keys(): + bad_kwargs = kwargs.copy() + bad_kwargs.pop(k) + + self.assertRaises(KeyError, write_opening_header, + self.final_file, **bad_kwargs) + + def test_basic(self): + kwargs = dict(extensions="", outputsubfolder="", + numberofrules=5, skipstatichosts=True) + write_opening_header(self.final_file, **kwargs) + + contents = self.final_file.getvalue() + contents = contents.decode("UTF-8") + + # Expected contents. + for expected in ( + "# This hosts file is a merged collection", + "# with a dash of crowd sourcing via Github", + "# Number of unique domains: {count}".format( + count=kwargs["numberofrules"]), + "Fetch the latest version of this file:", + "Project home page: https://github.com/StevenBlack/hosts", + ): + self.assertIn(expected, contents) + + # Expected non-contents. + for expected in ( + "# Extensions added to this file:", + "127.0.0.1 localhost", + "127.0.0.1 local", + "127.0.0.53", + "127.0.1.1", + ): + self.assertNotIn(expected, contents) + + def test_basic_include_static_hosts(self): + kwargs = dict(extensions="", outputsubfolder="", + numberofrules=5, skipstatichosts=False) + with self.mock_property("platform.system") as obj: + obj.return_value = "Windows" + write_opening_header(self.final_file, **kwargs) + + contents = self.final_file.getvalue() + contents = contents.decode("UTF-8") + + # Expected contents. + for expected in ( + "127.0.0.1 local", + "127.0.0.1 localhost", + "# This hosts file is a merged collection", + "# with a dash of crowd sourcing via Github", + "# Number of unique domains: {count}".format( + count=kwargs["numberofrules"]), + "Fetch the latest version of this file:", + "Project home page: https://github.com/StevenBlack/hosts", + ): + self.assertIn(expected, contents) + + # Expected non-contents. + for expected in ( + "# Extensions added to this file:", + "127.0.0.53", + "127.0.1.1", + ): + self.assertNotIn(expected, contents) + + def test_basic_include_static_hosts_linux(self): + kwargs = dict(extensions="", outputsubfolder="", + numberofrules=5, skipstatichosts=False) + with self.mock_property("platform.system") as system: + system.return_value = "Linux" + + with self.mock_property("socket.gethostname") as hostname: + hostname.return_value = "steven-hosts" + write_opening_header(self.final_file, **kwargs) + + contents = self.final_file.getvalue() + contents = contents.decode("UTF-8") + + # Expected contents. + for expected in ( + "127.0.1.1", + "127.0.0.53", + "steven-hosts", + "127.0.0.1 local", + "127.0.0.1 localhost", + "# This hosts file is a merged collection", + "# with a dash of crowd sourcing via Github", + "# Number of unique domains: {count}".format( + count=kwargs["numberofrules"]), + "Fetch the latest version of this file:", + "Project home page: https://github.com/StevenBlack/hosts", + ): + self.assertIn(expected, contents) + + # Expected non-contents. + expected = "# Extensions added to this file:" + self.assertNotIn(expected, contents) + + def test_extensions(self): + kwargs = dict(extensions=["epsilon", "gamma", "mu", "phi"], + outputsubfolder="", numberofrules=5, + skipstatichosts=True) + write_opening_header(self.final_file, **kwargs) + + contents = self.final_file.getvalue() + contents = contents.decode("UTF-8") + + # Expected contents. + for expected in ( + ", ".join(kwargs["extensions"]), + "# Extensions added to this file:", + "# This hosts file is a merged collection", + "# with a dash of crowd sourcing via Github", + "# Number of unique domains: {count}".format( + count=kwargs["numberofrules"]), + "Fetch the latest version of this file:", + "Project home page: https://github.com/StevenBlack/hosts", + ): + self.assertIn(expected, contents) + + # Expected non-contents. + for expected in ( + "127.0.0.1 localhost", + "127.0.0.1 local", + "127.0.0.53", + "127.0.1.1", + ): + self.assertNotIn(expected, contents) + + def test_no_preamble(self): + # We should not even attempt to read this, as it is a directory. + hosts_dir = os.path.join(self.test_dir, "myhosts") + os.mkdir(hosts_dir) + + kwargs = dict(extensions="", outputsubfolder="", + numberofrules=5, skipstatichosts=True) + + with self.mock_property("updateHostsFile.BASEDIR_PATH"): + updateHostsFile.BASEDIR_PATH = self.test_dir + write_opening_header(self.final_file, **kwargs) + + contents = self.final_file.getvalue() + contents = contents.decode("UTF-8") + + # Expected contents. + for expected in ( + "# This hosts file is a merged collection", + "# with a dash of crowd sourcing via Github", + "# Number of unique domains: {count}".format( + count=kwargs["numberofrules"]), + "Fetch the latest version of this file:", + "Project home page: https://github.com/StevenBlack/hosts", + ): + self.assertIn(expected, contents) + + # Expected non-contents. + for expected in ( + "# Extensions added to this file:", + "127.0.0.1 localhost", + "127.0.0.1 local", + "127.0.0.53", + "127.0.1.1", + ): + self.assertNotIn(expected, contents) + + def test_preamble(self): + hosts_file = os.path.join(self.test_dir, "myhosts") + with open(hosts_file, "w") as f: + f.write("peter-piper-picked-a-pepper") + + kwargs = dict(extensions="", outputsubfolder="", + numberofrules=5, skipstatichosts=True) + + with self.mock_property("updateHostsFile.BASEDIR_PATH"): + updateHostsFile.BASEDIR_PATH = self.test_dir + write_opening_header(self.final_file, **kwargs) + + contents = self.final_file.getvalue() + contents = contents.decode("UTF-8") + + # Expected contents. + for expected in ( + "peter-piper-picked-a-pepper", + "# This hosts file is a merged collection", + "# with a dash of crowd sourcing via Github", + "# Number of unique domains: {count}".format( + count=kwargs["numberofrules"]), + "Fetch the latest version of this file:", + "Project home page: https://github.com/StevenBlack/hosts", + ): + self.assertIn(expected, contents) + + # Expected non-contents. + for expected in ( + "# Extensions added to this file:", + "127.0.0.1 localhost", + "127.0.0.1 local", + "127.0.0.53", + "127.0.1.1", + ): + self.assertNotIn(expected, contents) + + def tearDown(self): + super(TestWriteOpeningHeader, self).tearDown() + self.final_file.close() + + +class TestUpdateReadmeData(BaseMockDir): + + def setUp(self): + super(TestUpdateReadmeData, self).setUp() + self.readme_file = os.path.join(self.test_dir, "readmeData.json") + + def test_missing_keyword(self): + kwargs = dict(extensions="", outputsubfolder="", + numberofrules="", sourcesdata="") + + for k in kwargs.keys(): + bad_kwargs = kwargs.copy() + bad_kwargs.pop(k) + + self.assertRaises(KeyError, update_readme_data, + self.readme_file, **bad_kwargs) + + def test_add_fields(self): + with open(self.readme_file, "w") as f: + json.dump({"foo": "bar"}, f) + + kwargs = dict(extensions=None, outputsubfolder="foo", + numberofrules=5, sourcesdata="hosts") + update_readme_data(self.readme_file, **kwargs) + + expected = { + "base": { + "location": "foo" + self.sep, + "sourcesdata": "hosts", + "entries": 5, + }, + "foo": "bar" + } + + with open(self.readme_file, "r") as f: + actual = json.load(f) + self.assertEqual(actual, expected) + + def test_modify_fields(self): + with open(self.readme_file, "w") as f: + json.dump({"base": "soprano"}, f) + + kwargs = dict(extensions=None, outputsubfolder="foo", + numberofrules=5, sourcesdata="hosts") + update_readme_data(self.readme_file, **kwargs) + + expected = { + "base": { + "location": "foo" + self.sep, + "sourcesdata": "hosts", + "entries": 5, + } + } + + with open(self.readme_file, "r") as f: + actual = json.load(f) + self.assertEqual(actual, expected) + + def test_set_extensions(self): + with open(self.readme_file, "w") as f: + json.dump({}, f) + + kwargs = dict(extensions=["com", "org"], outputsubfolder="foo", + numberofrules=5, sourcesdata="hosts") + update_readme_data(self.readme_file, **kwargs) + + expected = { + "com-org": { + "location": "foo" + self.sep, + "sourcesdata": "hosts", + "entries": 5, + } + } + + with open(self.readme_file, "r") as f: + actual = json.load(f) + self.assertEqual(actual, expected) + + class TestMoveHostsFile(BaseStdout): @mock.patch("os.path.abspath", side_effect=lambda f: f) @@ -312,6 +660,78 @@ class TestFlushDnsCache(BaseStdout): ("Flushing the DNS cache by restarting " "NetworkManager.service succeeded")]: self.assertIn(expected, output) + + +def mock_path_join_robust(*args): + # We want to hard-code the backup hosts filename + # instead of parametrizing based on current time. + if len(args) == 2 and args[1].startswith("hosts-"): + return os.path.join(args[0], "hosts-new") + else: + return os.path.join(*args) + + +class TestRemoveOldHostsFile(BaseMockDir): + + def setUp(self): + super(TestRemoveOldHostsFile, self).setUp() + self.hosts_file = os.path.join(self.test_dir, "hosts") + + def test_remove_hosts_file(self): + old_dir_count = self.dir_count + + with self.mock_property("updateHostsFile.BASEDIR_PATH"): + updateHostsFile.BASEDIR_PATH = self.test_dir + remove_old_hosts_file(backup=False) + + new_dir_count = old_dir_count + 1 + self.assertEqual(self.dir_count, new_dir_count) + + with open(self.hosts_file, "r") as f: + contents = f.read() + self.assertEqual(contents, "") + + def test_remove_hosts_file_exists(self): + with open(self.hosts_file, "w") as f: + f.write("foo") + + old_dir_count = self.dir_count + + with self.mock_property("updateHostsFile.BASEDIR_PATH"): + updateHostsFile.BASEDIR_PATH = self.test_dir + remove_old_hosts_file(backup=False) + + new_dir_count = old_dir_count + self.assertEqual(self.dir_count, new_dir_count) + + with open(self.hosts_file, "r") as f: + contents = f.read() + self.assertEqual(contents, "") + + @mock.patch("updateHostsFile.path_join_robust", + side_effect=mock_path_join_robust) + def test_remove_hosts_file_backup(self, _): + with open(self.hosts_file, "w") as f: + f.write("foo") + + old_dir_count = self.dir_count + + with self.mock_property("updateHostsFile.BASEDIR_PATH"): + updateHostsFile.BASEDIR_PATH = self.test_dir + remove_old_hosts_file(backup=True) + + new_dir_count = old_dir_count + 1 + self.assertEqual(self.dir_count, new_dir_count) + + with open(self.hosts_file, "r") as f: + contents = f.read() + self.assertEqual(contents, "") + + new_hosts_file = self.hosts_file + "-new" + + with open(new_hosts_file, "r") as f: + contents = f.read() + self.assertEqual(contents, "foo") # End File Logic diff --git a/updateHostsFile.py b/updateHostsFile.py index 2fdbcd37e..9cf5ec7a7 100644 --- a/updateHostsFile.py +++ b/updateHostsFile.py @@ -146,30 +146,39 @@ def main(): settings["extensions"] = sorted(list( set(options["extensions"]).intersection(settings["extensions"]))) - with open(settings["readmedatafilename"], "r") as f: - settings["readmedata"] = json.load(f) - prompt_for_update() prompt_for_exclusions() merge_file = create_initial_file() - remove_old_hosts_file() + remove_old_hosts_file(settings["backup"]) + + extensions = settings["extensions"] + number_of_rules = settings["numberofrules"] + output_subfolder = settings["outputsubfolder"] final_file = remove_dups_and_excl(merge_file) - write_opening_header(final_file) + write_opening_header(final_file, extensions=extensions, + numberofrules=number_of_rules, + outputsubfolder=output_subfolder, + skipstatichosts=settings["skipstatichosts"]) final_file.close() if settings["ziphosts"]: - zf = zipfile.ZipFile(path_join_robust(settings["outputsubfolder"], + zf = zipfile.ZipFile(path_join_robust(output_subfolder, "hosts.zip"), mode='w') - zf.write(path_join_robust(settings["outputsubfolder"], "hosts"), + zf.write(path_join_robust(output_subfolder, "hosts"), compress_type=zipfile.ZIP_DEFLATED, arcname='hosts') zf.close() - update_readme_data() + update_readme_data(settings["readmedatafilename"], + extensions=extensions, + numberofrules=number_of_rules, + outputsubfolder=output_subfolder, + sourcesdata=settings["sourcesdata"]) + print_success("Success! The hosts file has been saved in folder " + - settings["outputsubfolder"] + "\nIt contains " + - "{:,}".format(settings["numberofrules"]) + + output_subfolder + "\nIt contains " + + "{:,}".format(number_of_rules) + " unique entries.") prompt_for_move(final_file) @@ -196,7 +205,7 @@ def prompt_for_update(): if settings["auto"] or query_yes_no(prompt): update_all_sources() elif not settings["auto"]: - print("OK, we'll stick with what we've got locally.") + print("OK, we'll stick with what we've got locally.") def prompt_for_exclusions(): @@ -469,7 +478,10 @@ def remove_dups_and_excl(merge_file): continue # Normalize rule - hostname, normalized_rule = normalize_rule(stripped_rule) + hostname, normalized_rule = normalize_rule( + stripped_rule, target_ip=settings["targetip"], + keep_domain_comments=settings["keepdomaincomments"]) + for exclude in exclusions: if exclude in line: write_line = False @@ -486,7 +498,7 @@ def remove_dups_and_excl(merge_file): return final_file -def normalize_rule(rule): +def normalize_rule(rule, target_ip, keep_domain_comments): """ Standardize and format the rule string provided. @@ -494,26 +506,34 @@ def normalize_rule(rule): ---------- rule : str The rule whose spelling and spacing we are standardizing. + target_ip : str + The target IP address for the rule. + keep_domain_comments : bool + Whether or not to keep comments regarding these domains in + the normalized rule. Returns ------- - normalized_rule : str - The rule string with spelling and spacing reformatted. + normalized_rule : tuple + A tuple of the hostname and the rule string with spelling + and spacing reformatted. """ - result = re.search(r'^\s*(\d{1,3}\.){3}\d{1,3}\s+([\w\.-]+[a-zA-Z])(.*)', - rule) + regex = r'^\s*(\d{1,3}\.){3}\d{1,3}\s+([\w\.-]+[a-zA-Z])(.*)' + result = re.search(regex, rule) + if result: hostname, suffix = result.group(2, 3) - # Explicitly lowercase and trim the hostname + # Explicitly lowercase and trim the hostname. hostname = hostname.lower().strip() - if suffix and settings["keepdomaincomments"]: - # add suffix as comment only, not as a separate host - return hostname, "%s %s #%s\n" % (settings["targetip"], - hostname, suffix) - else: - return hostname, "%s %s\n" % (settings["targetip"], hostname) + rule = "%s %s" % (target_ip, hostname) + + if suffix and keep_domain_comments: + rule += " #%s" % suffix + + return hostname, rule + "\n" + print("==>%s<==" % rule) return None, None @@ -544,7 +564,7 @@ def strip_rule(line): return split_line[0] + " " + split_line[1] -def write_opening_header(final_file): +def write_opening_header(final_file, **header_params): """ Write the header information into the newly-created hosts file. @@ -552,32 +572,45 @@ def write_opening_header(final_file): ---------- final_file : file The file object that points to the newly-created hosts file. + header_params : kwargs + Dictionary providing additional parameters for populating the header + information. Currently, those fields are: + + 1) extensions + 2) numberofrules + 3) outputsubfolder + 4) skipstatichosts """ - final_file.seek(0) # reset file pointer - file_contents = final_file.read() # save content - final_file.seek(0) # write at the top + final_file.seek(0) # Reset file pointer. + file_contents = final_file.read() # Save content. + + final_file.seek(0) # Write at the top. write_data(final_file, "# This hosts file is a merged collection " "of hosts from reputable sources,\n") write_data(final_file, "# with a dash of crowd sourcing via Github\n#\n") write_data(final_file, "# Date: " + time.strftime( "%B %d %Y", time.gmtime()) + "\n") - if settings["extensions"]: + + if header_params["extensions"]: write_data(final_file, "# Extensions added to this file: " + ", ".join( - settings["extensions"]) + "\n") - write_data(final_file, "# Number of unique domains: " + "{:,}\n#\n".format( - settings["numberofrules"])) + header_params["extensions"]) + "\n") + + write_data(final_file, ("# Number of unique domains: " + + "{:,}\n#\n".format(header_params[ + "numberofrules"]))) write_data(final_file, "# Fetch the latest version of this file: " "https://raw.githubusercontent.com/" "StevenBlack/hosts/master/" + - path_join_robust(settings["outputsubfolder"], "") + "hosts\n") + path_join_robust(header_params["outputsubfolder"], + "") + "hosts\n") write_data(final_file, "# Project home page: https://github.com/" "StevenBlack/hosts\n#\n") write_data(final_file, "# ===============================" "================================\n") write_data(final_file, "\n") - if not settings["skipstatichosts"]: + if not header_params["skipstatichosts"]: write_data(final_file, "127.0.0.1 localhost\n") write_data(final_file, "127.0.0.1 localhost.localdomain\n") write_data(final_file, "127.0.0.1 local\n") @@ -585,12 +618,15 @@ def write_opening_header(final_file): write_data(final_file, "::1 localhost\n") write_data(final_file, "fe80::1%lo0 localhost\n") write_data(final_file, "0.0.0.0 0.0.0.0\n") + if platform.system() == "Linux": write_data(final_file, "127.0.1.1 " + socket.gethostname() + "\n") write_data(final_file, "127.0.0.53 " + socket.gethostname() + "\n") + write_data(final_file, "\n") preamble = path_join_robust(BASEDIR_PATH, "myhosts") + if os.path.isfile(preamble): with open(preamble, "r") as f: write_data(final_file, f.read()) @@ -598,22 +634,41 @@ def write_opening_header(final_file): final_file.write(file_contents) -def update_readme_data(): +def update_readme_data(readme_file, **readme_updates): """ Update the host and website information provided in the README JSON data. + + Parameters + ---------- + readme_file : str + The name of the README file to update. + readme_updates : kwargs + Dictionary providing additional JSON fields to update before + saving the data. Currently, those fields are: + + 1) extensions + 2) sourcesdata + 3) numberofrules + 4) outputsubfolder """ extensions_key = "base" - if settings["extensions"]: - extensions_key = "-".join(settings["extensions"]) + extensions = readme_updates["extensions"] - generation_data = {"location": path_join_robust( - settings["outputsubfolder"], ""), - "entries": settings["numberofrules"], - "sourcesdata": settings["sourcesdata"]} - settings["readmedata"][extensions_key] = generation_data - with open(settings["readmedatafilename"], "w") as f: - json.dump(settings["readmedata"], f) + if extensions: + extensions_key = "-".join(extensions) + + output_folder = readme_updates["outputsubfolder"] + generation_data = {"location": path_join_robust(output_folder, ""), + "entries": readme_updates["numberofrules"], + "sourcesdata": readme_updates["sourcesdata"]} + + with open(readme_file, "r") as f: + readme_data = json.load(f) + readme_data[extensions_key] = generation_data + + with open(readme_file, "w") as f: + json.dump(readme_data, f) def move_hosts_file_into_place(final_file): @@ -717,19 +772,25 @@ def flush_dns_cache(): print_failure("Unable to determine DNS management tool.") -def remove_old_hosts_file(): +def remove_old_hosts_file(backup): """ Remove the old hosts file. This is a hotfix because merging with an already existing hosts file leads to artifacts and duplicates. + + Parameters + ---------- + backup : boolean, default False + Whether or not to backup the existing hosts file. """ old_file_path = path_join_robust(BASEDIR_PATH, "hosts") - # create if already removed, so remove wont raise an error + + # Create if already removed, so remove won't raise an error. open(old_file_path, "a").close() - if settings["backup"]: + if backup: backup_file_path = path_join_robust(BASEDIR_PATH, "hosts-{}".format( time.strftime("%Y-%m-%d-%H-%M-%S")))