hosts/testUpdateHostsFile.py

2134 lines
70 KiB
Python
Raw Normal View History

2017-06-23 09:25:40 +02:00
#!/usr/bin/env python
# Script by gfyoung
# https://github.com/gfyoung
#
# Python script for testing updateHostFiles.py
import json
import locale
2017-06-23 09:25:40 +02:00
import os
import platform
import re
import shutil
import sys
import tempfile
import unittest
import unittest.mock as mock
from io import BytesIO, StringIO
2020-08-26 02:40:30 +02:00
import requests
import updateHostsFile
from updateHostsFile import (
Colors,
colorize,
display_exclusion_options,
domain_to_idna,
exclude_domain,
flush_dns_cache,
gather_custom_exclusions,
get_defaults,
2020-08-26 02:40:30 +02:00
get_file_by_url,
is_valid_user_provided_domain_format,
matches_exclusions,
move_hosts_file_into_place,
normalize_rule,
path_join_robust,
print_failure,
print_success,
prompt_for_exclusions,
prompt_for_flush_dns_cache,
prompt_for_move,
prompt_for_update,
query_yes_no,
recursive_glob,
remove_old_hosts_file,
sort_sources,
strip_rule,
supports_color,
update_all_sources,
update_readme_data,
update_sources_data,
write_data,
write_opening_header,
)
2017-06-23 09:25:40 +02:00
unicode = str
2017-06-23 09:25:40 +02:00
# Test Helper Objects
2017-06-23 09:25:40 +02:00
class Base(unittest.TestCase):
@staticmethod
def mock_property(name):
return mock.patch(name, new_callable=mock.PropertyMock)
@property
def sep(self):
if platform.system().lower() == "windows":
return "\\"
return os.sep
2017-06-23 09:25:40 +02:00
2017-08-03 05:50:43 +02:00
def assert_called_once(self, mock_method):
self.assertEqual(mock_method.call_count, 1)
2017-08-03 05:50:43 +02:00
2017-06-23 09:25:40 +02:00
class BaseStdout(Base):
def setUp(self):
sys.stdout = StringIO()
def tearDown(self):
sys.stdout.close()
sys.stdout = sys.__stdout__
class BaseMockDir(Base):
@property
def dir_count(self):
return len(os.listdir(self.test_dir))
def setUp(self):
self.test_dir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.test_dir)
# End Test Helper Objects
2017-06-23 09:25:40 +02:00
# Project Settings
class TestGetDefaults(Base):
def test_get_defaults(self):
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
updateHostsFile.BASEDIR_PATH = "foo"
actual = get_defaults()
expected = {
"numberofrules": 0,
"datapath": "foo" + self.sep + "data",
"freshen": True,
"replace": False,
"backup": False,
"skipstatichosts": False,
"keepdomaincomments": True,
"extensionspath": "foo" + self.sep + "extensions",
"extensions": [],
"nounifiedhosts": False,
"compress": False,
"minimise": False,
"outputsubfolder": "",
"hostfilename": "hosts",
"targetip": "0.0.0.0",
"sourcedatafilename": "update.json",
"sourcesdata": [],
"readmefilename": "readme.md",
"readmetemplate": ("foo" + self.sep + "readme_template.md"),
"readmedata": {},
"readmedatafilename": ("foo" + self.sep + "readmeData.json"),
"exclusionpattern": r"([a-zA-Z\d-]+\.){0,}",
"exclusionregexes": [],
"exclusions": [],
"commonexclusions": ["hulu.com"],
"blacklistfile": "foo" + self.sep + "blacklist",
"whitelistfile": "foo" + self.sep + "whitelist",
}
2017-06-23 09:25:40 +02:00
self.assertDictEqual(actual, expected)
2017-06-23 09:25:40 +02:00
# End Project Settings
class TestSortSources(Base):
def test_sort_sources_simple(self):
given = [
"sbc.io",
"example.com",
"github.com",
]
expected = ["example.com", "github.com", "sbc.io"]
actual = sort_sources(given)
self.assertEqual(actual, expected)
def test_live_data(self):
given = [
"data/KADhosts/update.json",
"data/someonewhocares.org/update.json",
"data/StevenBlack/update.json",
"data/adaway.org/update.json",
"data/URLHaus/update.json",
"data/UncheckyAds/update.json",
"data/add.2o7Net/update.json",
"data/mvps.org/update.json",
"data/add.Spam/update.json",
"data/add.Dead/update.json",
"data/malwaredomainlist.com/update.json",
"data/Badd-Boyz-Hosts/update.json",
"data/hostsVN/update.json",
"data/yoyo.org/update.json",
"data/add.Risk/update.json",
"data/tiuxo/update.json",
"extensions/gambling/update.json",
"extensions/porn/clefspeare13/update.json",
"extensions/porn/sinfonietta-snuff/update.json",
"extensions/porn/tiuxo/update.json",
"extensions/porn/sinfonietta/update.json",
"extensions/fakenews/update.json",
"extensions/social/tiuxo/update.json",
"extensions/social/sinfonietta/update.json",
]
expected = [
"data/StevenBlack/update.json",
"data/adaway.org/update.json",
"data/add.2o7Net/update.json",
"data/add.Dead/update.json",
"data/add.Risk/update.json",
"data/add.Spam/update.json",
"data/Badd-Boyz-Hosts/update.json",
"data/hostsVN/update.json",
"data/KADhosts/update.json",
"data/malwaredomainlist.com/update.json",
"data/mvps.org/update.json",
"data/someonewhocares.org/update.json",
"data/tiuxo/update.json",
"data/UncheckyAds/update.json",
"data/URLHaus/update.json",
"data/yoyo.org/update.json",
"extensions/fakenews/update.json",
"extensions/gambling/update.json",
"extensions/porn/clefspeare13/update.json",
"extensions/porn/sinfonietta/update.json",
"extensions/porn/sinfonietta-snuff/update.json",
"extensions/porn/tiuxo/update.json",
"extensions/social/sinfonietta/update.json",
"extensions/social/tiuxo/update.json",
]
actual = sort_sources(given)
self.assertEqual(actual, expected)
# Prompt the User
class TestPromptForUpdate(BaseStdout, BaseMockDir):
def setUp(self):
BaseStdout.setUp(self)
BaseMockDir.setUp(self)
def test_no_freshen_no_new_file(self):
hosts_file = os.path.join(self.test_dir, "hosts")
hosts_data = "This data should not be overwritten"
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
updateHostsFile.BASEDIR_PATH = self.test_dir
with open(hosts_file, "w") as f:
f.write(hosts_data)
for update_auto in (False, True):
dir_count = self.dir_count
prompt_for_update(freshen=False, update_auto=update_auto)
output = sys.stdout.getvalue()
self.assertEqual(output, "")
sys.stdout = StringIO()
self.assertEqual(self.dir_count, dir_count)
with open(hosts_file, "r") as f:
contents = f.read()
self.assertEqual(contents, hosts_data)
def test_no_freshen_new_file(self):
hosts_file = os.path.join(self.test_dir, "hosts")
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
updateHostsFile.BASEDIR_PATH = self.test_dir
dir_count = self.dir_count
prompt_for_update(freshen=False, update_auto=False)
output = sys.stdout.getvalue()
self.assertEqual(output, "")
sys.stdout = StringIO()
self.assertEqual(self.dir_count, dir_count + 1)
with open(hosts_file, "r") as f:
contents = f.read()
self.assertEqual(contents, "")
@mock.patch("builtins.open")
def test_no_freshen_fail_new_file(self, mock_open):
for exc in (IOError, OSError):
mock_open.side_effect = exc("failed open")
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
updateHostsFile.BASEDIR_PATH = self.test_dir
prompt_for_update(freshen=False, update_auto=False)
output = sys.stdout.getvalue()
expected = (
"ERROR: No 'hosts' file in the folder. "
"Try creating one manually."
)
self.assertIn(expected, output)
sys.stdout = StringIO()
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def test_freshen_no_update(self, _):
hosts_file = os.path.join(self.test_dir, "hosts")
hosts_data = "This data should not be overwritten"
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
updateHostsFile.BASEDIR_PATH = self.test_dir
with open(hosts_file, "w") as f:
f.write(hosts_data)
dir_count = self.dir_count
update_sources = prompt_for_update(freshen=True, update_auto=False)
self.assertFalse(update_sources)
output = sys.stdout.getvalue()
expected = "OK, we'll stick with what we've got locally."
self.assertIn(expected, output)
sys.stdout = StringIO()
self.assertEqual(self.dir_count, dir_count)
with open(hosts_file, "r") as f:
contents = f.read()
self.assertEqual(contents, hosts_data)
@mock.patch("updateHostsFile.query_yes_no", return_value=True)
def test_freshen_update(self, _):
hosts_file = os.path.join(self.test_dir, "hosts")
hosts_data = "This data should not be overwritten"
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
updateHostsFile.BASEDIR_PATH = self.test_dir
with open(hosts_file, "w") as f:
f.write(hosts_data)
dir_count = self.dir_count
for update_auto in (False, True):
update_sources = prompt_for_update(
freshen=True, update_auto=update_auto
)
self.assertTrue(update_sources)
output = sys.stdout.getvalue()
self.assertEqual(output, "")
sys.stdout = StringIO()
self.assertEqual(self.dir_count, dir_count)
with open(hosts_file, "r") as f:
contents = f.read()
self.assertEqual(contents, hosts_data)
def tearDown(self):
BaseStdout.tearDown(self)
class TestPromptForExclusions(BaseStdout):
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testSkipPrompt(self, mock_query):
gather_exclusions = prompt_for_exclusions(skip_prompt=True)
self.assertFalse(gather_exclusions)
output = sys.stdout.getvalue()
self.assertEqual(output, "")
mock_query.assert_not_called()
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testNoSkipPromptNoDisplay(self, mock_query):
gather_exclusions = prompt_for_exclusions(skip_prompt=False)
self.assertFalse(gather_exclusions)
output = sys.stdout.getvalue()
expected = "OK, we'll only exclude domains in the whitelist."
self.assertIn(expected, output)
2017-08-03 05:50:43 +02:00
self.assert_called_once(mock_query)
@mock.patch("updateHostsFile.query_yes_no", return_value=True)
def testNoSkipPromptDisplay(self, mock_query):
gather_exclusions = prompt_for_exclusions(skip_prompt=False)
self.assertTrue(gather_exclusions)
output = sys.stdout.getvalue()
self.assertEqual(output, "")
2017-08-03 05:50:43 +02:00
self.assert_called_once(mock_query)
class TestPromptForFlushDnsCache(Base):
@mock.patch("updateHostsFile.flush_dns_cache", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testFlushCache(self, mock_query, mock_flush):
for prompt_flush in (False, True):
prompt_for_flush_dns_cache(flush_cache=True, prompt_flush=prompt_flush)
mock_query.assert_not_called()
2017-08-03 05:50:43 +02:00
self.assert_called_once(mock_flush)
mock_query.reset_mock()
mock_flush.reset_mock()
@mock.patch("updateHostsFile.flush_dns_cache", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testNoFlushCacheNoPrompt(self, mock_query, mock_flush):
prompt_for_flush_dns_cache(flush_cache=False, prompt_flush=False)
mock_query.assert_not_called()
mock_flush.assert_not_called()
@mock.patch("updateHostsFile.flush_dns_cache", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testNoFlushCachePromptNoFlush(self, mock_query, mock_flush):
prompt_for_flush_dns_cache(flush_cache=False, prompt_flush=True)
2017-08-03 05:50:43 +02:00
self.assert_called_once(mock_query)
mock_flush.assert_not_called()
@mock.patch("updateHostsFile.flush_dns_cache", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=True)
def testNoFlushCachePromptFlush(self, mock_query, mock_flush):
prompt_for_flush_dns_cache(flush_cache=False, prompt_flush=True)
2017-08-03 05:50:43 +02:00
self.assert_called_once(mock_query)
self.assert_called_once(mock_flush)
class TestPromptForMove(Base):
def setUp(self):
Base.setUp(self)
self.final_file = "final.txt"
def prompt_for_move(self, **move_params):
return prompt_for_move(self.final_file, **move_params)
@mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testSkipStaticHosts(self, mock_query, mock_move):
for replace in (False, True):
for auto in (False, True):
move_file = self.prompt_for_move(
replace=replace, auto=auto, skipstatichosts=True
)
self.assertFalse(move_file)
mock_query.assert_not_called()
mock_move.assert_not_called()
mock_query.reset_mock()
mock_move.reset_mock()
@mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testReplaceNoSkipStaticHosts(self, mock_query, mock_move):
for auto in (False, True):
move_file = self.prompt_for_move(
replace=True, auto=auto, skipstatichosts=False
)
2022-07-06 18:47:38 +02:00
self.assertFalse(move_file)
mock_query.assert_not_called()
2017-08-03 05:50:43 +02:00
self.assert_called_once(mock_move)
mock_query.reset_mock()
mock_move.reset_mock()
@mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testAutoNoSkipStaticHosts(self, mock_query, mock_move):
for replace in (False, True):
move_file = self.prompt_for_move(
replace=replace, auto=True, skipstatichosts=True
)
self.assertFalse(move_file)
mock_query.assert_not_called()
mock_move.assert_not_called()
mock_query.reset_mock()
mock_move.reset_mock()
@mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
def testPromptNoMove(self, mock_query, mock_move):
move_file = self.prompt_for_move(
replace=False, auto=False, skipstatichosts=False
)
self.assertFalse(move_file)
2017-08-03 05:50:43 +02:00
self.assert_called_once(mock_query)
mock_move.assert_not_called()
@mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=True)
def testPromptMove(self, mock_query, mock_move):
move_file = self.prompt_for_move(
replace=False, auto=False, skipstatichosts=False
)
2022-07-06 18:47:38 +02:00
self.assertFalse(move_file)
2017-08-03 05:50:43 +02:00
self.assert_called_once(mock_query)
self.assert_called_once(mock_move)
# End Prompt the User
2017-06-23 09:25:40 +02:00
# Exclusion Logic
class TestDisplayExclusionsOptions(Base):
@mock.patch("updateHostsFile.query_yes_no", return_value=0)
@mock.patch("updateHostsFile.exclude_domain", return_value=None)
@mock.patch("updateHostsFile.gather_custom_exclusions", return_value=None)
def test_no_exclusions(self, mock_gather, mock_exclude, _):
common_exclusions = []
display_exclusion_options(common_exclusions, "foo", [])
mock_gather.assert_not_called()
mock_exclude.assert_not_called()
@mock.patch("updateHostsFile.query_yes_no", side_effect=[1, 1, 0])
@mock.patch("updateHostsFile.exclude_domain", return_value=None)
@mock.patch("updateHostsFile.gather_custom_exclusions", return_value=None)
def test_only_common_exclusions(self, mock_gather, mock_exclude, _):
common_exclusions = ["foo", "bar"]
display_exclusion_options(common_exclusions, "foo", [])
mock_gather.assert_not_called()
exclude_calls = [mock.call("foo", "foo", []), mock.call("bar", "foo", None)]
mock_exclude.assert_has_calls(exclude_calls)
@mock.patch("updateHostsFile.query_yes_no", side_effect=[0, 0, 1])
@mock.patch("updateHostsFile.exclude_domain", return_value=None)
@mock.patch("updateHostsFile.gather_custom_exclusions", return_value=None)
def test_gather_exclusions(self, mock_gather, mock_exclude, _):
common_exclusions = ["foo", "bar"]
display_exclusion_options(common_exclusions, "foo", [])
mock_exclude.assert_not_called()
self.assert_called_once(mock_gather)
@mock.patch("updateHostsFile.query_yes_no", side_effect=[1, 0, 1])
@mock.patch("updateHostsFile.exclude_domain", return_value=None)
@mock.patch("updateHostsFile.gather_custom_exclusions", return_value=None)
def test_mixture_gather_exclusions(self, mock_gather, mock_exclude, _):
common_exclusions = ["foo", "bar"]
display_exclusion_options(common_exclusions, "foo", [])
mock_exclude.assert_called_once_with("foo", "foo", [])
self.assert_called_once(mock_gather)
2017-06-23 09:25:40 +02:00
class TestGatherCustomExclusions(BaseStdout):
# Can only test in the invalid domain case
# because of the settings global variable.
@mock.patch("updateHostsFile.input", side_effect=["foo", "no"])
2021-07-13 20:14:25 +02:00
@mock.patch(
"updateHostsFile.is_valid_user_provided_domain_format", return_value=False
)
2017-06-23 09:25:40 +02:00
def test_basic(self, *_):
gather_custom_exclusions("foo", [])
2017-06-23 09:25:40 +02:00
expected = "Do you have more domains you want to enter? [Y/n]"
output = sys.stdout.getvalue()
self.assertIn(expected, output)
2018-08-10 17:32:58 +02:00
@mock.patch("updateHostsFile.input", side_effect=["foo", "yes", "bar", "no"])
2021-07-13 20:14:25 +02:00
@mock.patch(
"updateHostsFile.is_valid_user_provided_domain_format", return_value=False
)
2017-06-23 09:25:40 +02:00
def test_multiple(self, *_):
gather_custom_exclusions("foo", [])
2017-06-23 09:25:40 +02:00
expected = (
"Do you have more domains you want to enter? [Y/n] "
"Do you have more domains you want to enter? [Y/n]"
)
2017-06-23 09:25:40 +02:00
output = sys.stdout.getvalue()
self.assertIn(expected, output)
class TestExcludeDomain(Base):
def test_invalid_exclude_domain(self):
exclusion_regexes = []
exclusion_pattern = "*.com"
for domain in ["google.com", "hulu.com", "adaway.org"]:
self.assertRaises(
re.error, exclude_domain, domain, exclusion_pattern, exclusion_regexes
)
self.assertListEqual(exclusion_regexes, [])
def test_valid_exclude_domain(self):
exp_count = 0
expected_regexes = []
exclusion_regexes = []
2018-08-10 17:32:58 +02:00
exclusion_pattern = r"[a-z]\."
for domain in ["google.com", "hulu.com", "adaway.org"]:
self.assertEqual(len(exclusion_regexes), exp_count)
exclusion_regexes = exclude_domain(
domain, exclusion_pattern, exclusion_regexes
)
expected_regex = re.compile(exclusion_pattern + domain)
expected_regexes.append(expected_regex)
exp_count += 1
self.assertEqual(len(exclusion_regexes), exp_count)
self.assertListEqual(exclusion_regexes, expected_regexes)
class TestMatchesExclusions(Base):
def test_no_match_empty_list(self):
exclusion_regexes = []
for domain in [
"1.2.3.4 localhost",
"5.6.7.8 hulu.com",
"9.1.2.3 yahoo.com",
"4.5.6.7 cloudfront.net",
]:
self.assertFalse(matches_exclusions(domain, exclusion_regexes))
def test_no_match_list(self):
2018-08-10 17:32:58 +02:00
exclusion_regexes = [r".*\.org", r".*\.edu"]
exclusion_regexes = [re.compile(regex) for regex in exclusion_regexes]
for domain in [
"1.2.3.4 localhost",
"5.6.7.8 hulu.com",
"9.1.2.3 yahoo.com",
"4.5.6.7 cloudfront.net",
]:
self.assertFalse(matches_exclusions(domain, exclusion_regexes))
def test_match_list(self):
2018-08-10 17:32:58 +02:00
exclusion_regexes = [r".*\.com", r".*\.org", r".*\.edu"]
exclusion_regexes = [re.compile(regex) for regex in exclusion_regexes]
for domain in [
"5.6.7.8 hulu.com",
"9.1.2.3 yahoo.com",
"4.5.6.7 adaway.org",
"8.9.1.2 education.edu",
]:
self.assertTrue(matches_exclusions(domain, exclusion_regexes))
def test_match_raw_list(self):
exclusion_regexes = [r".*\.com", r".*\.org", r".*\.edu", r".*@.*"]
exclusion_regexes = [re.compile(regex) for regex in exclusion_regexes]
for domain in [
"hulu.com",
"yahoo.com",
"adaway.org",
"education.edu",
"a.stro.lo.gy@45.144.225.135",
]:
self.assertTrue(matches_exclusions(domain, exclusion_regexes))
def test_no_match_raw_list(self):
exclusion_regexes = [r".*\.org", r".*\.edu"]
exclusion_regexes = [re.compile(regex) for regex in exclusion_regexes]
for domain in [
"localhost",
"hulu.com",
"yahoo.com",
"cloudfront.net",
]:
self.assertFalse(matches_exclusions(domain, exclusion_regexes))
2017-06-23 09:25:40 +02:00
# End Exclusion Logic
# Update Logic
2017-08-13 13:22:42 +02:00
class TestUpdateSourcesData(Base):
def setUp(self):
Base.setUp(self)
self.data_path = "data"
self.extensions_path = "extensions"
self.source_data_filename = "update.json"
self.update_kwargs = dict(
datapath=self.data_path,
extensionspath=self.extensions_path,
sourcedatafilename=self.source_data_filename,
nounifiedhosts=False,
)
2017-08-13 13:22:42 +02:00
def update_sources_data(self, sources_data, extensions):
return update_sources_data(
sources_data[:], extensions=extensions, **self.update_kwargs
)
2017-08-13 13:22:42 +02:00
@mock.patch("updateHostsFile.recursive_glob", return_value=[])
@mock.patch("updateHostsFile.path_join_robust", return_value="dirpath")
@mock.patch("builtins.open", return_value=mock.Mock())
2017-08-13 13:22:42 +02:00
def test_no_update(self, mock_open, mock_join_robust, _):
extensions = []
sources_data = [{"source": "source1.txt"}, {"source": "source2.txt"}]
new_sources_data = self.update_sources_data(sources_data, extensions)
self.assertEqual(new_sources_data, sources_data)
mock_join_robust.assert_not_called()
mock_open.assert_not_called()
extensions = [".json", ".txt"]
new_sources_data = self.update_sources_data(sources_data, extensions)
self.assertEqual(new_sources_data, sources_data)
join_calls = [
mock.call(self.extensions_path, ".json"),
mock.call(self.extensions_path, ".txt"),
]
2017-08-13 13:22:42 +02:00
mock_join_robust.assert_has_calls(join_calls)
mock_open.assert_not_called()
@mock.patch(
"updateHostsFile.recursive_glob",
side_effect=[[], ["update1.txt", "update2.txt"]],
)
2017-08-13 13:22:42 +02:00
@mock.patch("json.load", return_value={"mock_source": "mock_source.ext"})
@mock.patch("builtins.open", return_value=mock.Mock())
2017-08-13 13:22:42 +02:00
@mock.patch("updateHostsFile.path_join_robust", return_value="dirpath")
def test_update_only_extensions(self, mock_join_robust, *_):
extensions = [".json"]
sources_data = [{"source": "source1.txt"}, {"source": "source2.txt"}]
new_sources_data = self.update_sources_data(sources_data, extensions)
expected = sources_data + [{"mock_source": "mock_source.ext"}] * 2
self.assertEqual(new_sources_data, expected)
self.assert_called_once(mock_join_robust)
@mock.patch(
"updateHostsFile.recursive_glob",
side_effect=[["update1.txt", "update2.txt"], ["update3.txt", "update4.txt"]],
)
@mock.patch(
"json.load",
side_effect=[
{"mock_source": "mock_source.txt"},
{"mock_source": "mock_source2.txt"},
{"mock_source": "mock_source3.txt"},
{"mock_source": "mock_source4.txt"},
],
)
@mock.patch("builtins.open", return_value=mock.Mock())
2017-08-13 13:22:42 +02:00
@mock.patch("updateHostsFile.path_join_robust", return_value="dirpath")
def test_update_both_pathways(self, mock_join_robust, *_):
extensions = [".json"]
sources_data = [{"source": "source1.txt"}, {"source": "source2.txt"}]
new_sources_data = self.update_sources_data(sources_data, extensions)
expected = sources_data + [
{"mock_source": "mock_source.txt"},
{"mock_source": "mock_source2.txt"},
{"mock_source": "mock_source3.txt"},
{"mock_source": "mock_source4.txt"},
]
2017-08-13 13:22:42 +02:00
self.assertEqual(new_sources_data, expected)
self.assert_called_once(mock_join_robust)
class TestUpdateAllSources(BaseStdout):
def setUp(self):
BaseStdout.setUp(self)
self.source_data_filename = "data.json"
self.host_filename = "hosts.txt"
@mock.patch("builtins.open")
@mock.patch("updateHostsFile.recursive_glob", return_value=[])
def test_no_sources(self, _, mock_open):
update_all_sources(self.source_data_filename, self.host_filename)
mock_open.assert_not_called()
@mock.patch("builtins.open", return_value=mock.Mock())
@mock.patch("json.load", return_value={"url": "example.com"})
@mock.patch("updateHostsFile.recursive_glob", return_value=["foo"])
@mock.patch("updateHostsFile.write_data", return_value=0)
@mock.patch("updateHostsFile.get_file_by_url", return_value="file_data")
def test_one_source(self, mock_get, mock_write, *_):
update_all_sources(self.source_data_filename, self.host_filename)
self.assert_called_once(mock_write)
self.assert_called_once(mock_get)
output = sys.stdout.getvalue()
expected = "Updating source from example.com"
self.assertIn(expected, output)
@mock.patch("builtins.open", return_value=mock.Mock())
@mock.patch("json.load", return_value={"url": "example.com"})
@mock.patch("updateHostsFile.recursive_glob", return_value=["foo"])
@mock.patch("updateHostsFile.write_data", return_value=0)
@mock.patch("updateHostsFile.get_file_by_url", return_value=Exception("fail"))
def test_source_fail(self, mock_get, mock_write, *_):
update_all_sources(self.source_data_filename, self.host_filename)
mock_write.assert_not_called()
self.assert_called_once(mock_get)
output = sys.stdout.getvalue()
expecteds = [
"Updating source from example.com",
"Error in updating source: example.com",
]
for expected in expecteds:
self.assertIn(expected, output)
@mock.patch("builtins.open", return_value=mock.Mock())
@mock.patch(
"json.load", side_effect=[{"url": "example.com"}, {"url": "example2.com"}]
)
@mock.patch("updateHostsFile.recursive_glob", return_value=["foo", "bar"])
@mock.patch("updateHostsFile.write_data", return_value=0)
@mock.patch(
"updateHostsFile.get_file_by_url", side_effect=[Exception("fail"), "file_data"]
)
def test_sources_fail_succeed(self, mock_get, mock_write, *_):
update_all_sources(self.source_data_filename, self.host_filename)
self.assert_called_once(mock_write)
get_calls = [mock.call("example.com"), mock.call("example2.com")]
mock_get.assert_has_calls(get_calls)
output = sys.stdout.getvalue()
expecteds = [
"Updating source from example.com",
"Error in updating source: example.com",
"Updating source from example2.com",
]
for expected in expecteds:
self.assertIn(expected, output)
# End Update Logic
2017-06-23 09:25:40 +02:00
# File Logic
class TestNormalizeRule(BaseStdout):
def test_no_match(self):
kwargs = dict(target_ip="0.0.0.0", keep_domain_comments=False)
# Note: "Bare"- Domains are accepted. IP are excluded.
for rule in [
"128.0.0.1",
"::1",
"0.0.0.0 128.0.0.2",
"0.1.2.3 foo/bar",
"0.3.4.5 example.org/hello/world",
"0.0.0.0 https",
"0.0.0.0 https..",
]:
self.assertEqual(normalize_rule(rule, **kwargs), (None, None))
2017-06-23 09:25:40 +02:00
output = sys.stdout.getvalue()
sys.stdout = StringIO()
expected = "==>" + rule + "<=="
self.assertIn(expected, output)
def test_mixed_cases(self):
for rule, expected_target in (
("tWiTTer.cOM", "twitter.com"),
("goOgLe.Com", "google.com"),
("FoO.bAR.edu", "foo.bar.edu"),
):
expected = (expected_target, "0.0.0.0 " + expected_target + "\n")
actual = normalize_rule(
rule, target_ip="0.0.0.0", keep_domain_comments=False
)
self.assertEqual(actual, expected)
# Nothing gets printed if there's a match.
output = sys.stdout.getvalue()
self.assertEqual(output, "")
sys.stdout = StringIO()
def test_no_comments(self):
for target_ip in ("0.0.0.0", "127.0.0.1", "8.8.8.8"):
rule = "127.0.0.1 1.google.com foo"
expected = ("1.google.com", str(target_ip) + " 1.google.com\n")
actual = normalize_rule(
rule, target_ip=target_ip, keep_domain_comments=False
)
self.assertEqual(actual, expected)
# Nothing gets printed if there's a match.
output = sys.stdout.getvalue()
self.assertEqual(output, "")
sys.stdout = StringIO()
def test_with_comments(self):
for target_ip in ("0.0.0.0", "127.0.0.1", "8.8.8.8"):
for comment in ("foo", "bar", "baz"):
rule = "127.0.0.1 1.google.co.uk " + comment
expected = (
"1.google.co.uk",
(str(target_ip) + " 1.google.co.uk # " + comment + "\n"),
)
actual = normalize_rule(
rule, target_ip=target_ip, keep_domain_comments=True
)
self.assertEqual(actual, expected)
# Nothing gets printed if there's a match.
output = sys.stdout.getvalue()
self.assertEqual(output, "")
sys.stdout = StringIO()
def test_two_ips(self):
for target_ip in ("0.0.0.0", "127.0.0.1", "8.8.8.8"):
rule = "127.0.0.1 11.22.33.44 foo"
actual = normalize_rule(
rule, target_ip=target_ip, keep_domain_comments=False
)
self.assertEqual(actual, (None, None))
output = sys.stdout.getvalue()
expected = "==>" + rule + "<=="
self.assertIn(expected, output)
sys.stdout = StringIO()
def test_no_comment_raw(self):
2023-09-03 13:17:24 +02:00
for rule in (
"twitter.com",
"google.com",
"foo.bar.edu",
"www.example-foo.bar.edu",
"www.example-3045.foobar.com",
"www.example.xn--p1ai"
):
expected = (rule, "0.0.0.0 " + rule + "\n")
2017-06-23 09:25:40 +02:00
actual = normalize_rule(
rule, target_ip="0.0.0.0", keep_domain_comments=False
)
self.assertEqual(actual, expected)
# Nothing gets printed if there's a match.
output = sys.stdout.getvalue()
2017-06-23 09:25:40 +02:00
self.assertEqual(output, "")
sys.stdout = StringIO()
def test_with_comments_raw(self):
for target_ip in ("0.0.0.0", "127.0.0.1", "8.8.8.8"):
for comment in ("foo", "bar", "baz"):
rule = "1.google.co.uk " + comment
expected = (
"1.google.co.uk",
(str(target_ip) + " 1.google.co.uk # " + comment + "\n"),
)
actual = normalize_rule(
rule, target_ip=target_ip, keep_domain_comments=True
)
self.assertEqual(actual, expected)
# Nothing gets printed if there's a match.
output = sys.stdout.getvalue()
self.assertEqual(output, "")
sys.stdout = StringIO()
class TestStripRule(Base):
2017-06-23 09:25:40 +02:00
def test_strip_exactly_two(self):
for line in [
"0.0.0.0 twitter.com",
"127.0.0.1 facebook.com",
"8.8.8.8 google.com",
"1.2.3.4 foo.bar.edu",
]:
2017-06-23 09:25:40 +02:00
output = strip_rule(line)
self.assertEqual(output, line)
def test_strip_more_than_two(self):
comment = " # comments here galore"
for line in [
"0.0.0.0 twitter.com",
"127.0.0.1 facebook.com",
"8.8.8.8 google.com",
"1.2.3.4 foo.bar.edu",
]:
output = strip_rule(line + comment)
self.assertEqual(output, line + comment)
2017-06-23 09:25:40 +02:00
def test_strip_raw(self):
for line in [
"twitter.com",
"facebook.com",
"google.com",
"foo.bar.edu",
]:
output = strip_rule(line)
self.assertEqual(output, line)
def test_strip_raw_with_comment(self):
comment = " # comments here galore"
for line in [
"twitter.com",
"facebook.com",
"google.com",
"foo.bar.edu",
]:
output = strip_rule(line + comment)
self.assertEqual(output, line + comment)
2017-06-23 09:25:40 +02:00
class TestWriteOpeningHeader(BaseMockDir):
def setUp(self):
super(TestWriteOpeningHeader, self).setUp()
self.final_file = BytesIO()
def test_missing_keyword(self):
kwargs = dict(
extensions="", outputsubfolder="", numberofrules=5, skipstatichosts=False, nounifiedhosts=False
)
for k in kwargs.keys():
bad_kwargs = kwargs.copy()
bad_kwargs.pop(k)
self.assertRaises(
KeyError, write_opening_header, self.final_file, **bad_kwargs
)
def test_basic(self):
kwargs = dict(
extensions="", outputsubfolder="", numberofrules=5, skipstatichosts=True, nounifiedhosts=False
)
write_opening_header(self.final_file, **kwargs)
contents = self.final_file.getvalue()
contents = contents.decode("UTF-8")
# Expected contents.
for expected in (
"# This hosts file is a merged collection",
2020-02-22 14:24:28 +01:00
"# with a dash of crowd sourcing via GitHub",
"# Number of unique domains: {count}".format(count=kwargs["numberofrules"]),
"Fetch the latest version of this file:",
"Project home page: https://github.com/StevenBlack/hosts",
):
self.assertIn(expected, contents)
# Expected non-contents.
for expected in (
"# Extensions added to this file:",
"127.0.0.1 localhost",
"127.0.0.1 local",
"127.0.0.53",
"127.0.1.1",
):
self.assertNotIn(expected, contents)
def test_basic_include_static_hosts(self):
kwargs = dict(
extensions="", outputsubfolder="", numberofrules=5, skipstatichosts=False, nounifiedhosts=False
)
with self.mock_property("platform.system") as obj:
obj.return_value = "Windows"
write_opening_header(self.final_file, **kwargs)
contents = self.final_file.getvalue()
contents = contents.decode("UTF-8")
# Expected contents.
for expected in (
"127.0.0.1 local",
"127.0.0.1 localhost",
"# This hosts file is a merged collection",
2020-02-22 14:24:28 +01:00
"# with a dash of crowd sourcing via GitHub",
"# Number of unique domains: {count}".format(count=kwargs["numberofrules"]),
"Fetch the latest version of this file:",
"Project home page: https://github.com/StevenBlack/hosts",
):
self.assertIn(expected, contents)
# Expected non-contents.
for expected in ("# Extensions added to this file:", "127.0.0.53", "127.0.1.1"):
self.assertNotIn(expected, contents)
def test_basic_include_static_hosts_linux(self):
kwargs = dict(
extensions="", outputsubfolder="", numberofrules=5, skipstatichosts=False, nounifiedhosts=False
)
with self.mock_property("platform.system") as system:
system.return_value = "Linux"
with self.mock_property("socket.gethostname") as hostname:
hostname.return_value = "steven-hosts"
write_opening_header(self.final_file, **kwargs)
contents = self.final_file.getvalue()
contents = contents.decode("UTF-8")
# Expected contents.
for expected in (
"127.0.1.1",
"127.0.0.53",
"steven-hosts",
"127.0.0.1 local",
"127.0.0.1 localhost",
"# This hosts file is a merged collection",
2020-02-22 14:24:28 +01:00
"# with a dash of crowd sourcing via GitHub",
"# Number of unique domains: {count}".format(count=kwargs["numberofrules"]),
"Fetch the latest version of this file:",
"Project home page: https://github.com/StevenBlack/hosts",
):
self.assertIn(expected, contents)
# Expected non-contents.
expected = "# Extensions added to this file:"
self.assertNotIn(expected, contents)
def test_extensions(self):
kwargs = dict(
extensions=["epsilon", "gamma", "mu", "phi"],
outputsubfolder="",
numberofrules=5,
skipstatichosts=True,
nounifiedhosts=False,
)
write_opening_header(self.final_file, **kwargs)
contents = self.final_file.getvalue()
contents = contents.decode("UTF-8")
# Expected contents.
for expected in (
", ".join(kwargs["extensions"]),
"# Extensions added to this file:",
"# This hosts file is a merged collection",
2020-02-22 14:24:28 +01:00
"# with a dash of crowd sourcing via GitHub",
"# Number of unique domains: {count}".format(count=kwargs["numberofrules"]),
"Fetch the latest version of this file:",
"Project home page: https://github.com/StevenBlack/hosts",
):
self.assertIn(expected, contents)
# Expected non-contents.
for expected in (
"127.0.0.1 localhost",
"127.0.0.1 local",
"127.0.0.53",
"127.0.1.1",
):
self.assertNotIn(expected, contents)
def test_no_unified_hosts(self):
kwargs = dict(
extensions=["epsilon", "gamma"],
outputsubfolder="",
numberofrules=5,
skipstatichosts=True,
nounifiedhosts=True,
)
write_opening_header(self.final_file, **kwargs)
contents = self.final_file.getvalue()
contents = contents.decode("UTF-8")
# Expected contents.
for expected in (
", ".join(kwargs["extensions"]),
"# The unified hosts file was not used while generating this file.",
"# Extensions used to generate this file:",
"# This hosts file is a merged collection",
"# with a dash of crowd sourcing via GitHub",
"# Number of unique domains: {count}".format(count=kwargs["numberofrules"]),
"Fetch the latest version of this file:",
"Project home page: https://github.com/StevenBlack/hosts",
):
self.assertIn(expected, contents)
# Expected non-contents.
for expected in (
"127.0.0.1 localhost",
"127.0.0.1 local",
"127.0.0.53",
"127.0.1.1",
):
self.assertNotIn(expected, contents)
def _check_preamble(self, check_copy):
hosts_file = os.path.join(self.test_dir, "myhosts")
hosts_file += ".example" if check_copy else ""
with open(hosts_file, "w") as f:
f.write("peter-piper-picked-a-pepper")
kwargs = dict(
extensions="", outputsubfolder="", numberofrules=5, skipstatichosts=True, nounifiedhosts=False
)
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
updateHostsFile.BASEDIR_PATH = self.test_dir
write_opening_header(self.final_file, **kwargs)
contents = self.final_file.getvalue()
contents = contents.decode("UTF-8")
# Expected contents.
for expected in (
"peter-piper-picked-a-pepper",
"# This hosts file is a merged collection",
2020-02-22 14:24:28 +01:00
"# with a dash of crowd sourcing via GitHub",
"# Number of unique domains: {count}".format(count=kwargs["numberofrules"]),
"Fetch the latest version of this file:",
"Project home page: https://github.com/StevenBlack/hosts",
):
self.assertIn(expected, contents)
# Expected non-contents.
for expected in (
"# Extensions added to this file:",
"127.0.0.1 localhost",
"127.0.0.1 local",
"127.0.0.53",
"127.0.1.1",
):
self.assertNotIn(expected, contents)
def test_preamble_exists(self):
self._check_preamble(True)
def test_preamble_copy(self):
self._check_preamble(False)
def tearDown(self):
super(TestWriteOpeningHeader, self).tearDown()
self.final_file.close()
class TestUpdateReadmeData(BaseMockDir):
def setUp(self):
super(TestUpdateReadmeData, self).setUp()
self.readme_file = os.path.join(self.test_dir, "readmeData.json")
def test_missing_keyword(self):
kwargs = dict(
extensions="", outputsubfolder="", numberofrules="", sourcesdata="", nounifiedhosts=False
)
for k in kwargs.keys():
bad_kwargs = kwargs.copy()
bad_kwargs.pop(k)
self.assertRaises(
KeyError, update_readme_data, self.readme_file, **bad_kwargs
)
def test_add_fields(self):
with open(self.readme_file, "w") as f:
json.dump({"foo": "bar"}, f)
kwargs = dict(
extensions=None, outputsubfolder="foo", numberofrules=5, sourcesdata="hosts", nounifiedhosts=False
)
update_readme_data(self.readme_file, **kwargs)
if platform.system().lower() == "windows":
sep = "/"
else:
sep = self.sep
expected = {
"base": {"location": "foo" + sep, 'no_unified_hosts': False, "sourcesdata": "hosts", "entries": 5},
"foo": "bar",
}
with open(self.readme_file, "r") as f:
actual = json.load(f)
self.assertEqual(actual, expected)
def test_modify_fields(self):
with open(self.readme_file, "w") as f:
json.dump({"base": "soprano"}, f)
kwargs = dict(
extensions=None, outputsubfolder="foo", numberofrules=5, sourcesdata="hosts", nounifiedhosts=False
)
update_readme_data(self.readme_file, **kwargs)
if platform.system().lower() == "windows":
sep = "/"
else:
sep = self.sep
expected = {
"base": {"location": "foo" + sep, 'no_unified_hosts': False, "sourcesdata": "hosts", "entries": 5},
}
with open(self.readme_file, "r") as f:
actual = json.load(f)
self.assertEqual(actual, expected)
def test_set_extensions(self):
with open(self.readme_file, "w") as f:
json.dump({}, f)
kwargs = dict(
extensions=["com", "org"],
outputsubfolder="foo",
numberofrules=5,
sourcesdata="hosts",
nounifiedhosts=False,
)
update_readme_data(self.readme_file, **kwargs)
if platform.system().lower() == "windows":
sep = "/"
else:
sep = self.sep
expected = {
"com-org": {"location": "foo" + sep, 'no_unified_hosts': False, "sourcesdata": "hosts", "entries": 5}
}
with open(self.readme_file, "r") as f:
actual = json.load(f)
self.assertEqual(actual, expected)
def test_set_no_unified_hosts(self):
with open(self.readme_file, "w") as f:
json.dump({}, f)
kwargs = dict(
extensions=["com", "org"],
outputsubfolder="foo",
numberofrules=5,
sourcesdata="hosts",
nounifiedhosts=True,
)
update_readme_data(self.readme_file, **kwargs)
if platform.system().lower() == "windows":
sep = "/"
else:
sep = self.sep
expected = {
"com-org-only": {"location": "foo" + sep, 'no_unified_hosts': True, "sourcesdata": "hosts", "entries": 5}
}
with open(self.readme_file, "r") as f:
actual = json.load(f)
self.assertEqual(actual, expected)
2017-06-23 09:25:40 +02:00
class TestMoveHostsFile(BaseStdout):
@mock.patch("os.path.abspath", side_effect=lambda f: f)
def test_move_hosts_no_name(self, _): # TODO: Create test which tries to move actual file
2022-07-06 18:47:38 +02:00
with self.mock_property("platform.system") as obj:
obj.return_value = "foo"
2017-06-23 09:25:40 +02:00
mock_file = mock.Mock(name="foo")
move_hosts_file_into_place(mock_file)
expected = "does not exist"
2017-06-23 09:25:40 +02:00
output = sys.stdout.getvalue()
self.assertIn(expected, output)
2017-06-23 09:25:40 +02:00
@mock.patch("os.path.abspath", side_effect=lambda f: f)
def test_move_hosts_windows(self, _):
2022-07-06 18:47:38 +02:00
with self.mock_property("platform.system") as obj:
obj.return_value = "Windows"
2017-06-23 09:25:40 +02:00
mock_file = mock.Mock(name="foo")
move_hosts_file_into_place(mock_file)
2022-07-06 18:47:38 +02:00
expected = ""
2017-06-23 09:25:40 +02:00
output = sys.stdout.getvalue()
self.assertIn(expected, output)
@mock.patch("os.path.abspath", side_effect=lambda f: f)
@mock.patch("subprocess.call", return_value=0)
def test_move_hosts_posix(self, *_): # TODO: create test which tries to move an actual file
2022-07-06 18:47:38 +02:00
with self.mock_property("platform.system") as obj:
obj.return_value = "Linux"
2017-06-23 09:25:40 +02:00
mock_file = mock.Mock(name="foo")
move_hosts_file_into_place(mock_file)
expected = "does not exist."
2017-06-23 09:25:40 +02:00
output = sys.stdout.getvalue()
self.assertIn(expected, output)
@mock.patch("os.path.abspath", side_effect=lambda f: f)
@mock.patch("subprocess.call", return_value=1)
def test_move_hosts_posix_fail(self, *_):
2022-07-06 18:47:38 +02:00
with self.mock_property("platform.system") as obj:
obj.return_value = "Linux"
2017-06-23 09:25:40 +02:00
mock_file = mock.Mock(name="foo")
move_hosts_file_into_place(mock_file)
expected = "does not exist."
2017-06-23 09:25:40 +02:00
output = sys.stdout.getvalue()
self.assertIn(expected, output)
class TestFlushDnsCache(BaseStdout):
@mock.patch("subprocess.call", return_value=0)
def test_flush_darwin(self, _):
with self.mock_property("platform.system") as obj:
obj.return_value = "Darwin"
flush_dns_cache()
expected = (
"Flushing the DNS cache to utilize new hosts "
"file...\nFlushing the DNS cache requires "
"administrative privileges. You might need to "
"enter your password."
)
2017-06-23 09:25:40 +02:00
output = sys.stdout.getvalue()
self.assertIn(expected, output)
@mock.patch("subprocess.call", return_value=1)
def test_flush_darwin_fail(self, _):
with self.mock_property("platform.system") as obj:
obj.return_value = "Darwin"
flush_dns_cache()
expected = "Flushing the DNS cache failed."
output = sys.stdout.getvalue()
self.assertIn(expected, output)
def test_flush_windows(self):
with self.mock_property("platform.system") as obj:
obj.return_value = "win32"
with self.mock_property("os.name"):
os.name = "nt"
flush_dns_cache()
expected = (
"Automatically flushing the DNS cache is "
"not yet supported.\nPlease copy and paste "
"the command 'ipconfig /flushdns' in "
"administrator command prompt after running "
"this script."
)
2017-06-23 09:25:40 +02:00
output = sys.stdout.getvalue()
self.assertIn(expected, output)
@mock.patch("os.path.isfile", return_value=False)
def test_flush_no_tool(self, _):
with self.mock_property("platform.system") as obj:
obj.return_value = "Linux"
with self.mock_property("os.name"):
os.name = "posix"
flush_dns_cache()
expected = "Unable to determine DNS management tool."
output = sys.stdout.getvalue()
self.assertIn(expected, output)
@mock.patch("os.path.isfile", side_effect=[True] + [False] * 11)
2017-06-23 09:25:40 +02:00
@mock.patch("subprocess.call", return_value=0)
def test_flush_posix(self, *_):
with self.mock_property("platform.system") as obj:
obj.return_value = "Linux"
with self.mock_property("os.name"):
os.name = "posix"
flush_dns_cache()
expected = "Flushing the DNS cache by restarting nscd succeeded"
2017-06-23 09:25:40 +02:00
output = sys.stdout.getvalue()
self.assertIn(expected, output)
@mock.patch("os.path.isfile", side_effect=[True] + [False] * 11)
2017-06-23 09:25:40 +02:00
@mock.patch("subprocess.call", return_value=1)
def test_flush_posix_fail(self, *_):
with self.mock_property("platform.system") as obj:
obj.return_value = "Linux"
with self.mock_property("os.name"):
os.name = "posix"
flush_dns_cache()
expected = "Flushing the DNS cache by restarting nscd failed"
2017-06-23 09:25:40 +02:00
output = sys.stdout.getvalue()
self.assertIn(expected, output)
@mock.patch("os.path.isfile", side_effect=[True, False, False, True] + [False] * 10)
2020-07-10 22:43:09 +02:00
@mock.patch("subprocess.call", side_effect=[1, 0, 0])
2017-06-23 09:25:40 +02:00
def test_flush_posix_fail_then_succeed(self, *_):
with self.mock_property("platform.system") as obj:
obj.return_value = "Linux"
with self.mock_property("os.name"):
os.name = "posix"
flush_dns_cache()
output = sys.stdout.getvalue()
for expected in [
("Flushing the DNS cache by restarting nscd failed"),
(
"Flushing the DNS cache by restarting "
"NetworkManager.service succeeded"
),
]:
2017-06-23 09:25:40 +02:00
self.assertIn(expected, output)
2020-04-08 08:07:40 +02:00
class TestRemoveOldHostsFile(BaseMockDir):
def setUp(self):
super(TestRemoveOldHostsFile, self).setUp()
self.hosts_file = "hosts"
self.full_hosts_path = os.path.join(self.test_dir, "hosts")
def test_remove_hosts_file(self):
old_dir_count = self.dir_count
remove_old_hosts_file(self.test_dir, self.hosts_file, backup=False)
2020-04-08 07:18:35 +02:00
new_dir_count = old_dir_count + 1
self.assertEqual(self.dir_count, new_dir_count)
with open(self.full_hosts_path, "r") as f:
2020-04-08 07:18:35 +02:00
contents = f.read()
self.assertEqual(contents, "")
def test_remove_hosts_file_exists(self):
with open(self.full_hosts_path, "w") as f:
f.write("foo")
old_dir_count = self.dir_count
remove_old_hosts_file(self.test_dir, self.hosts_file, backup=False)
2020-04-08 07:18:35 +02:00
new_dir_count = old_dir_count
self.assertEqual(self.dir_count, new_dir_count)
with open(self.full_hosts_path, "r") as f:
2020-04-08 07:18:35 +02:00
contents = f.read()
self.assertEqual(contents, "")
2020-04-24 21:13:41 +02:00
@mock.patch("time.strftime", return_value="new")
def test_remove_hosts_file_backup(self, _):
with open(self.full_hosts_path, "w") as f:
f.write("foo")
old_dir_count = self.dir_count
remove_old_hosts_file(self.test_dir, self.hosts_file, backup=True)
2020-04-08 07:18:35 +02:00
new_dir_count = old_dir_count + 1
self.assertEqual(self.dir_count, new_dir_count)
with open(self.full_hosts_path, "r") as f:
2020-04-08 07:18:35 +02:00
contents = f.read()
self.assertEqual(contents, "")
new_hosts_file = self.full_hosts_path + "-new"
2020-04-08 07:18:35 +02:00
with open(new_hosts_file, "r") as f:
contents = f.read()
self.assertEqual(contents, "foo")
2017-06-23 09:25:40 +02:00
# End File Logic
class DomainToIDNA(Base):
def __init__(self, *args, **kwargs):
super(DomainToIDNA, self).__init__(*args, **kwargs)
self.domains = [b"\xc9\xa2oogle.com", b"www.huala\xc3\xb1e.cl"]
self.expected_domains = ["xn--oogle-wmc.com", "www.xn--hualae-0wa.cl"]
def test_empty_line(self):
data = ["", "\r", "\n"]
for empty in data:
expected = empty
actual = domain_to_idna(empty)
self.assertEqual(actual, expected)
def test_commented_line(self):
data = "# Hello World"
expected = data
actual = domain_to_idna(data)
self.assertEqual(actual, expected)
def test_simple_line(self):
# Test with a space as separator.
for i in range(len(self.domains)):
data = (b"0.0.0.0 " + self.domains[i]).decode("utf-8")
expected = "0.0.0.0 " + self.expected_domains[i]
actual = domain_to_idna(data)
self.assertEqual(actual, expected)
# Test with a tabulation as separator.
for i in range(len(self.domains)):
data = (b"0.0.0.0\t" + self.domains[i]).decode("utf-8")
expected = "0.0.0.0\t" + self.expected_domains[i]
actual = domain_to_idna(data)
self.assertEqual(actual, expected)
def test_multiple_space_as_separator(self):
# Test with multiple space as separator.
for i in range(len(self.domains)):
data = (b"0.0.0.0 " + self.domains[i]).decode("utf-8")
expected = "0.0.0.0 " + self.expected_domains[i]
actual = domain_to_idna(data)
self.assertEqual(actual, expected)
def test_multiple_tabs_as_separator(self):
# Test with multiple tabls as separator.
for i in range(len(self.domains)):
data = (b"0.0.0.0\t\t\t\t\t\t" + self.domains[i]).decode("utf-8")
expected = "0.0.0.0\t\t\t\t\t\t" + self.expected_domains[i]
actual = domain_to_idna(data)
self.assertEqual(actual, expected)
def test_line_with_comment_at_the_end(self):
# Test with a space as separator.
for i in range(len(self.domains)):
data = (b"0.0.0.0 " + self.domains[i] + b" # Hello World").decode("utf-8")
expected = "0.0.0.0 " + self.expected_domains[i] + " # Hello World"
actual = domain_to_idna(data)
self.assertEqual(actual, expected)
# Test with a tabulation as separator.
for i in range(len(self.domains)):
data = (b"0.0.0.0\t" + self.domains[i] + b" # Hello World").decode("utf-8")
expected = "0.0.0.0\t" + self.expected_domains[i] + " # Hello World"
actual = domain_to_idna(data)
self.assertEqual(actual, expected)
# Test with tabulation as separator of domain and comment.
for i in range(len(self.domains)):
data = (b"0.0.0.0\t" + self.domains[i] + b"\t # Hello World").decode(
"utf-8"
)
expected = "0.0.0.0\t" + self.expected_domains[i] + "\t # Hello World"
actual = domain_to_idna(data)
self.assertEqual(actual, expected)
# Test with space as separator of domain and tabulation as separator
# of comments.
for i in range(len(self.domains)):
data = (b"0.0.0.0 " + self.domains[i] + b" \t # Hello World").decode(
"utf-8"
)
expected = "0.0.0.0 " + self.expected_domains[i] + " \t # Hello World"
actual = domain_to_idna(data)
self.assertEqual(actual, expected)
# Test with multiple space as separator of domain and space and
# tabulation as separator or comments.
for i in range(len(self.domains)):
data = (b"0.0.0.0 " + self.domains[i] + b" \t # Hello World").decode(
"utf-8"
)
expected = "0.0.0.0 " + self.expected_domains[i] + " \t # Hello World"
actual = domain_to_idna(data)
self.assertEqual(actual, expected)
# Test with multiple tabulations as separator of domain and space and
# tabulation as separator or comments.
for i, domain in enumerate(self.domains):
data = (b"0.0.0.0\t\t\t" + domain + b" \t # Hello World").decode(
"utf-8"
)
expected = "0.0.0.0\t\t\t" + self.expected_domains[i] + " \t # Hello World"
actual = domain_to_idna(data)
self.assertEqual(actual, expected)
def test_line_without_prefix(self):
for i in range(len(self.domains)):
data = self.domains[i].decode("utf-8")
expected = self.expected_domains[i]
actual = domain_to_idna(data)
self.assertEqual(actual, expected)
2020-08-26 02:40:30 +02:00
class GetFileByUrl(BaseStdout):
def test_basic(self):
raw_resp_content = "hello, ".encode("ascii") + "world".encode("utf-8")
resp_obj = requests.Response()
resp_obj.__setstate__({"_content": raw_resp_content})
expected = "hello, world"
with mock.patch("requests.get", return_value=resp_obj):
actual = get_file_by_url("www.test-url.com")
self.assertEqual(expected, actual)
def test_with_idna(self):
raw_resp_content = b"www.huala\xc3\xb1e.cl"
resp_obj = requests.Response()
resp_obj.__setstate__({"_content": raw_resp_content})
expected = "www.xn--hualae-0wa.cl"
with mock.patch("requests.get", return_value=resp_obj):
actual = get_file_by_url("www.test-url.com")
self.assertEqual(expected, actual)
2020-08-28 09:17:06 +02:00
def test_connect_unknown_domain(self):
2021-03-31 14:57:56 +02:00
test_url = (
"http://doesnotexist.google.com" # leads to exception: ConnectionError
)
with mock.patch(
"requests.get", side_effect=requests.exceptions.ConnectionError
):
2020-09-03 02:25:01 +02:00
return_value = get_file_by_url(test_url)
2020-08-28 09:17:06 +02:00
self.assertIsNone(return_value)
printed_output = sys.stdout.getvalue()
2021-03-31 14:57:56 +02:00
self.assertEqual(
printed_output, "Error retrieving data from {}\n".format(test_url)
)
2020-08-28 09:17:06 +02:00
def test_invalid_url(self):
2020-09-03 02:25:01 +02:00
test_url = "http://fe80::5054:ff:fe5a:fc0" # leads to exception: InvalidURL
2021-03-31 14:57:56 +02:00
with mock.patch(
"requests.get", side_effect=requests.exceptions.ConnectionError
):
2020-09-03 02:25:01 +02:00
return_value = get_file_by_url(test_url)
2020-08-28 09:17:06 +02:00
self.assertIsNone(return_value)
printed_output = sys.stdout.getvalue()
2021-03-31 14:57:56 +02:00
self.assertEqual(
printed_output, "Error retrieving data from {}\n".format(test_url)
)
2020-08-28 09:17:06 +02:00
2020-08-26 02:40:30 +02:00
2017-06-23 09:25:40 +02:00
class TestWriteData(Base):
def test_write_basic(self):
f = BytesIO()
data = "foo"
write_data(f, data)
expected = b"foo"
actual = f.getvalue()
self.assertEqual(actual, expected)
def test_write_unicode(self):
f = BytesIO()
data = u"foo"
write_data(f, data)
expected = b"foo"
actual = f.getvalue()
self.assertEqual(actual, expected)
class TestQueryYesOrNo(BaseStdout):
def test_invalid_default(self):
for invalid_default in ["foo", "bar", "baz", 1, 2, 3]:
self.assertRaises(ValueError, query_yes_no, "?", invalid_default)
@mock.patch("updateHostsFile.input", side_effect=["yes"] * 3)
2017-06-23 09:25:40 +02:00
def test_valid_default(self, _):
for valid_default, expected in [
(None, "[y/n]"),
("yes", "[Y/n]"),
("no", "[y/N]"),
]:
2017-06-23 09:25:40 +02:00
self.assertTrue(query_yes_no("?", valid_default))
output = sys.stdout.getvalue()
sys.stdout = StringIO()
self.assertIn(expected, output)
@mock.patch("updateHostsFile.input", side_effect=([""] * 2))
2017-06-23 09:25:40 +02:00
def test_use_valid_default(self, _):
for valid_default in ["yes", "no"]:
expected = valid_default == "yes"
2017-06-23 09:25:40 +02:00
actual = query_yes_no("?", valid_default)
self.assertEqual(actual, expected)
2018-08-10 17:32:58 +02:00
@mock.patch("updateHostsFile.input", side_effect=["no", "NO", "N", "n", "No", "nO"])
2017-06-23 09:25:40 +02:00
def test_valid_no(self, _):
self.assertFalse(query_yes_no("?", None))
@mock.patch(
"updateHostsFile.input",
side_effect=["yes", "YES", "Y", "yeS", "y", "YeS", "yES", "YEs"],
)
2017-06-23 09:25:40 +02:00
def test_valid_yes(self, _):
self.assertTrue(query_yes_no("?", None))
2018-08-10 17:32:58 +02:00
@mock.patch("updateHostsFile.input", side_effect=["foo", "yes", "foo", "no"])
2017-06-23 09:25:40 +02:00
def test_invalid_then_valid(self, _):
expected = "Please respond with 'yes' or 'no'"
# The first time, we respond "yes"
self.assertTrue(query_yes_no("?", None))
output = sys.stdout.getvalue()
self.assertIn(expected, output)
sys.stdout = StringIO()
# The second time, we respond "no"
self.assertFalse(query_yes_no("?", None))
output = sys.stdout.getvalue()
self.assertIn(expected, output)
class TestIsValidUserProvidedDomainFormat(BaseStdout):
2017-06-23 09:25:40 +02:00
def test_empty_domain(self):
self.assertFalse(is_valid_user_provided_domain_format(""))
2017-06-23 09:25:40 +02:00
output = sys.stdout.getvalue()
expected = "You didn't enter a domain. Try again."
2022-03-07 08:08:23 +01:00
self.assertIn(expected, output)
2017-06-23 09:25:40 +02:00
def test_invalid_domain(self):
2020-05-27 21:59:34 +02:00
expected = "Do not include www.domain.com or http(s)://domain.com. Try again."
for invalid_domain in [
"www.subdomain.domain",
"https://github.com",
"http://www.google.com",
]:
self.assertFalse(is_valid_user_provided_domain_format(invalid_domain))
2017-06-23 09:25:40 +02:00
output = sys.stdout.getvalue()
sys.stdout = StringIO()
self.assertIn(expected, output)
def test_valid_domain(self):
for valid_domain in ["github.com", "travis.org", "twitter.com"]:
self.assertTrue(is_valid_user_provided_domain_format(valid_domain))
2017-06-23 09:25:40 +02:00
output = sys.stdout.getvalue()
sys.stdout = StringIO()
self.assertEqual(output, "")
def mock_walk(stem):
"""
Mock method for `os.walk`.
Please refer to the documentation of `os.walk` for information about
the provided parameters.
"""
files = [
"foo.txt",
"bar.bat",
"baz.py",
"foo/foo.c",
"foo/bar.doc",
"foo/baz/foo.py",
"bar/foo/baz.c",
"bar/bar/foo.bat",
]
2017-06-23 09:25:40 +02:00
if stem == ".":
stem = ""
matches = []
for f in files:
if not stem or f.startswith(stem + "/"):
matches.append(("", "_", [f]))
return matches
class TestRecursiveGlob(Base):
@staticmethod
def sorted_recursive_glob(stem, file_pattern):
actual = recursive_glob(stem, file_pattern)
actual.sort()
return actual
@mock.patch("os.walk", side_effect=mock_walk)
def test_all_match(self, _):
with self.mock_property("sys.version_info"):
sys.version_info = (2, 6)
expected = [
"bar.bat",
"bar/bar/foo.bat",
"bar/foo/baz.c",
"baz.py",
"foo.txt",
"foo/bar.doc",
"foo/baz/foo.py",
"foo/foo.c",
]
2017-06-23 09:25:40 +02:00
actual = self.sorted_recursive_glob("*", "*")
self.assertListEqual(actual, expected)
expected = ["bar/bar/foo.bat", "bar/foo/baz.c"]
actual = self.sorted_recursive_glob("bar", "*")
self.assertListEqual(actual, expected)
expected = ["foo/bar.doc", "foo/baz/foo.py", "foo/foo.c"]
actual = self.sorted_recursive_glob("foo", "*")
self.assertListEqual(actual, expected)
@mock.patch("os.walk", side_effect=mock_walk)
def test_file_ending(self, _):
with self.mock_property("sys.version_info"):
sys.version_info = (2, 6)
expected = ["foo/baz/foo.py"]
actual = self.sorted_recursive_glob("foo", "*.py")
self.assertListEqual(actual, expected)
expected = ["bar/foo/baz.c", "foo/foo.c"]
actual = self.sorted_recursive_glob("*", "*.c")
self.assertListEqual(actual, expected)
expected = []
actual = self.sorted_recursive_glob("*", ".xlsx")
self.assertListEqual(actual, expected)
def mock_path_join(*_):
"""
Mock method for `os.path.join`.
Please refer to the documentation of `os.path.join` for information about
the provided parameters.
"""
raise UnicodeDecodeError("foo", b"", 1, 5, "foo")
class TestPathJoinRobust(Base):
def test_basic(self):
expected = "path1"
actual = path_join_robust("path1")
self.assertEqual(actual, expected)
actual = path_join_robust(u"path1")
self.assertEqual(actual, expected)
def test_join(self):
for i in range(1, 4):
paths = ["pathNew"] * i
expected = "path1" + (self.sep + "pathNew") * i
actual = path_join_robust("path1", *paths)
self.assertEqual(actual, expected)
def test_join_unicode(self):
for i in range(1, 4):
paths = [u"pathNew"] * i
expected = "path1" + (self.sep + "pathNew") * i
actual = path_join_robust("path1", *paths)
self.assertEqual(actual, expected)
@mock.patch("os.path.join", side_effect=mock_path_join)
def test_join_error(self, _):
self.assertRaises(locale.Error, path_join_robust, "path")
# Colors
class TestSupportsColor(BaseStdout):
def test_posix(self):
with self.mock_property("sys.platform"):
sys.platform = "Linux"
with self.mock_property("sys.stdout.isatty") as obj:
obj.return_value = True
self.assertTrue(supports_color())
def test_pocket_pc(self):
with self.mock_property("sys.platform"):
sys.platform = "Pocket PC"
self.assertFalse(supports_color())
def test_windows_no_ansicon(self):
with self.mock_property("sys.platform"):
sys.platform = "win32"
with self.mock_property("os.environ"):
os.environ = []
self.assertFalse(supports_color())
def test_windows_ansicon(self):
with self.mock_property("sys.platform"):
sys.platform = "win32"
with self.mock_property("os.environ"):
os.environ = ["ANSICON"]
with self.mock_property("sys.stdout.isatty") as obj:
obj.return_value = True
self.assertTrue(supports_color())
def test_no_isatty_attribute(self):
with self.mock_property("sys.platform"):
sys.platform = "Linux"
with self.mock_property("sys.stdout"):
sys.stdout = list()
self.assertFalse(supports_color())
def test_no_isatty(self):
with self.mock_property("sys.platform"):
sys.platform = "Linux"
with self.mock_property("sys.stdout.isatty") as obj:
obj.return_value = False
self.assertFalse(supports_color())
class TestColorize(Base):
def setUp(self):
self.text = "house"
self.colors = ["red", "orange", "yellow", "green", "blue", "purple"]
2017-06-23 09:25:40 +02:00
@mock.patch("updateHostsFile.supports_color", return_value=False)
def test_colorize_no_support(self, _):
for color in self.colors:
expected = self.text
actual = colorize(self.text, color)
self.assertEqual(actual, expected)
@mock.patch("updateHostsFile.supports_color", return_value=True)
def test_colorize_support(self, _):
for color in self.colors:
expected = color + self.text + Colors.ENDC
actual = colorize(self.text, color)
self.assertEqual(actual, expected)
class TestPrintSuccess(BaseStdout):
def setUp(self):
super(TestPrintSuccess, self).setUp()
self.text = "house"
@mock.patch("updateHostsFile.supports_color", return_value=False)
def test_print_success_no_support(self, _):
print_success(self.text)
expected = self.text + "\n"
actual = sys.stdout.getvalue()
self.assertEqual(actual, expected)
@mock.patch("updateHostsFile.supports_color", return_value=True)
def test_print_success_support(self, _):
print_success(self.text)
expected = Colors.SUCCESS + self.text + Colors.ENDC + "\n"
actual = sys.stdout.getvalue()
self.assertEqual(actual, expected)
class TestPrintFailure(BaseStdout):
def setUp(self):
super(TestPrintFailure, self).setUp()
self.text = "house"
@mock.patch("updateHostsFile.supports_color", return_value=False)
def test_print_failure_no_support(self, _):
print_failure(self.text)
expected = self.text + "\n"
actual = sys.stdout.getvalue()
self.assertEqual(actual, expected)
@mock.patch("updateHostsFile.supports_color", return_value=True)
def test_print_failure_support(self, _):
print_failure(self.text)
expected = Colors.FAIL + self.text + Colors.ENDC + "\n"
actual = sys.stdout.getvalue()
self.assertEqual(actual, expected)
2017-06-23 09:25:40 +02:00
# End Helper Functions
if __name__ == "__main__":
unittest.main()