From d0ed86eeadd5c7a95b5398b52950286cb7b0892d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicki=20K=C5=99=C3=AD=C5=BEek?= <nicki@isc.org> Date: Wed, 29 Jan 2025 15:28:55 +0100 Subject: [PATCH 1/5] Improve handling of ConnectionError In a ConnectionError happens, try and capture the information for which resolver it occurred. While it may be just a regular one-off network blip, it may also indicate a failure mode in one of the target resolvers. That is especially true if the same connection error happens repeatedly for a single resolver. Prior to this change, the information about the resolver was lacking. This made it hard to assess whether the connection errors were caused by some network issue, or if it's likely that one of the target resolver is at fault. --- respdiff/sendrecv.py | 100 +++++++++++++++++++++++++++++++++---------- 1 file changed, 77 insertions(+), 23 deletions(-) diff --git a/respdiff/sendrecv.py b/respdiff/sendrecv.py index 5920ecb..862f038 100644 --- a/respdiff/sendrecv.py +++ b/respdiff/sendrecv.py @@ -10,6 +10,7 @@ threads or processes. Make sure not to break this compatibility. from argparse import Namespace +import logging import random import signal import selectors @@ -18,7 +19,7 @@ import ssl import struct import time import threading -from typing import Any, Dict, List, Mapping, Sequence, Tuple # noqa: type hints +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple # noqa: type hints import dns.inet import dns.message @@ -56,6 +57,15 @@ class TcpDnsLengthError(ConnectionError): pass +class ResolverConnectionError(ConnectionError): + def __init__(self, resolver: ResolverID, message: str): + super().__init__(message) + self.resolver = resolver + + def __str__(self): + return f"[{self.resolver}] {super().__str__()}" + + def module_init(args: Namespace) -> None: global __resolvers global __max_timeouts @@ -261,27 +271,39 @@ def _recv_msg(sock: Socket, isstream: IsStreamFlag) -> WireFormat: return sock.recv(length) -def _send_recv_parallel( - dgram: WireFormat, # DNS message suitable for UDP transport +def _create_sendbuf(dnsdata: WireFormat, isstream: IsStreamFlag) -> bytes: + if isstream: # prepend length, RFC 1035 section 4.2.2 + length = len(dnsdata) + return struct.pack('!H', length) + dnsdata + return dnsdata + + +def _get_resolver_from_sock(sockets: ResolverSockets, sock: Socket) -> Optional[ResolverID]: + for resolver, resolver_sock, _ in sockets: + if sock == resolver_sock: + return resolver + return None + + +def _recv_from_resolvers( selector: Selector, sockets: ResolverSockets, + msgid: bytes, timeout: float - ) -> Tuple[Mapping[ResolverID, DNSReply], ReinitFlag]: - replies = {} # type: Dict[ResolverID, DNSReply] - streammsg = None + ) -> Tuple[Dict[ResolverID, DNSReply], bool]: + + def raise_resolver_exc(sock, exc): + resolver = _get_resolver_from_sock(sockets, sock) + if resolver is not None: + raise ResolverConnectionError(resolver, str(exc)) from exc + raise exc + + start_time = time.perf_counter() end_time = start_time + timeout - for _, sock, isstream in sockets: - if isstream: # prepend length, RFC 1035 section 4.2.2 - if not streammsg: - length = len(dgram) - streammsg = struct.pack('!H', length) + dgram - sock.sendall(streammsg) - else: - sock.sendall(dgram) - - # receive replies + replies = {} # type: Dict[ResolverID, DNSReply] reinit = False + while len(replies) != len(sockets): remaining_time = end_time - time.perf_counter() if remaining_time <= 0: @@ -293,17 +315,40 @@ def _send_recv_parallel( sock = key.fileobj try: wire = _recv_msg(sock, isstream) - except TcpDnsLengthError: + except TcpDnsLengthError as exc: if name in replies: # we have a reply already, most likely TCP FIN reinit = True selector.unregister(sock) continue # receive answers from other parties - raise # no reply -> raise error - # assert len(wire) > 14 - if dgram[0:2] != wire[0:2]: + # no reply -> raise error + raise_resolver_exc(sock, exc) + except ConnectionError as exc: + raise_resolver_exc(sock, exc) + if msgid != wire[0:2]: continue # wrong msgid, this might be a delayed answer - ignore it replies[name] = DNSReply(wire, time.perf_counter() - start_time) + return replies, reinit + + +def _send_recv_parallel( + dgram: WireFormat, # DNS message suitable for UDP transport + selector: Selector, + sockets: ResolverSockets, + timeout: float + ) -> Tuple[Mapping[ResolverID, DNSReply], ReinitFlag]: + # send queries + for resolver, sock, isstream in sockets: + sendbuf = _create_sendbuf(dgram, isstream) + try: + sock.sendall(sendbuf) + except ConnectionError as exc: + raise ResolverConnectionError(resolver, str(exc)) from exc + + # receive replies + msgid = dgram[0:2] + replies, reinit = _recv_from_resolvers(selector, sockets, msgid, timeout) + # set missing replies as timeout for resolver, *_ in sockets: # type: ignore # python/mypy#465 if resolver not in replies: @@ -317,6 +362,7 @@ def send_recv_parallel( timeout: float, reinit_on_tcpfin: bool = True ) -> Mapping[ResolverID, DNSReply]: + problematic = [] for _ in range(CONN_RESET_RETRIES + 1): try: # get sockets and selector selector = __worker_state.selector @@ -331,9 +377,17 @@ def send_recv_parallel( worker_deinit() worker_reinit() return replies - except (TcpDnsLengthError, ConnectionError): # most likely TCP RST + # The following exception handling is typically triggered by TCP RST, + # but it could also indicate some issue with one of the resolvers. + except ResolverConnectionError as exc: + problematic.append(exc.resolver) + logging.debug(exc) + worker_deinit() # re-establish connection + worker_reinit() + except ConnectionError as exc: # most likely TCP RST + logging.debug(exc) worker_deinit() # re-establish connection worker_reinit() raise RuntimeError( - 'ConnectionError received {} times in a row, exiting!'.format( - CONN_RESET_RETRIES + 1)) + 'ConnectionError received {} times in a row ({}), exiting!'.format( + CONN_RESET_RETRIES + 1, ', '.join(problematic))) -- GitLab From 33f0e4af6a79969a95ec4db6a925e1a93f1256ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicki=20K=C5=99=C3=AD=C5=BEek?= <nicki@isc.org> Date: Wed, 5 Feb 2025 15:42:10 +0100 Subject: [PATCH 2/5] Format code with black This entire commit is just the result of running `black` formatter on the python files - no manual changes were made. --- contrib/job_manager/create.py | 238 +++++++++++--------- contrib/job_manager/submit.py | 169 ++++++++------ diffrepro.py | 27 ++- diffsum.py | 102 ++++++--- distrcmp.py | 46 ++-- dnsviz.py | 37 ++- histogram.py | 129 ++++++----- msgdiff.py | 54 ++--- orchestrator.py | 27 ++- qexport.py | 78 ++++--- qprep.py | 86 ++++--- respdiff/__init__.py | 4 +- respdiff/blacklist.py | 49 ++-- respdiff/cfg.py | 137 +++++++----- respdiff/cli.py | 395 +++++++++++++++++++-------------- respdiff/database.py | 195 +++++++++------- respdiff/dataformat.py | 238 +++++++++++--------- respdiff/dnsviz.py | 35 +-- respdiff/match.py | 159 +++++++------ respdiff/qstats.py | 33 +-- respdiff/query.py | 68 +++--- respdiff/repro.py | 96 ++++---- respdiff/sendrecv.py | 111 +++++---- respdiff/stats.py | 94 ++++---- statcmp.py | 153 ++++++++----- sumcmp.py | 68 +++--- sumstat.py | 18 +- tests/lmdb/create_test_lmdb.py | 65 +++--- tests/test_cli.py | 27 ++- tests/test_data.py | 139 +++++++----- tests/test_database.py | 159 +++++++------ tests/test_qprep_pcap.py | 91 +++++--- tests/test_qprep_text.py | 51 +++-- utils/dns2txt.py | 2 +- utils/normalize_names.py | 16 +- 35 files changed, 1988 insertions(+), 1408 deletions(-) diff --git a/contrib/job_manager/create.py b/contrib/job_manager/create.py index fb2970f..08d809b 100755 --- a/contrib/job_manager/create.py +++ b/contrib/job_manager/create.py @@ -14,8 +14,8 @@ import yaml DIR_PATH = os.path.dirname(os.path.realpath(__file__)) -TEST_CASE_DIR = os.path.join(DIR_PATH, 'test_cases') -FILES_DIR = os.path.join(DIR_PATH, 'files') +TEST_CASE_DIR = os.path.join(DIR_PATH, "test_cases") +FILES_DIR = os.path.join(DIR_PATH, "files") def prepare_dir(directory: str, clean: bool = False) -> None: @@ -29,33 +29,32 @@ def prepare_dir(directory: str, clean: bool = False) -> None: except FileExistsError as e: raise RuntimeError( 'Directory "{}" already exists! Use -l label / --clean or (re)move the ' - 'directory manually to resolve the issue.'.format(directory)) from e + "directory manually to resolve the issue.".format(directory) + ) from e -def copy_file(name: str, destdir: str, destname: str = ''): +def copy_file(name: str, destdir: str, destname: str = ""): if not destname: destname = name - shutil.copy( - os.path.join(FILES_DIR, name), - os.path.join(destdir, destname)) + shutil.copy(os.path.join(FILES_DIR, name), os.path.join(destdir, destname)) def create_file_from_template( - name: str, - data: Mapping[str, Any], - destdir: str, - destname: str = '', - executable=False - ) -> None: + name: str, + data: Mapping[str, Any], + destdir: str, + destname: str = "", + executable=False, +) -> None: env = jinja2.Environment(loader=jinja2.FileSystemLoader(FILES_DIR)) template = env.get_template(name) rendered = template.render(**data) if not destname: - assert name[-3:] == '.j2' + assert name[-3:] == ".j2" destname = os.path.basename(name)[:-3] dest = os.path.join(destdir, destname) - with open(dest, 'w', encoding='UTF-8') as fh: + with open(dest, "w", encoding="UTF-8") as fh: fh.write(rendered) if executable: @@ -64,86 +63,99 @@ def create_file_from_template( def load_test_case_config(test_case: str) -> Dict[str, Any]: - path = os.path.join(TEST_CASE_DIR, test_case + '.yaml') - with open(path, 'r', encoding='UTF-8') as f: + path = os.path.join(TEST_CASE_DIR, test_case + ".yaml") + with open(path, "r", encoding="UTF-8") as f: return yaml.safe_load(f) def create_resolver_configs(directory: str, config: Dict[str, Any]): - for name, resolver in config['resolvers'].items(): - resolver['name'] = name - resolver['verbose'] = config['verbose'] - if resolver['type'] == 'knot-resolver': - dockerfile_dir = os.path.join(directory, 'docker-knot-resolver') + for name, resolver in config["resolvers"].items(): + resolver["name"] = name + resolver["verbose"] = config["verbose"] + if resolver["type"] == "knot-resolver": + dockerfile_dir = os.path.join(directory, "docker-knot-resolver") if not os.path.exists(dockerfile_dir): os.makedirs(dockerfile_dir) - copy_file('Dockerfile.knot-resolver', dockerfile_dir, 'Dockerfile') - copy_file('kresd.entrypoint.sh', dockerfile_dir) + copy_file("Dockerfile.knot-resolver", dockerfile_dir, "Dockerfile") + copy_file("kresd.entrypoint.sh", dockerfile_dir) create_file_from_template( - 'kresd.conf.j2', resolver, directory, name + '.conf') - elif resolver['type'] == 'unbound': + "kresd.conf.j2", resolver, directory, name + ".conf" + ) + elif resolver["type"] == "unbound": create_file_from_template( - 'unbound.conf.j2', resolver, directory, name + '.conf') - copy_file('cert.pem', directory) - copy_file('key.pem', directory) - copy_file('root.keys', directory) - elif resolver['type'] == 'bind': + "unbound.conf.j2", resolver, directory, name + ".conf" + ) + copy_file("cert.pem", directory) + copy_file("key.pem", directory) + copy_file("root.keys", directory) + elif resolver["type"] == "bind": create_file_from_template( - 'named.conf.j2', resolver, directory, name + '.conf') - copy_file('rfc1912.zones', directory) - copy_file('bind.keys', directory) + "named.conf.j2", resolver, directory, name + ".conf" + ) + copy_file("rfc1912.zones", directory) + copy_file("bind.keys", directory) else: raise NotImplementedError( - "unknown resolver type: '{}'".format(resolver['type'])) + "unknown resolver type: '{}'".format(resolver["type"]) + ) def create_resperf_files(directory: str, config: Dict[str, Any]): - create_file_from_template('run_resperf.sh.j2', config, directory, executable=True) - create_file_from_template('docker-compose.yaml.j2', config, directory) + create_file_from_template("run_resperf.sh.j2", config, directory, executable=True) + create_file_from_template("docker-compose.yaml.j2", config, directory) create_resolver_configs(directory, config) def create_distrotest_files(directory: str, config: Dict[str, Any]): - create_file_from_template('run_distrotest.sh.j2', config, directory, executable=True) + create_file_from_template( + "run_distrotest.sh.j2", config, directory, executable=True + ) def create_respdiff_files(directory: str, config: Dict[str, Any]): - create_file_from_template('run_respdiff.sh.j2', config, directory, executable=True) - create_file_from_template('restart-all.sh.j2', config, directory, executable=True) - create_file_from_template('docker-compose.yaml.j2', config, directory) + create_file_from_template("run_respdiff.sh.j2", config, directory, executable=True) + create_file_from_template("restart-all.sh.j2", config, directory, executable=True) + create_file_from_template("docker-compose.yaml.j2", config, directory) create_resolver_configs(directory, config) # omit resolvers without respdiff section from respdiff.cfg - config['resolvers'] = { - name: res for name, res - in config['resolvers'].items() - if 'respdiff' in res} - create_file_from_template('respdiff.cfg.j2', config, directory) + config["resolvers"] = { + name: res for name, res in config["resolvers"].items() if "respdiff" in res + } + create_file_from_template("respdiff.cfg.j2", config, directory) - if config['respdiff_stats']: # copy optional stats file + if config["respdiff_stats"]: # copy optional stats file try: - shutil.copyfile(config['respdiff_stats'], os.path.join(directory, 'stats.json')) + shutil.copyfile( + config["respdiff_stats"], os.path.join(directory, "stats.json") + ) except FileNotFoundError as e: raise RuntimeError( - "Statistics file not found: {}".format(config['respdiff_stats'])) from e + "Statistics file not found: {}".format(config["respdiff_stats"]) + ) from e def create_template_files(directory: str, config: Dict[str, Any]): - if 'respdiff' in config: + if "respdiff" in config: create_respdiff_files(directory, config) - elif 'resperf' in config: + elif "resperf" in config: create_resperf_files(directory, config) - elif 'distrotest' in config: + elif "distrotest" in config: create_distrotest_files(directory, config) -def get_test_case_list(nameglob: str = '') -> List[str]: +def get_test_case_list(nameglob: str = "") -> List[str]: # test cases end with '*.jXXX' implying a number of jobs (less -> longer runtime) # return them in ascending order, so more time consuming test cases run first - return sorted([ - os.path.splitext(os.path.basename(fname))[0] - for fname in glob.glob(os.path.join(TEST_CASE_DIR, '{}*.yaml'.format(nameglob)))], - key=lambda x: x.split('.')[-1]) # sort by job count + return sorted( + [ + os.path.splitext(os.path.basename(fname))[0] + for fname in glob.glob( + os.path.join(TEST_CASE_DIR, "{}*.yaml".format(nameglob)) + ) + ], + key=lambda x: x.split(".")[-1], + ) # sort by job count def create_jobs(args: argparse.Namespace) -> None: @@ -158,20 +170,20 @@ def create_jobs(args: argparse.Namespace) -> None: git_sha = args.sha_or_tag commit_dir = git_sha if args.label is not None: - if ' ' in args.label: - raise RuntimeError('Label may not contain spaces.') - commit_dir += '-' + args.label + if " " in args.label: + raise RuntimeError("Label may not contain spaces.") + commit_dir += "-" + args.label for test_case in test_cases: config = load_test_case_config(test_case) - config['git_sha'] = git_sha - config['knot_branch'] = args.knot_branch - config['verbose'] = args.verbose - config['asan'] = args.asan - config['log_keys'] = args.log_keys - config['respdiff_stats'] = args.respdiff_stats - config['obs_repo'] = args.obs_repo - config['package'] = args.package + config["git_sha"] = git_sha + config["knot_branch"] = args.knot_branch + config["verbose"] = args.verbose + config["asan"] = args.asan + config["log_keys"] = args.log_keys + config["respdiff_stats"] = args.respdiff_stats + config["obs_repo"] = args.obs_repo + config["package"] = args.package directory = os.path.join(args.jobs_dir, commit_dir, test_case) prepare_dir(directory, clean=args.clean) @@ -184,55 +196,83 @@ def create_jobs(args: argparse.Namespace) -> None: def main() -> None: parser = argparse.ArgumentParser( - description="Prepare files for docker respdiff job") + description="Prepare files for docker respdiff job" + ) parser.add_argument( - 'sha_or_tag', type=str, - help="Knot Resolver git commit or tag to use (don't use branch!)") + "sha_or_tag", + type=str, + help="Knot Resolver git commit or tag to use (don't use branch!)", + ) parser.add_argument( - '-a', '--all', default='shortlist', - help="Create all test cases which start with expression (default: shortlist)") + "-a", + "--all", + default="shortlist", + help="Create all test cases which start with expression (default: shortlist)", + ) parser.add_argument( - '-t', choices=get_test_case_list(), - help="Create only the specified test case") + "-t", choices=get_test_case_list(), help="Create only the specified test case" + ) parser.add_argument( - '-l', '--label', - help="Assign label for easier job identification and isolation") + "-l", "--label", help="Assign label for easier job identification and isolation" + ) parser.add_argument( - '--clean', action='store_true', - help="Remove target directory if it already exists (use with caution!)") + "--clean", + action="store_true", + help="Remove target directory if it already exists (use with caution!)", + ) parser.add_argument( - '--jobs-dir', default='/var/tmp/respdiff-jobs', - help="Directory with job collections (default: /var/tmp/respdiff-jobs)") + "--jobs-dir", + default="/var/tmp/respdiff-jobs", + help="Directory with job collections (default: /var/tmp/respdiff-jobs)", + ) parser.add_argument( - '--knot-branch', type=str, default='3.1', - help="Build knot-resolver against selected Knot DNS branch") + "--knot-branch", + type=str, + default="3.1", + help="Build knot-resolver against selected Knot DNS branch", + ) parser.add_argument( - '-v', '--verbose', action='store_true', - help="Capture verbose logs (kresd, unbound)") + "-v", + "--verbose", + action="store_true", + help="Capture verbose logs (kresd, unbound)", + ) parser.add_argument( - '--asan', action='store_true', - help="Build with Address Sanitizer") + "--asan", action="store_true", help="Build with Address Sanitizer" + ) parser.add_argument( - '--log-keys', action='store_true', - help="Log TLS session keys (kresd)") + "--log-keys", action="store_true", help="Log TLS session keys (kresd)" + ) parser.add_argument( - '--respdiff-stats', type=str, default='', - help=("Statistics file to generate extra respdiff report(s) with omitted " - "unstable/failing queries")) + "--respdiff-stats", + type=str, + default="", + help=( + "Statistics file to generate extra respdiff report(s) with omitted " + "unstable/failing queries" + ), + ) parser.add_argument( - '--obs-repo', type=str, default='knot-resolver-devel', - help=("OBS repository for distrotests (default: knot-resolver-devel)")) + "--obs-repo", + type=str, + default="knot-resolver-devel", + help=("OBS repository for distrotests (default: knot-resolver-devel)"), + ) parser.add_argument( - '--package', type=str, default='knot-resolver', - help=("package for distrotests (default: knot-resolver)")) + "--package", + type=str, + default="knot-resolver", + help=("package for distrotests (default: knot-resolver)"), + ) args = parser.parse_args() create_jobs(args) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig( - format='%(asctime)s %(levelname)8s %(message)s', level=logging.DEBUG) + format="%(asctime)s %(levelname)8s %(message)s", level=logging.DEBUG + ) try: main() except RuntimeError as exc: @@ -241,5 +281,5 @@ if __name__ == '__main__': sys.exit(1) except Exception as exc: logging.debug(traceback.format_exc()) - logging.critical('Unhandled code exception: %s', str(exc)) + logging.critical("Unhandled code exception: %s", str(exc)) sys.exit(2) diff --git a/contrib/job_manager/submit.py b/contrib/job_manager/submit.py index c9a75f7..44db1f5 100755 --- a/contrib/job_manager/submit.py +++ b/contrib/job_manager/submit.py @@ -19,9 +19,9 @@ JOB_STATUS_RUNNING = 2 def get_all_files(directory: str) -> List[str]: files = [] - for filename in glob.iglob('{}/**'.format(directory), recursive=True): + for filename in glob.iglob("{}/**".format(directory), recursive=True): # omit job artifacts (begins with j*) - if os.path.isfile(filename) and not os.path.basename(filename).startswith('j'): + if os.path.isfile(filename) and not os.path.basename(filename).startswith("j"): files.append(os.path.relpath(filename, directory)) return files @@ -30,61 +30,69 @@ def condor_submit(txn, priority: int) -> int: directory = os.getcwd() input_files = get_all_files(directory) - if 'run_respdiff.sh' in input_files: - executable = 'run_respdiff.sh' + if "run_respdiff.sh" in input_files: + executable = "run_respdiff.sh" output_files = [ - 'j$(Cluster).$(Process)_docker.txt', - 'j$(Cluster).$(Process)_report.json', - 'j$(Cluster).$(Process)_report.diffrepro.json', - 'j$(Cluster).$(Process)_report.txt', - 'j$(Cluster).$(Process)_report.diffrepro.txt', - 'j$(Cluster).$(Process)_histogram.tar.gz', - 'j$(Cluster).$(Process)_logs.tar.gz'] - if 'stats.json' in input_files: - output_files.extend([ - 'j$(Cluster).$(Process)_report.noref.json', - 'j$(Cluster).$(Process)_report.noref.txt', - 'j$(Cluster).$(Process)_report.diffrepro.noref.json', - 'j$(Cluster).$(Process)_report.diffrepro.noref.txt', - # 'j$(Cluster).$(Process)_dnsviz.json.gz', - # 'j$(Cluster).$(Process)_report.noref.dnsviz.json', - # 'j$(Cluster).$(Process)_report.noref.dnsviz.txt', - ]) - elif 'run_resperf.sh' in input_files: - executable = 'run_resperf.sh' + "j$(Cluster).$(Process)_docker.txt", + "j$(Cluster).$(Process)_report.json", + "j$(Cluster).$(Process)_report.diffrepro.json", + "j$(Cluster).$(Process)_report.txt", + "j$(Cluster).$(Process)_report.diffrepro.txt", + "j$(Cluster).$(Process)_histogram.tar.gz", + "j$(Cluster).$(Process)_logs.tar.gz", + ] + if "stats.json" in input_files: + output_files.extend( + [ + "j$(Cluster).$(Process)_report.noref.json", + "j$(Cluster).$(Process)_report.noref.txt", + "j$(Cluster).$(Process)_report.diffrepro.noref.json", + "j$(Cluster).$(Process)_report.diffrepro.noref.txt", + # 'j$(Cluster).$(Process)_dnsviz.json.gz', + # 'j$(Cluster).$(Process)_report.noref.dnsviz.json', + # 'j$(Cluster).$(Process)_report.noref.dnsviz.txt', + ] + ) + elif "run_resperf.sh" in input_files: + executable = "run_resperf.sh" output_files = [ - 'j$(Cluster).$(Process)_exitcode', - 'j$(Cluster).$(Process)_docker.txt', - 'j$(Cluster).$(Process)_resperf.txt', - 'j$(Cluster).$(Process)_logs.tar.gz'] - elif 'run_distrotest.sh' in input_files: - executable = 'run_distrotest.sh' + "j$(Cluster).$(Process)_exitcode", + "j$(Cluster).$(Process)_docker.txt", + "j$(Cluster).$(Process)_resperf.txt", + "j$(Cluster).$(Process)_logs.tar.gz", + ] + elif "run_distrotest.sh" in input_files: + executable = "run_distrotest.sh" output_files = [ - 'j$(Cluster).$(Process)_exitcode', - 'j$(Cluster).$(Process)_vagrant.log.txt'] + "j$(Cluster).$(Process)_exitcode", + "j$(Cluster).$(Process)_vagrant.log.txt", + ] else: raise RuntimeError( "The provided directory doesn't look like a respdiff/resperf job. " - "{}/run_*.sh is missing!".format(directory)) + "{}/run_*.sh is missing!".format(directory) + ) # create batch name from dir structure commit_dir_path, test_case = os.path.split(directory) _, commit_dir = os.path.split(commit_dir_path) - batch_name = commit_dir + '_' + test_case - - submit = Submit({ - 'priority': str(priority), - 'executable': executable, - 'arguments': '$(Cluster) $(Process)', - 'error': 'j$(Cluster).$(Process)_stderr.txt', - 'output': 'j$(Cluster).$(Process)_stdout.txt', - 'log': 'j$(Cluster).$(Process)_log.txt', - 'jobbatchname': batch_name, - 'should_transfer_files': 'YES', - 'when_to_transfer_output': 'ON_EXIT', - 'transfer_input_files': ', '.join(input_files), - 'transfer_output_files': ', '.join(output_files), - }) + batch_name = commit_dir + "_" + test_case + + submit = Submit( + { + "priority": str(priority), + "executable": executable, + "arguments": "$(Cluster) $(Process)", + "error": "j$(Cluster).$(Process)_stderr.txt", + "output": "j$(Cluster).$(Process)_stdout.txt", + "log": "j$(Cluster).$(Process)_log.txt", + "jobbatchname": batch_name, + "should_transfer_files": "YES", + "when_to_transfer_output": "ON_EXIT", + "transfer_input_files": ", ".join(input_files), + "transfer_output_files": ", ".join(output_files), + } + ) return submit.queue(txn) @@ -107,12 +115,17 @@ def condor_wait_for(schedd, job_ids: Sequence[int]) -> None: break # log only status changes - if (remaining != prev_remaining or - running != prev_running or - worst_pos != prev_worst_pos): + if ( + remaining != prev_remaining + or running != prev_running + or worst_pos != prev_worst_pos + ): logging.info( " remaning: %2d (running: %2d) worst queue position: %2d", - remaining, running, worst_pos + 1) + remaining, + running, + worst_pos + 1, + ) prev_remaining = remaining prev_running = running @@ -122,19 +135,19 @@ def condor_wait_for(schedd, job_ids: Sequence[int]) -> None: def condor_check_status(schedd, job_ids: Sequence[int]) -> Tuple[int, int, int]: - all_jobs = schedd.query(True, ['JobPrio', 'ClusterId', 'ProcId', 'JobStatus']) + all_jobs = schedd.query(True, ["JobPrio", "ClusterId", "ProcId", "JobStatus"]) all_jobs = sorted( - all_jobs, - key=lambda x: (-x['JobPrio'], x['ClusterId'], x['ProcId'])) + all_jobs, key=lambda x: (-x["JobPrio"], x["ClusterId"], x["ProcId"]) + ) worst_pos = 0 running = 0 remaining = 0 for i, job in enumerate(all_jobs): - if int(job['ClusterId']) in job_ids: + if int(job["ClusterId"]) in job_ids: remaining += 1 - if int(job['JobStatus']) == JOB_STATUS_RUNNING: + if int(job["JobStatus"]) == JOB_STATUS_RUNNING: running += 1 worst_pos = i @@ -143,19 +156,31 @@ def condor_check_status(schedd, job_ids: Sequence[int]) -> Tuple[int, int, int]: def main() -> None: parser = argparse.ArgumentParser( - description="Submit prepared jobs to HTCondor cluster") + description="Submit prepared jobs to HTCondor cluster" + ) parser.add_argument( - '-c', '--count', type=int, default=1, - help="How many times to submit job (default: 1)") + "-c", + "--count", + type=int, + default=1, + help="How many times to submit job (default: 1)", + ) parser.add_argument( - '-p', '--priority', type=int, default=5, - help="Set condor job priority, higher means sooner execution (default: 5)") + "-p", + "--priority", + type=int, + default=5, + help="Set condor job priority, higher means sooner execution (default: 5)", + ) parser.add_argument( - '-w', '--wait', action='store_true', - help="Wait until all submitted jobs are finished") + "-w", + "--wait", + action="store_true", + help="Wait until all submitted jobs are finished", + ) parser.add_argument( - 'job_dir', nargs='+', - help="Path to the job directory to be submitted") + "job_dir", nargs="+", help="Path to the job directory to be submitted" + ) args = parser.parse_args() @@ -170,14 +195,15 @@ def main() -> None: job_ids[directory].append(condor_submit(txn, args.priority)) for directory, jobs in job_ids.items(): - logging.info("%s JOB ID(s): %s", directory, ', '.join(str(j) for j in jobs)) + logging.info("%s JOB ID(s): %s", directory, ", ".join(str(j) for j in jobs)) job_count = sum(len(jobs) for jobs in job_ids.values()) logging.info("%d job(s) successfully submitted!", job_count) if args.wait: logging.info( - 'WAITING for jobs to complete. This can be safely interrupted with Ctl+C...') + "WAITING for jobs to complete. This can be safely interrupted with Ctl+C..." + ) try: condor_wait_for(schedd, list(itertools.chain(*job_ids.values()))) except KeyboardInterrupt: @@ -186,16 +212,17 @@ def main() -> None: logging.info("All jobs done!") -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig( - format='%(asctime)s %(levelname)8s %(message)s', level=logging.DEBUG) + format="%(asctime)s %(levelname)8s %(message)s", level=logging.DEBUG + ) with warnings.catch_warnings(): warnings.simplefilter("error") # trigger UserWarning which causes ImportError try: from htcondor import Submit, Schedd except (ImportError, UserWarning): - logging.error('HTCondor not detected. Use this script on a submit machine.') + logging.error("HTCondor not detected. Use this script on a submit machine.") sys.exit(1) try: @@ -206,5 +233,5 @@ if __name__ == '__main__': sys.exit(1) except Exception as exc: logging.debug(traceback.format_exc()) - logging.critical('Unhandled code exception: %s', str(exc)) + logging.critical("Unhandled code exception: %s", str(exc)) sys.exit(2) diff --git a/diffrepro.py b/diffrepro.py index 92ef28a..127616e 100755 --- a/diffrepro.py +++ b/diffrepro.py @@ -10,25 +10,30 @@ from respdiff.dataformat import DiffReport, ReproData def main(): cli.setup_logging() parser = argparse.ArgumentParser( - description='attempt to reproduce original diffs from JSON report') + description="attempt to reproduce original diffs from JSON report" + ) cli.add_arg_envdir(parser) cli.add_arg_config(parser) cli.add_arg_datafile(parser) - parser.add_argument('--sequential', action='store_true', default=False, - help='send one query at a time (slower, but more reliable)') + parser.add_argument( + "--sequential", + action="store_true", + default=False, + help="send one query at a time (slower, but more reliable)", + ) args = parser.parse_args() sendrecv.module_init(args) datafile = cli.get_datafile(args) report = DiffReport.from_json(datafile) restart_scripts = repro.get_restart_scripts(args.cfg) - servers = args.cfg['servers']['names'] + servers = args.cfg["servers"]["names"] dnsreplies_factory = DNSRepliesFactory(servers) if args.sequential: nproc = 1 else: - nproc = args.cfg['sendrecv']['jobs'] + nproc = args.cfg["sendrecv"]["jobs"] if report.reprodata is None: report.reprodata = ReproData() @@ -40,12 +45,18 @@ def main(): dstream = repro.query_stream_from_disagreements(lmdb, report) try: repro.reproduce_queries( - dstream, report, dnsreplies_factory, args.cfg['diff']['criteria'], - args.cfg['diff']['target'], restart_scripts, nproc) + dstream, + report, + dnsreplies_factory, + args.cfg["diff"]["criteria"], + args.cfg["diff"]["target"], + restart_scripts, + nproc, + ) finally: # make sure data is saved in case of interrupt report.export_json(datafile) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/diffsum.py b/diffsum.py index 6d375aa..4f53c1f 100755 --- a/diffsum.py +++ b/diffsum.py @@ -4,8 +4,17 @@ import argparse import logging import sys from typing import ( # noqa - Any, Callable, Iterable, Iterator, ItemsView, List, Set, Sequence, Tuple, - Union) + Any, + Callable, + Iterable, + Iterator, + ItemsView, + List, + Set, + Sequence, + Tuple, + Union, +) import dns.message from respdiff import cli @@ -13,41 +22,61 @@ from respdiff.database import LMDB from respdiff.dataformat import DiffReport, Summary from respdiff.dnsviz import DnsvizGrok from respdiff.query import ( - convert_queries, get_printable_queries_format, get_query_iterator, qwire_to_msgid_qname_qtype) + convert_queries, + get_printable_queries_format, + get_query_iterator, + qwire_to_msgid_qname_qtype, +) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( - description='create a summary report from gathered data stored in LMDB ' - 'and JSON datafile') + description="create a summary report from gathered data stored in LMDB " + "and JSON datafile" + ) cli.add_arg_envdir(parser) cli.add_arg_config(parser) cli.add_arg_datafile(parser) cli.add_arg_limit(parser) - cli.add_arg_stats_filename(parser, default='') - cli.add_arg_dnsviz(parser, default='') - parser.add_argument('--without-dnsviz-errors', action='store_true', - help='omit domains that have any errors in DNSViz results') - parser.add_argument('--without-diffrepro', action='store_true', - help='omit reproducibility data from summary') - parser.add_argument('--without-ref-unstable', action='store_true', - help='omit unstable reference queries from summary') - parser.add_argument('--without-ref-failing', action='store_true', - help='omit failing reference queries from summary') + cli.add_arg_stats_filename(parser, default="") + cli.add_arg_dnsviz(parser, default="") + parser.add_argument( + "--without-dnsviz-errors", + action="store_true", + help="omit domains that have any errors in DNSViz results", + ) + parser.add_argument( + "--without-diffrepro", + action="store_true", + help="omit reproducibility data from summary", + ) + parser.add_argument( + "--without-ref-unstable", + action="store_true", + help="omit unstable reference queries from summary", + ) + parser.add_argument( + "--without-ref-failing", + action="store_true", + help="omit failing reference queries from summary", + ) return parser.parse_args() def check_args(args: argparse.Namespace, report: DiffReport): - if (args.without_ref_unstable or args.without_ref_failing) \ - and not args.stats_filename: + if ( + args.without_ref_unstable or args.without_ref_failing + ) and not args.stats_filename: logging.critical("Statistics file must be provided as a reference.") sys.exit(1) if not report.total_answers: - logging.error('No answers in DB!') + logging.error("No answers in DB!") sys.exit(1) if report.target_disagreements is None: - logging.error('JSON report is missing diff data! Did you forget to run msgdiff?') + logging.error( + "JSON report is missing diff data! Did you forget to run msgdiff?" + ) sys.exit(1) @@ -56,7 +85,7 @@ def main(): args = parse_args() datafile = cli.get_datafile(args) report = DiffReport.from_json(datafile) - field_weights = args.cfg['report']['field_weights'] + field_weights = args.cfg["report"]["field_weights"] check_args(args, report) @@ -74,16 +103,18 @@ def main(): report = DiffReport.from_json(datafile) report.summary = Summary.from_report( - report, field_weights, + report, + field_weights, without_diffrepro=args.without_diffrepro, - ignore_qids=ignore_qids) + ignore_qids=ignore_qids, + ) # dnsviz filter: by domain -> need to iterate over disagreements to get QIDs if args.without_dnsviz_errors: try: dnsviz_grok = DnsvizGrok.from_json(args.dnsviz) except (FileNotFoundError, RuntimeError) as exc: - logging.critical('Failed to load dnsviz data: %s', exc) + logging.critical("Failed to load dnsviz data: %s", exc) sys.exit(1) error_domains = dnsviz_grok.error_domains() @@ -93,14 +124,18 @@ def main(): for qid, wire in get_query_iterator(lmdb, report.summary.keys()): msg = dns.message.from_wire(wire) if msg.question: - if any(msg.question[0].name.is_subdomain(name) - for name in error_domains): + if any( + msg.question[0].name.is_subdomain(name) + for name in error_domains + ): ignore_qids.add(qid) report.summary = Summary.from_report( - report, field_weights, + report, + field_weights, without_diffrepro=args.without_diffrepro, - ignore_qids=ignore_qids) + ignore_qids=ignore_qids, + ) cli.print_global_stats(report) cli.print_differences_stats(report) @@ -111,7 +146,8 @@ def main(): for field in field_weights: if field in report.summary.field_labels: cli.print_field_mismatch_stats( - field, field_counters[field], len(report.summary)) + field, field_counters[field], len(report.summary) + ) # query details with LMDB(args.envdir, readonly=True) as lmdb: @@ -120,16 +156,18 @@ def main(): for field in field_weights: if field in report.summary.field_labels: for mismatch, qids in report.summary.get_field_mismatches(field): - queries = convert_queries(get_query_iterator(lmdb, qids), - qwire_to_msgid_qname_qtype) + queries = convert_queries( + get_query_iterator(lmdb, qids), qwire_to_msgid_qname_qtype + ) cli.print_mismatch_queries( field, mismatch, get_printable_queries_format(queries), - args.limit) + args.limit, + ) report.export_json(datafile) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/distrcmp.py b/distrcmp.py index 012360d..21e0dff 100755 --- a/distrcmp.py +++ b/distrcmp.py @@ -12,7 +12,9 @@ from respdiff.stats import Stats, SummaryStatistics DEFAULT_COEF = 2 -def belongs_to_distribution(ref: Optional[Stats], new: Optional[Stats], coef: float) -> bool: +def belongs_to_distribution( + ref: Optional[Stats], new: Optional[Stats], coef: float +) -> bool: if ref is None or new is None: return False median = statistics.median(new.samples) @@ -24,7 +26,9 @@ def belongs_to_distribution(ref: Optional[Stats], new: Optional[Stats], coef: fl def belongs_to_all(ref: SummaryStatistics, new: SummaryStatistics, coef: float) -> bool: - if not belongs_to_distribution(ref.target_disagreements, new.target_disagreements, coef): + if not belongs_to_distribution( + ref.target_disagreements, new.target_disagreements, coef + ): return False if not belongs_to_distribution(ref.upstream_unstable, new.upstream_unstable, coef): return False @@ -32,24 +36,38 @@ def belongs_to_all(ref: SummaryStatistics, new: SummaryStatistics, coef: float) return False if ref.fields is not None and new.fields is not None: for field in ref.fields: - if not belongs_to_distribution(ref.fields[field].total, new.fields[field].total, coef): + if not belongs_to_distribution( + ref.fields[field].total, new.fields[field].total, coef + ): return False return True def main(): cli.setup_logging() - parser = argparse.ArgumentParser(description='Check if new samples belong to reference ' - 'distribution. Ends with exitcode 0 if belong, ' - '1 if not,') - parser.add_argument('-r', '--reference', type=cli.read_stats, - help='json statistics file with reference data') + parser = argparse.ArgumentParser( + description="Check if new samples belong to reference " + "distribution. Ends with exitcode 0 if belong, " + "1 if not," + ) + parser.add_argument( + "-r", + "--reference", + type=cli.read_stats, + help="json statistics file with reference data", + ) cli.add_arg_report_filename(parser) - parser.add_argument('-c', '--coef', type=float, - default=DEFAULT_COEF, - help=('coeficient for comparation (new belongs to refference if ' - 'its median is closer than COEF * standart deviation of reference ' - 'from reference mean) (default: {})'.format(DEFAULT_COEF))) + parser.add_argument( + "-c", + "--coef", + type=float, + default=DEFAULT_COEF, + help=( + "coeficient for comparation (new belongs to refference if " + "its median is closer than COEF * standart deviation of reference " + "from reference mean) (default: {})".format(DEFAULT_COEF) + ), + ) args = parser.parse_args() reports = cli.get_reports_from_filenames(args) @@ -64,5 +82,5 @@ def main(): sys.exit(0) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/dnsviz.py b/dnsviz.py index a5658cd..337496e 100755 --- a/dnsviz.py +++ b/dnsviz.py @@ -11,17 +11,32 @@ import respdiff.dnsviz def main(): cli.setup_logging() parser = argparse.ArgumentParser( - description="use dnsviz to categorize domains (perfect, warnings, errors)") + description="use dnsviz to categorize domains (perfect, warnings, errors)" + ) cli.add_arg_config(parser) cli.add_arg_dnsviz(parser) - parser.add_argument('input', type=str, help='input file with domains (one qname per line)') + parser.add_argument( + "input", type=str, help="input file with domains (one qname per line)" + ) args = parser.parse_args() - njobs = args.cfg['sendrecv']['jobs'] + njobs = args.cfg["sendrecv"]["jobs"] try: - probe = subprocess.run([ - 'dnsviz', 'probe', '-A', '-R', respdiff.dnsviz.TYPES, '-f', - args.input, '-t', str(njobs)], check=True, stdout=subprocess.PIPE) + probe = subprocess.run( + [ + "dnsviz", + "probe", + "-A", + "-R", + respdiff.dnsviz.TYPES, + "-f", + args.input, + "-t", + str(njobs), + ], + check=True, + stdout=subprocess.PIPE, + ) except subprocess.CalledProcessError as exc: logging.critical("dnsviz probe failed: %s", exc) sys.exit(1) @@ -30,12 +45,16 @@ def main(): sys.exit(1) try: - subprocess.run(['dnsviz', 'grok', '-o', args.dnsviz], input=probe.stdout, - check=True, stdout=subprocess.PIPE) + subprocess.run( + ["dnsviz", "grok", "-o", args.dnsviz], + input=probe.stdout, + check=True, + stdout=subprocess.PIPE, + ) except subprocess.CalledProcessError as exc: logging.critical("dnsviz grok failed: %s", exc) sys.exit(1) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/histogram.py b/histogram.py index 5ec82c5..0bd1e0d 100755 --- a/histogram.py +++ b/histogram.py @@ -25,7 +25,8 @@ from respdiff.typing import ResolverID # Force matplotlib to use a different backend to handle machines without a display import matplotlib import matplotlib.ticker as mtick -matplotlib.use('Agg') + +matplotlib.use("Agg") import matplotlib.pyplot as plt # noqa @@ -33,9 +34,8 @@ HISTOGRAM_RCODE_MAX = 23 def load_data( - txn: lmdb.Transaction, - dnsreplies_factory: DNSRepliesFactory - ) -> Dict[ResolverID, List[Tuple[float, Optional[int]]]]: + txn: lmdb.Transaction, dnsreplies_factory: DNSRepliesFactory +) -> Dict[ResolverID, List[Tuple[float, Optional[int]]]]: data = {} # type: Dict[ResolverID, List[Tuple[float, Optional[int]]]] cursor = txn.cursor() for value in cursor.iternext(keys=False, values=True): @@ -45,17 +45,15 @@ def load_data( # 12 is chosen to be consistent with dnspython's ShortHeader exception rcode = None else: - (flags, ) = struct.unpack('!H', reply.wire[2:4]) - rcode = flags & 0x000f + (flags,) = struct.unpack("!H", reply.wire[2:4]) + rcode = flags & 0x000F data.setdefault(resolver, []).append((reply.time, rcode)) return data def plot_log_percentile_histogram( - data: Dict[str, List[float]], - title: str, - config=None - ) -> None: + data: Dict[str, List[float]], title: str, config=None +) -> None: """ For graph explanation, see https://blog.powerdns.com/2017/11/02/dns-performance-metrics-the-logarithmic-percentile-histogram/ @@ -66,55 +64,59 @@ def plot_log_percentile_histogram( # Distribute sample points along logarithmic X axis percentiles = np.logspace(-3, 2, num=100) - ax.set_xscale('log') - ax.xaxis.set_major_formatter(mtick.FormatStrFormatter('%s')) - ax.set_yscale('log') - ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%s')) + ax.set_xscale("log") + ax.xaxis.set_major_formatter(mtick.FormatStrFormatter("%s")) + ax.set_yscale("log") + ax.yaxis.set_major_formatter(mtick.FormatStrFormatter("%s")) - ax.grid(True, which='major') - ax.grid(True, which='minor', linestyle='dotted', color='#DDDDDD') + ax.grid(True, which="major") + ax.grid(True, which="minor", linestyle="dotted", color="#DDDDDD") - ax.set_xlabel('Slowest percentile') - ax.set_ylabel('Response time [ms]') - ax.set_title('Resolver Response Time' + " - " + title) + ax.set_xlabel("Slowest percentile") + ax.set_ylabel("Response time [ms]") + ax.set_title("Resolver Response Time" + " - " + title) # plot data for server in sorted(data): if data[server]: try: - color = config[server]['graph_color'] + color = config[server]["graph_color"] except (KeyError, TypeError): color = None # convert to ms and sort values = sorted([1000 * x for x in data[server]], reverse=True) - ax.plot(percentiles, - [values[math.ceil(pctl * len(values) / 100) - 1] for pctl in percentiles], lw=2, - label='{:<10}'.format(server) + " " + '{:9d}'.format(len(values)), color=color) + ax.plot( + percentiles, + [ + values[math.ceil(pctl * len(values) / 100) - 1] + for pctl in percentiles + ], + lw=2, + label="{:<10}".format(server) + " " + "{:9d}".format(len(values)), + color=color, + ) plt.legend() def create_histogram( - data: Dict[str, List[float]], - filename: str, - title: str, - config=None - ) -> None: + data: Dict[str, List[float]], filename: str, title: str, config=None +) -> None: # don't plot graphs which don't contain any finite time - if any(any(time < float('+inf') for time in d) for d in data.values()): + if any(any(time < float("+inf") for time in d) for d in data.values()): plot_log_percentile_histogram(data, title, config) # save to file plt.savefig(filename, dpi=300) def histogram_by_rcode( - data: Dict[ResolverID, List[Tuple[float, Optional[int]]]], - filename: str, - title: str, - config=None, - rcode: Optional[int] = None - ) -> None: + data: Dict[ResolverID, List[Tuple[float, Optional[int]]]], + filename: str, + title: str, + config=None, + rcode: Optional[int] = None, +) -> None: def same_rcode(value: Tuple[float, Optional[int]]) -> bool: if rcode is None: if value[1] is None: @@ -125,26 +127,43 @@ def histogram_by_rcode( filtered_by_rcode = { resolver: [time for (time, rc) in filter(same_rcode, values)] - for (resolver, values) in data.items()} + for (resolver, values) in data.items() + } create_histogram(filtered_by_rcode, filename, title, config) def main(): cli.setup_logging() parser = argparse.ArgumentParser( - description='Plot query response time histogram from answers stored ' - 'in LMDB') - parser.add_argument('-o', '--output', type=str, default='histogram', - help='output directory for image files (default: histogram)') - parser.add_argument('-f', '--format', type=str, default='png', - help='output image format (default: png)') - parser.add_argument('-c', '--config', default='respdiff.cfg', dest='cfgpath', - help='config file (default: respdiff.cfg)') - parser.add_argument('envdir', type=str, - help='LMDB environment to read answers from') + description="Plot query response time histogram from answers stored " "in LMDB" + ) + parser.add_argument( + "-o", + "--output", + type=str, + default="histogram", + help="output directory for image files (default: histogram)", + ) + parser.add_argument( + "-f", + "--format", + type=str, + default="png", + help="output image format (default: png)", + ) + parser.add_argument( + "-c", + "--config", + default="respdiff.cfg", + dest="cfgpath", + help="config file (default: respdiff.cfg)", + ) + parser.add_argument( + "envdir", type=str, help="LMDB environment to read answers from" + ) args = parser.parse_args() config = cfg.read_cfg(args.cfgpath) - servers = config['servers']['names'] + servers = config["servers"]["names"] dnsreplies_factory = DNSRepliesFactory(servers) with LMDB(args.envdir, readonly=True) as lmdb_: @@ -160,12 +179,16 @@ def main(): data = load_data(txn, dnsreplies_factory) def get_filepath(filename) -> str: - return os.path.join(args.output, filename + '.' + args.format) + return os.path.join(args.output, filename + "." + args.format) if not os.path.exists(args.output): os.makedirs(args.output) - create_histogram({k: [tup[0] for tup in d] for (k, d) in data.items()}, - get_filepath('all'), 'all', config) + create_histogram( + {k: [tup[0] for tup in d] for (k, d) in data.items()}, + get_filepath("all"), + "all", + config, + ) # rcode-specific queries with pool.Pool() as p: @@ -175,9 +198,9 @@ def main(): filepath = get_filepath(rcode_text) fargs.append((data, filepath, rcode_text, config, rcode)) p.starmap(histogram_by_rcode, fargs) - filepath = get_filepath('unparsed') - histogram_by_rcode(data, filepath, 'unparsed queries', config, None) + filepath = get_filepath("unparsed") + histogram_by_rcode(data, filepath, "unparsed queries", config, None) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/msgdiff.py b/msgdiff.py index 10d7b68..43d95a2 100755 --- a/msgdiff.py +++ b/msgdiff.py @@ -10,7 +10,12 @@ from typing import Any, Dict, Iterator, Mapping, Optional, Sequence, Tuple # no from respdiff import cli from respdiff.dataformat import ( - DiffReport, Disagreements, DisagreementsCounter, FieldLabel, QID) + DiffReport, + Disagreements, + DisagreementsCounter, + FieldLabel, + QID, +) from respdiff.database import DNSRepliesFactory, DNSReply, key2qid, LMDB, MetaDatabase from respdiff.match import compare from respdiff.typing import ResolverID @@ -20,9 +25,8 @@ lmdb = None def read_answers_lmdb( - dnsreplies_factory: DNSRepliesFactory, - qid: QID - ) -> Mapping[ResolverID, DNSReply]: + dnsreplies_factory: DNSRepliesFactory, qid: QID +) -> Mapping[ResolverID, DNSReply]: assert lmdb is not None, "LMDB wasn't initialized!" adb = lmdb.get_db(LMDB.ANSWERS) with lmdb.env.begin(adb) as txn: @@ -32,11 +36,11 @@ def read_answers_lmdb( def compare_lmdb_wrapper( - criteria: Sequence[FieldLabel], - target: ResolverID, - dnsreplies_factory: DNSRepliesFactory, - qid: QID - ) -> None: + criteria: Sequence[FieldLabel], + target: ResolverID, + dnsreplies_factory: DNSRepliesFactory, + qid: QID, +) -> None: assert lmdb is not None, "LMDB wasn't initialized!" answers = read_answers_lmdb(dnsreplies_factory, qid) others_agree, target_diffs = compare(answers, criteria, target) @@ -69,11 +73,13 @@ def export_json(filename: str, report: DiffReport): # NOTE: msgdiff is the first tool in the toolchain to generate report.json # thus it doesn't make sense to re-use existing report.json file if os.path.exists(filename): - backup_filename = filename + '.bak' + backup_filename = filename + ".bak" os.rename(filename, backup_filename) logging.warning( - 'JSON report already exists, overwriting file. Original ' - 'file backed up as %s', backup_filename) + "JSON report already exists, overwriting file. Original " + "file backed up as %s", + backup_filename, + ) report.export_json(filename) @@ -81,18 +87,14 @@ def prepare_report(lmdb_, servers: Sequence[ResolverID]) -> DiffReport: qdb = lmdb_.open_db(LMDB.QUERIES) adb = lmdb_.open_db(LMDB.ANSWERS) with lmdb_.env.begin() as txn: - total_queries = txn.stat(qdb)['entries'] - total_answers = txn.stat(adb)['entries'] + total_queries = txn.stat(qdb)["entries"] + total_answers = txn.stat(adb)["entries"] meta = MetaDatabase(lmdb_, servers) start_time = meta.read_start_time() end_time = meta.read_end_time() - return DiffReport( - start_time, - end_time, - total_queries, - total_answers) + return DiffReport(start_time, end_time, total_queries, total_answers) def main(): @@ -100,16 +102,17 @@ def main(): cli.setup_logging() parser = argparse.ArgumentParser( - description='compute diff from answers stored in LMDB and write diffs to LMDB') + description="compute diff from answers stored in LMDB and write diffs to LMDB" + ) cli.add_arg_envdir(parser) cli.add_arg_config(parser) cli.add_arg_datafile(parser) args = parser.parse_args() datafile = cli.get_datafile(args, check_exists=False) - criteria = args.cfg['diff']['criteria'] - target = args.cfg['diff']['target'] - servers = args.cfg['servers']['names'] + criteria = args.cfg["diff"]["criteria"] + target = args.cfg["diff"]["target"] + servers = args.cfg["servers"]["names"] with LMDB(args.envdir) as lmdb_: # NOTE: To avoid an lmdb.BadRslotError, probably caused by weird @@ -126,12 +129,13 @@ def main(): dnsreplies_factory = DNSRepliesFactory(servers) compare_func = partial( - compare_lmdb_wrapper, criteria, target, dnsreplies_factory) + compare_lmdb_wrapper, criteria, target, dnsreplies_factory + ) with pool.Pool() as p: for _ in p.imap_unordered(compare_func, qid_stream, chunksize=10): pass export_json(datafile, report) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/orchestrator.py b/orchestrator.py index bbbfb4b..3ee3ee1 100755 --- a/orchestrator.py +++ b/orchestrator.py @@ -12,18 +12,22 @@ from respdiff.database import LMDB, MetaDatabase def main(): cli.setup_logging() parser = argparse.ArgumentParser( - description='read queries from LMDB, send them in parallel to servers ' - 'listed in configuration file, and record answers into LMDB') + description="read queries from LMDB, send them in parallel to servers " + "listed in configuration file, and record answers into LMDB" + ) cli.add_arg_envdir(parser) cli.add_arg_config(parser) - parser.add_argument('--ignore-timeout', action="store_true", - help='continue despite consecutive timeouts from resolvers') + parser.add_argument( + "--ignore-timeout", + action="store_true", + help="continue despite consecutive timeouts from resolvers", + ) args = parser.parse_args() sendrecv.module_init(args) with LMDB(args.envdir) as lmdb: - meta = MetaDatabase(lmdb, args.cfg['servers']['names'], create=True) + meta = MetaDatabase(lmdb, args.cfg["servers"]["names"], create=True) meta.write_version() meta.write_start_time() @@ -35,24 +39,25 @@ def main(): try: # process queries in parallel with pool.Pool( - processes=args.cfg['sendrecv']['jobs'], - initializer=sendrecv.worker_init) as p: + processes=args.cfg["sendrecv"]["jobs"], initializer=sendrecv.worker_init + ) as p: i = 0 for qkey, blob in p.imap_unordered( - sendrecv.worker_perform_query, qstream, chunksize=100): + sendrecv.worker_perform_query, qstream, chunksize=100 + ): i += 1 if i % 10000 == 0: - logging.info('Received {:d} answers'.format(i)) + logging.info("Received {:d} answers".format(i)) txn.put(qkey, blob) except KeyboardInterrupt: - logging.info('SIGINT received, exiting...') + logging.info("SIGINT received, exiting...") sys.exit(130) except RuntimeError as err: logging.error(err) sys.exit(1) finally: # attempt to preserve data if something went wrong (or not) - logging.debug('Comitting LMDB transaction...') + logging.debug("Comitting LMDB transaction...") txn.commit() meta.write_end_time() diff --git a/qexport.py b/qexport.py index 0198be6..85d8589 100755 --- a/qexport.py +++ b/qexport.py @@ -14,28 +14,29 @@ from respdiff.typing import QID def get_qids_to_export( - args: argparse.Namespace, - reports: Sequence[DiffReport] - ) -> Set[QID]: + args: argparse.Namespace, reports: Sequence[DiffReport] +) -> Set[QID]: qids = set() # type: Set[QID] for report in reports: if args.failing: if report.summary is None: raise ValueError( - "Report {} is missing summary!".format(report.fileorigin)) + "Report {} is missing summary!".format(report.fileorigin) + ) failing_qids = set(report.summary.keys()) qids.update(failing_qids) if args.unstable: if report.other_disagreements is None: raise ValueError( - "Report {} is missing other disagreements!".format(report.fileorigin)) + "Report {} is missing other disagreements!".format( + report.fileorigin + ) + ) unstable_qids = report.other_disagreements.queries qids.update(unstable_qids) if args.qidlist: - with open(args.qidlist, encoding='UTF-8') as qidlist_file: - qids.update(int(qid.strip()) - for qid in qidlist_file - if qid.strip()) + with open(args.qidlist, encoding="UTF-8") as qidlist_file: + qids.update(int(qid.strip()) for qid in qidlist_file if qid.strip()) return qids @@ -49,7 +50,7 @@ def export_qids_to_qname_qtype(qids: Set[QID], lmdb, file=sys.stdout): try: query = qwire_to_qname_qtype(qwire) except ValueError as exc: - logging.debug('Omitting QID %d from export: %s', qid, exc) + logging.debug("Omitting QID %d from export: %s", qid, exc) else: print(query, file=file) @@ -60,7 +61,7 @@ def export_qids_to_qname(qids: Set[QID], lmdb, file=sys.stdout): try: qname = qwire_to_qname(qwire) except ValueError as exc: - logging.debug('Omitting QID %d from export: %s', qid, exc) + logging.debug("Omitting QID %d from export: %s", qid, exc) else: if qname not in domains: print(qname, file=file) @@ -71,36 +72,51 @@ def export_qids_to_base64url(qids: Set[QID], lmdb, file=sys.stdout): wires = set() # type: Set[bytes] for _, qwire in get_query_iterator(lmdb, qids): if qwire not in wires: - print(base64.urlsafe_b64encode(qwire).decode('ascii'), file=file) + print(base64.urlsafe_b64encode(qwire).decode("ascii"), file=file) wires.add(qwire) def main(): cli.setup_logging() - parser = argparse.ArgumentParser(description="export queries from reports' summaries") - cli.add_arg_report_filename(parser, nargs='+') - parser.add_argument('--envdir', type=str, - help="LMDB environment (required when output format isn't 'qid')") - parser.add_argument('-f', '--format', type=str, choices=['query', 'qid', 'domain', 'base64url'], - default='domain', help="output data format") - parser.add_argument('-o', '--output', type=str, help='output file') - parser.add_argument('--failing', action='store_true', help="get target disagreements") - parser.add_argument('--unstable', action='store_true', help="get upstream unstable") - parser.add_argument('--qidlist', type=str, help='path to file with list of QIDs to export') + parser = argparse.ArgumentParser( + description="export queries from reports' summaries" + ) + cli.add_arg_report_filename(parser, nargs="+") + parser.add_argument( + "--envdir", + type=str, + help="LMDB environment (required when output format isn't 'qid')", + ) + parser.add_argument( + "-f", + "--format", + type=str, + choices=["query", "qid", "domain", "base64url"], + default="domain", + help="output data format", + ) + parser.add_argument("-o", "--output", type=str, help="output file") + parser.add_argument( + "--failing", action="store_true", help="get target disagreements" + ) + parser.add_argument("--unstable", action="store_true", help="get upstream unstable") + parser.add_argument( + "--qidlist", type=str, help="path to file with list of QIDs to export" + ) args = parser.parse_args() - if args.format != 'qid' and not args.envdir: + if args.format != "qid" and not args.envdir: logging.critical("--envdir required when output format isn't 'qid'") sys.exit(1) if not args.failing and not args.unstable and not args.qidlist: - logging.critical('No filter selected!') + logging.critical("No filter selected!") sys.exit(1) reports = cli.get_reports_from_filenames(args) if not reports: - logging.critical('No reports found!') + logging.critical("No reports found!") sys.exit(1) try: @@ -110,21 +126,21 @@ def main(): sys.exit(1) with cli.smart_open(args.output) as fh: - if args.format == 'qid': + if args.format == "qid": export_qids(qids, fh) else: with LMDB(args.envdir, readonly=True) as lmdb: lmdb.open_db(LMDB.QUERIES) - if args.format == 'query': + if args.format == "query": export_qids_to_qname_qtype(qids, lmdb, fh) - elif args.format == 'domain': + elif args.format == "domain": export_qids_to_qname(qids, lmdb, fh) - elif args.format == 'base64url': + elif args.format == "base64url": export_qids_to_base64url(qids, lmdb, fh) else: - raise ValueError('unsupported output format') + raise ValueError("unsupported output format") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/qprep.py b/qprep.py index bedd7bf..6f94c97 100755 --- a/qprep.py +++ b/qprep.py @@ -26,7 +26,7 @@ def read_lines(instream): i = 1 for line in instream: if i % REPORT_CHUNKS == 0: - logging.info('Read %d lines', i) + logging.info("Read %d lines", i) line = line.strip() if line: yield (i, line, line) @@ -62,14 +62,14 @@ def parse_pcap(pcap_file): pcap_file = dpkt.pcap.Reader(pcap_file) for _, frame in pcap_file: if i % REPORT_CHUNKS == 0: - logging.info('Read %d frames', i) - yield (i, frame, 'frame no. {}'.format(i)) + logging.info("Read %d frames", i) + yield (i, frame, "frame no. {}".format(i)) i += 1 def wrk_process_line( - args: Tuple[int, str, str] - ) -> Tuple[Optional[int], Optional[bytes]]: + args: Tuple[int, str, str] +) -> Tuple[Optional[int], Optional[bytes]]: """ Worker: parse input line, creates a packet in binary format @@ -80,16 +80,19 @@ def wrk_process_line( try: msg = msg_from_text(line) if blacklist.is_blacklisted(msg): - logging.debug('Blacklisted query "%s", skipping QID %d', - log_repr, qid) + logging.debug('Blacklisted query "%s", skipping QID %d', log_repr, qid) return None, None return qid, msg.to_wire() except (ValueError, struct.error, dns.exception.DNSException) as ex: - logging.error('Invalid query specification "%s": %s, skipping QID %d', line, ex, qid) + logging.error( + 'Invalid query specification "%s": %s, skipping QID %d', line, ex, qid + ) return None, None -def wrk_process_frame(args: Tuple[int, bytes, str]) -> Tuple[Optional[int], Optional[bytes]]: +def wrk_process_frame( + args: Tuple[int, bytes, str] +) -> Tuple[Optional[int], Optional[bytes]]: """ Worker: convert packet from pcap to binary data """ @@ -99,10 +102,8 @@ def wrk_process_frame(args: Tuple[int, bytes, str]) -> Tuple[Optional[int], Opti def wrk_process_wire_packet( - qid: int, - wire_packet: bytes, - log_repr: str - ) -> Tuple[Optional[int], Optional[bytes]]: + qid: int, wire_packet: bytes, log_repr: str +) -> Tuple[Optional[int], Optional[bytes]]: """ Worker: Return packet's data if it's not blacklisted @@ -117,8 +118,7 @@ def wrk_process_wire_packet( pass else: if blacklist.is_blacklisted(msg): - logging.debug('Blacklisted query "%s", skipping QID %d', - log_repr, qid) + logging.debug('Blacklisted query "%s", skipping QID %d', log_repr, qid) return None, None return qid, wire_packet @@ -140,11 +140,14 @@ def msg_from_text(text): try: qname, qtype = text.split() except ValueError as e: - raise ValueError('space is only allowed as separator between qname qtype') from e - qname = dns.name.from_text(qname.encode('ascii')) + raise ValueError( + "space is only allowed as separator between qname qtype" + ) from e + qname = dns.name.from_text(qname.encode("ascii")) qtype = int_or_fromtext(qtype, dns.rdatatype.from_text) - msg = dns.message.make_query(qname, qtype, dns.rdataclass.IN, - want_dnssec=True, payload=4096) + msg = dns.message.make_query( + qname, qtype, dns.rdataclass.IN, want_dnssec=True, payload=4096 + ) return msg @@ -152,22 +155,31 @@ def main(): cli.setup_logging() parser = argparse.ArgumentParser( formatter_class=argparse.RawTextHelpFormatter, - description='Convert queries data from standard input and store ' - 'wire format into LMDB "queries" DB.') + description="Convert queries data from standard input and store " + 'wire format into LMDB "queries" DB.', + ) cli.add_arg_envdir(parser) - parser.add_argument('-f', '--in-format', type=str, choices=['text', 'pcap'], default='text', - help='define format for input data, default value is text\n' - 'Expected input for "text" is: "<qname> <RR type>", ' - 'one query per line.\n' - 'Expected input for "pcap" is content of the pcap file.') - parser.add_argument('--pcap-file', type=argparse.FileType('rb')) + parser.add_argument( + "-f", + "--in-format", + type=str, + choices=["text", "pcap"], + default="text", + help="define format for input data, default value is text\n" + 'Expected input for "text" is: "<qname> <RR type>", ' + "one query per line.\n" + 'Expected input for "pcap" is content of the pcap file.', + ) + parser.add_argument("--pcap-file", type=argparse.FileType("rb")) args = parser.parse_args() - if args.in_format == 'text' and args.pcap_file: - logging.critical("Argument --pcap-file can be use only in combination with -f pcap") + if args.in_format == "text" and args.pcap_file: + logging.critical( + "Argument --pcap-file can be use only in combination with -f pcap" + ) sys.exit(1) - if args.in_format == 'pcap' and not args.pcap_file: + if args.in_format == "pcap" and not args.pcap_file: logging.critical("Missing path to pcap file, use argument --pcap-file") sys.exit(1) @@ -176,12 +188,12 @@ def main(): txn = lmdb.env.begin(qdb, write=True) try: with pool.Pool( - initializer=lambda: signal.signal(signal.SIGINT, signal.SIG_IGN) - ) as workers: - if args.in_format == 'text': + initializer=lambda: signal.signal(signal.SIGINT, signal.SIG_IGN) + ) as workers: + if args.in_format == "text": data_stream = read_lines(sys.stdin) method = wrk_process_line - elif args.in_format == 'pcap': + elif args.in_format == "pcap": data_stream = parse_pcap(args.pcap_file) method = wrk_process_frame else: @@ -192,16 +204,16 @@ def main(): key = qid2key(qid) txn.put(key, wire) except KeyboardInterrupt: - logging.info('SIGINT received, exiting...') + logging.info("SIGINT received, exiting...") sys.exit(130) except RuntimeError as err: logging.error(err) sys.exit(1) finally: # attempt to preserve data if something went wrong (or not) - logging.debug('Comitting LMDB transaction...') + logging.debug("Comitting LMDB transaction...") txn.commit() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/respdiff/__init__.py b/respdiff/__init__.py index 65432d4..797e22d 100644 --- a/respdiff/__init__.py +++ b/respdiff/__init__.py @@ -2,7 +2,7 @@ import sys # LMDB isn't portable between LE/BE platforms # upon import, check we're on a little endian platform -assert sys.byteorder == 'little', 'Big endian platforms are not supported' +assert sys.byteorder == "little", "Big endian platforms are not supported" # Check minimal Python version -assert sys.version_info >= (3, 5, 2), 'Minimal supported Python version is 3.5.2' +assert sys.version_info >= (3, 5, 2), "Minimal supported Python version is 3.5.2" diff --git a/respdiff/blacklist.py b/respdiff/blacklist.py index e050c56..a2c757c 100644 --- a/respdiff/blacklist.py +++ b/respdiff/blacklist.py @@ -2,19 +2,37 @@ import dns from dns.message import Message -_BLACKLIST_SUBDOMAINS = [dns.name.from_text(name) for name in [ - 'local.', - # dotnxdomain.net and dashnxdomain.net are used by APNIC for ephemeral - # single-query tests so there is no point in asking these repeatedly - 'dotnxdomain.net.', 'dashnxdomain.net.', - # Some malware names from shadowserver.org - # Asking them repeatedly may give impression that we're infected. - 'ak-is.net', 'boceureid.com', 'dkfdbbtmeeow.net', 'dssjistyyqhs.online', - 'gowwmcawjbuwrvrwc.com', 'hhjxobtybumg.pw', 'iqodqxy.info', 'k5service.org', - 'kibot.pw', 'mdtrlvfjppreyv.in', 'oxekbtumwc.info', 'raz4756pi7zop.com', - 'rbqfe.net', 'theyardservice.com', 'uwnlzvl.info', 'weg11uvrym3.com', - 'winsrw.com', 'xfwnbk.info', 'xmrimtl.net', -]] +_BLACKLIST_SUBDOMAINS = [ + dns.name.from_text(name) + for name in [ + "local.", + # dotnxdomain.net and dashnxdomain.net are used by APNIC for ephemeral + # single-query tests so there is no point in asking these repeatedly + "dotnxdomain.net.", + "dashnxdomain.net.", + # Some malware names from shadowserver.org + # Asking them repeatedly may give impression that we're infected. + "ak-is.net", + "boceureid.com", + "dkfdbbtmeeow.net", + "dssjistyyqhs.online", + "gowwmcawjbuwrvrwc.com", + "hhjxobtybumg.pw", + "iqodqxy.info", + "k5service.org", + "kibot.pw", + "mdtrlvfjppreyv.in", + "oxekbtumwc.info", + "raz4756pi7zop.com", + "rbqfe.net", + "theyardservice.com", + "uwnlzvl.info", + "weg11uvrym3.com", + "winsrw.com", + "xfwnbk.info", + "xmrimtl.net", + ] +] def is_blacklisted(dnsmsg: Message) -> bool: @@ -23,7 +41,7 @@ def is_blacklisted(dnsmsg: Message) -> bool: """ try: flags = dns.flags.to_text(dnsmsg.flags).split() - if 'QR' in flags: # not a query + if "QR" in flags: # not a query return True if len(dnsmsg.question) != 1: # weird but valid packet (maybe DNS Cookies) @@ -32,8 +50,7 @@ def is_blacklisted(dnsmsg: Message) -> bool: # there is not standard describing common behavior for ANY/RRSIG query if question.rdtype in {dns.rdatatype.ANY, dns.rdatatype.RRSIG}: return True - return any(question.name.is_subdomain(name) - for name in _BLACKLIST_SUBDOMAINS) + return any(question.name.is_subdomain(name) for name in _BLACKLIST_SUBDOMAINS) except Exception: # weird stuff, it's better to test resolver with this as well! return False diff --git a/respdiff/cfg.py b/respdiff/cfg.py index 60bb050..87acab8 100644 --- a/respdiff/cfg.py +++ b/respdiff/cfg.py @@ -12,8 +12,20 @@ import dns.inet ALL_FIELDS = [ - 'timeout', 'malformed', 'opcode', 'question', 'rcode', 'flags', 'answertypes', - 'answerrrsigs', 'answer', 'authority', 'additional', 'edns', 'nsid'] + "timeout", + "malformed", + "opcode", + "question", + "rcode", + "flags", + "answertypes", + "answerrrsigs", + "answer", + "authority", + "additional", + "edns", + "nsid", +] ALL_FIELDS_SET = set(ALL_FIELDS) @@ -30,45 +42,45 @@ def comma_list(lstr): """ Split string 'a, b' into list [a, b] """ - return [name.strip() for name in lstr.split(',')] + return [name.strip() for name in lstr.split(",")] def transport_opt(ostr): - if ostr not in {'udp', 'tcp', 'tls'}: - raise ValueError('unsupported transport') + if ostr not in {"udp", "tcp", "tls"}: + raise ValueError("unsupported transport") return ostr # declarative config format description for always-present sections # dict structure: dict[section name][key name] = (type, required) _CFGFMT = { - 'sendrecv': { - 'timeout': (float, True), - 'jobs': (int, True), - 'time_delay_min': (float, True), - 'time_delay_max': (float, True), - 'max_timeouts': (int, False), + "sendrecv": { + "timeout": (float, True), + "jobs": (int, True), + "time_delay_min": (float, True), + "time_delay_max": (float, True), + "max_timeouts": (int, False), }, - 'servers': { - 'names': (comma_list, True), + "servers": { + "names": (comma_list, True), }, - 'diff': { - 'target': (str, True), - 'criteria': (comma_list, True), + "diff": { + "target": (str, True), + "criteria": (comma_list, True), }, - 'report': { - 'field_weights': (comma_list, True), + "report": { + "field_weights": (comma_list, True), }, } # declarative config format description for per-server section # dict structure: dict[key name] = type _CFGFMT_SERVER = { - 'ip': (ipaddr_check, True), - 'port': (int, True), - 'transport': (transport_opt, True), - 'graph_color': (str, False), - 'restart_script': (str, False), + "ip": (ipaddr_check, True), + "port": (int, True), + "transport": (transport_opt, True), + "graph_color": (str, False), + "restart_script": (str, False), } @@ -87,20 +99,25 @@ def cfg2dict_convert(fmt, cparser): for valname, (valfmt, valreq) in sectfmt.items(): try: if not cparser[sectname][valname].strip(): - raise ValueError('empty values are not allowed') + raise ValueError("empty values are not allowed") sectdict[valname] = valfmt(cparser[sectname][valname]) except ValueError as ex: - raise ValueError('config section [{}] key "{}" has invalid format: ' - '{}; expected format: {}'.format( - sectname, valname, ex, valfmt)) from ex + raise ValueError( + 'config section [{}] key "{}" has invalid format: ' + "{}; expected format: {}".format(sectname, valname, ex, valfmt) + ) from ex except KeyError as ex: if valreq: - raise KeyError('config section [{}] key "{}" not found'.format( - sectname, valname)) from ex + raise KeyError( + 'config section [{}] key "{}" not found'.format( + sectname, valname + ) + ) from ex unsupported_keys = set(cparser[sectname].keys()) - set(sectfmt.keys()) if unsupported_keys: - raise ValueError('unexpected keys {} in section [{}]'.format( - unsupported_keys, sectname)) + raise ValueError( + "unexpected keys {} in section [{}]".format(unsupported_keys, sectname) + ) return cdict @@ -109,38 +126,53 @@ def cfg2dict_check_sect(fmt, cfg): Check non-existence of unhandled config sections. """ supported_sections = set(fmt.keys()) - present_sections = set(cfg.keys()) - {'DEFAULT'} + present_sections = set(cfg.keys()) - {"DEFAULT"} unsupported_sections = present_sections - supported_sections if unsupported_sections: - raise ValueError('unexpected config sections {}'.format( - ', '.join('[{}]'.format(sn) for sn in unsupported_sections))) + raise ValueError( + "unexpected config sections {}".format( + ", ".join("[{}]".format(sn) for sn in unsupported_sections) + ) + ) def cfg2dict_check_diff(cdict): """ Check if diff target is listed among servers. """ - if cdict['diff']['target'] not in cdict['servers']['names']: - raise ValueError('[diff] target value "{}" must be listed in [servers] names'.format( - cdict['diff']['target'])) + if cdict["diff"]["target"] not in cdict["servers"]["names"]: + raise ValueError( + '[diff] target value "{}" must be listed in [servers] names'.format( + cdict["diff"]["target"] + ) + ) def cfg2dict_check_fields(cdict): """Check if all fields are known and that all have a weight assigned""" - unknown_criteria = set(cdict['diff']['criteria']) - ALL_FIELDS_SET + unknown_criteria = set(cdict["diff"]["criteria"]) - ALL_FIELDS_SET if unknown_criteria: - raise ValueError('[diff] criteria: unknown fields: {}'.format( - ', '.join(['"{}"'.format(field) for field in unknown_criteria]))) + raise ValueError( + "[diff] criteria: unknown fields: {}".format( + ", ".join(['"{}"'.format(field) for field in unknown_criteria]) + ) + ) - unknown_field_weights = set(cdict['report']['field_weights']) - ALL_FIELDS_SET + unknown_field_weights = set(cdict["report"]["field_weights"]) - ALL_FIELDS_SET if unknown_field_weights: - raise ValueError('[report] field_weights: unknown fields: {}'.format( - ', '.join(['"{}"'.format(field) for field in unknown_field_weights]))) + raise ValueError( + "[report] field_weights: unknown fields: {}".format( + ", ".join(['"{}"'.format(field) for field in unknown_field_weights]) + ) + ) - missing_field_weights = ALL_FIELDS_SET - set(cdict['report']['field_weights']) + missing_field_weights = ALL_FIELDS_SET - set(cdict["report"]["field_weights"]) if missing_field_weights: - raise ValueError('[report] field_weights: missing fields: {}'.format( - ', '.join(['"{}"'.format(field) for field in missing_field_weights]))) + raise ValueError( + "[report] field_weights: missing fields: {}".format( + ", ".join(['"{}"'.format(field) for field in missing_field_weights]) + ) + ) def read_cfg(filename): @@ -155,10 +187,11 @@ def read_cfg(filename): try: parser = configparser.ConfigParser( - delimiters='=', - comment_prefixes='#', + delimiters="=", + comment_prefixes="#", interpolation=None, - empty_lines_in_values=False) + empty_lines_in_values=False, + ) parser.read(filename) # parse things which must be present @@ -166,7 +199,7 @@ def read_cfg(filename): # parse variable server-specific data cfgfmt_servers = _CFGFMT.copy() - for server in cdict['servers']['names']: + for server in cdict["servers"]["names"]: cfgfmt_servers[server] = _CFGFMT_SERVER cdict = cfg2dict_convert(cfgfmt_servers, parser) @@ -177,13 +210,13 @@ def read_cfg(filename): # check fields (criteria, field_weights) cfg2dict_check_fields(cdict) except Exception as exc: - logging.critical('Failed to parse config: %s', exc) + logging.critical("Failed to parse config: %s", exc) raise ValueError(exc) from exc return cdict -if __name__ == '__main__': +if __name__ == "__main__": from pprint import pprint import sys diff --git a/respdiff/cli.py b/respdiff/cli.py index b9e7ed0..6e0a098 100644 --- a/respdiff/cli.py +++ b/respdiff/cli.py @@ -20,10 +20,10 @@ ChangeStatsTuple = Tuple[int, Optional[float], Optional[int], Optional[float]] ChangeStatsTupleStr = Tuple[int, Optional[float], Optional[str], Optional[float]] LOGGING_LEVEL = logging.DEBUG -CONFIG_FILENAME = 'respdiff.cfg' -REPORT_FILENAME = 'report.json' -STATS_FILENAME = 'stats.json' -DNSVIZ_FILENAME = 'dnsviz.json' +CONFIG_FILENAME = "respdiff.cfg" +REPORT_FILENAME = "report.json" +STATS_FILENAME = "stats.json" +DNSVIZ_FILENAME = "dnsviz.json" DEFAULT_PRINT_QUERY_LIMIT = 10 @@ -36,7 +36,7 @@ def read_stats(filename: str) -> SummaryStatistics: def _handle_empty_report(exc: Exception, skip_empty: bool): if skip_empty: - logging.debug('%s Omitting...', exc) + logging.debug("%s Omitting...", exc) else: logging.error(str(exc)) raise ValueError(exc) @@ -51,77 +51,107 @@ def read_report(filename: str, skip_empty: bool = False) -> Optional[DiffReport] def load_summaries( - reports: Sequence[DiffReport], - skip_empty: bool = False - ) -> Sequence[Summary]: + reports: Sequence[DiffReport], skip_empty: bool = False +) -> Sequence[Summary]: summaries = [] for report in reports: if report.summary is None: _handle_empty_report( ValueError('Empty diffsum in "{}"!'.format(report.fileorigin)), - skip_empty) + skip_empty, + ) else: summaries.append(report.summary) return summaries def setup_logging(level: int = LOGGING_LEVEL) -> None: - logging.basicConfig(format='%(asctime)s %(levelname)8s %(message)s', level=level) - logger = logging.getLogger('matplotlib') + logging.basicConfig(format="%(asctime)s %(levelname)8s %(message)s", level=level) + logger = logging.getLogger("matplotlib") # set WARNING for Matplotlib logger.setLevel(logging.WARNING) def add_arg_config(parser: ArgumentParser) -> None: - parser.add_argument('-c', '--config', type=read_cfg, - default=CONFIG_FILENAME, dest='cfg', - help='config file (default: {})'.format(CONFIG_FILENAME)) + parser.add_argument( + "-c", + "--config", + type=read_cfg, + default=CONFIG_FILENAME, + dest="cfg", + help="config file (default: {})".format(CONFIG_FILENAME), + ) def add_arg_envdir(parser: ArgumentParser) -> None: - parser.add_argument('envdir', type=str, - help='LMDB environment to read/write queries, answers and diffs') + parser.add_argument( + "envdir", + type=str, + help="LMDB environment to read/write queries, answers and diffs", + ) def add_arg_datafile(parser: ArgumentParser) -> None: - parser.add_argument('-d', '--datafile', type=str, - help='JSON report file (default: <envdir>/{})'.format( - REPORT_FILENAME)) + parser.add_argument( + "-d", + "--datafile", + type=str, + help="JSON report file (default: <envdir>/{})".format(REPORT_FILENAME), + ) def add_arg_limit(parser: ArgumentParser) -> None: - parser.add_argument('-l', '--limit', type=int, - default=DEFAULT_PRINT_QUERY_LIMIT, - help='number of displayed mismatches in fields (default: {}; ' - 'use 0 to display all)'.format(DEFAULT_PRINT_QUERY_LIMIT)) + parser.add_argument( + "-l", + "--limit", + type=int, + default=DEFAULT_PRINT_QUERY_LIMIT, + help="number of displayed mismatches in fields (default: {}; " + "use 0 to display all)".format(DEFAULT_PRINT_QUERY_LIMIT), + ) def add_arg_stats(parser: ArgumentParser) -> None: - parser.add_argument('-s', '--stats', type=read_stats, - default=STATS_FILENAME, - help='statistics file (default: {})'.format(STATS_FILENAME)) - - -def add_arg_stats_filename(parser: ArgumentParser, default: str = STATS_FILENAME) -> None: - parser.add_argument('-s', '--stats', type=str, - default=default, dest='stats_filename', - help='statistics file (default: {})'.format(default)) + parser.add_argument( + "-s", + "--stats", + type=read_stats, + default=STATS_FILENAME, + help="statistics file (default: {})".format(STATS_FILENAME), + ) + + +def add_arg_stats_filename( + parser: ArgumentParser, default: str = STATS_FILENAME +) -> None: + parser.add_argument( + "-s", + "--stats", + type=str, + default=default, + dest="stats_filename", + help="statistics file (default: {})".format(default), + ) def add_arg_dnsviz(parser: ArgumentParser, default: str = DNSVIZ_FILENAME) -> None: - parser.add_argument('--dnsviz', type=str, default=default, - help='dnsviz grok output (default: {})'.format(default)) + parser.add_argument( + "--dnsviz", + type=str, + default=default, + help="dnsviz grok output (default: {})".format(default), + ) def add_arg_report(parser: ArgumentParser) -> None: - parser.add_argument('report', type=read_report, nargs='*', - help='JSON report file(s)') + parser.add_argument( + "report", type=read_report, nargs="*", help="JSON report file(s)" + ) -def add_arg_report_filename(parser: ArgumentParser, nargs='*') -> None: - parser.add_argument('report', type=str, nargs=nargs, - help='JSON report file(s)') +def add_arg_report_filename(parser: ArgumentParser, nargs="*") -> None: + parser.add_argument("report", type=str, nargs=nargs, help="JSON report file(s)") def get_reports_from_filenames(args: Namespace) -> Sequence[DiffReport]: @@ -133,7 +163,9 @@ def get_reports_from_filenames(args: Namespace) -> Sequence[DiffReport]: return reports -def get_datafile(args: Namespace, key: str = 'datafile', check_exists: bool = True) -> str: +def get_datafile( + args: Namespace, key: str = "datafile", check_exists: bool = True +) -> str: datafile = getattr(args, key, None) if datafile is None: datafile = os.path.join(args.envdir, REPORT_FILENAME) @@ -155,8 +187,8 @@ def check_metadb_servers_version(lmdb, servers: Sequence[str]) -> None: @contextlib.contextmanager def smart_open(filename=None): - if filename and filename != '-': - fh = open(filename, 'w', encoding='UTF-8') + if filename and filename != "-": + fh = open(filename, "w", encoding="UTF-8") else: fh = sys.stdout @@ -168,47 +200,47 @@ def smart_open(filename=None): def format_stats_line( - description: str, - number: int, - pct: float = None, - diff: int = None, - diff_pct: float = None, - additional: str = None - ) -> str: + description: str, + number: int, + pct: float = None, + diff: int = None, + diff_pct: float = None, + additional: str = None, +) -> str: s = {} # type: Dict[str, str] - s['description'] = '{:21s}'.format(description) - s['number'] = '{:8d}'.format(number) - s['pct'] = '{:6.2f} %'.format(pct) if pct is not None else ' ' * 8 - s['additional'] = '{:30s}'.format(additional) if additional is not None else ' ' * 30 - s['diff'] = '{:+6d}'.format(diff) if diff is not None else ' ' * 6 - s['diff_pct'] = '{:+7.2f} %'.format(diff_pct) if diff_pct is not None else ' ' * 9 + s["description"] = "{:21s}".format(description) + s["number"] = "{:8d}".format(number) + s["pct"] = "{:6.2f} %".format(pct) if pct is not None else " " * 8 + s["additional"] = ( + "{:30s}".format(additional) if additional is not None else " " * 30 + ) + s["diff"] = "{:+6d}".format(diff) if diff is not None else " " * 6 + s["diff_pct"] = "{:+7.2f} %".format(diff_pct) if diff_pct is not None else " " * 9 - return '{description} {number} {pct} {additional} {diff} {diff_pct}'.format(**s) + return "{description} {number} {pct} {additional} {diff} {diff_pct}".format(**s) def get_stats_data( - n: int, - total: int = None, - ref_n: int = None, - ) -> ChangeStatsTuple: + n: int, + total: int = None, + ref_n: int = None, +) -> ChangeStatsTuple: """ Return absolute and relative data statistics Optionally, the data is compared with a reference. """ - def percentage( - dividend: Number, - divisor: Optional[Number] - ) -> Optional[float]: + + def percentage(dividend: Number, divisor: Optional[Number]) -> Optional[float]: """Return dividend/divisor value in %""" if divisor is None: return None if divisor == 0: if dividend > 0: - return float('+inf') + return float("+inf") if dividend < 0: - return float('-inf') - return float('nan') + return float("-inf") + return float("nan") return dividend * 100.0 / divisor pct = percentage(n, total) @@ -223,25 +255,20 @@ def get_stats_data( def get_table_stats_row( - count: int, - total: int, - ref_count: Optional[int] = None - ) -> Union[StatsTuple, ChangeStatsTupleStr]: - n, pct, diff, diff_pct = get_stats_data( # type: ignore - count, - total, - ref_count) + count: int, total: int, ref_count: Optional[int] = None +) -> Union[StatsTuple, ChangeStatsTupleStr]: + n, pct, diff, diff_pct = get_stats_data(count, total, ref_count) # type: ignore if ref_count is None: return n, pct - s_diff = '{:+d}'.format(diff) if diff is not None else None + s_diff = "{:+d}".format(diff) if diff is not None else None return n, pct, s_diff, diff_pct def print_fields_overview( - field_counters: Mapping[FieldLabel, Counter], - n_disagreements: int, - ref_field_counters: Optional[Mapping[FieldLabel, Counter]] = None, - ) -> None: + field_counters: Mapping[FieldLabel, Counter], + n_disagreements: int, + ref_field_counters: Optional[Mapping[FieldLabel, Counter]] = None, +) -> None: rows = [] def get_field_count(counter: Counter) -> int: @@ -256,117 +283,164 @@ def print_fields_overview( if ref_field_counters is not None: ref_counter = ref_field_counters.get(field, Counter()) ref_field_count = get_field_count(ref_counter) - rows.append((field, *get_table_stats_row( - field_count, n_disagreements, ref_field_count))) + rows.append( + (field, *get_table_stats_row(field_count, n_disagreements, ref_field_count)) + ) - headers = ['Field', 'Count', '% of mismatches'] + headers = ["Field", "Count", "% of mismatches"] if ref_field_counters is not None: - headers.extend(['Change', 'Change (%)']) + headers.extend(["Change", "Change (%)"]) - print('== Target Disagreements') - print(tabulate( # type: ignore - sorted(rows, key=lambda data: data[1], reverse=True), # type: ignore - headers, - tablefmt='simple', - floatfmt=('s', 'd', '.2f', 's', '+.2f'))) # type: ignore - print('') + print("== Target Disagreements") + print( + tabulate( # type: ignore + sorted(rows, key=lambda data: data[1], reverse=True), # type: ignore + headers, + tablefmt="simple", + floatfmt=("s", "d", ".2f", "s", "+.2f"), + ) + ) # type: ignore + print("") def print_field_mismatch_stats( - field: FieldLabel, - counter: Counter, - n_disagreements: int, - ref_counter: Counter = None - ) -> None: + field: FieldLabel, + counter: Counter, + n_disagreements: int, + ref_counter: Counter = None, +) -> None: rows = [] ref_count = None for mismatch, count in counter.items(): if ref_counter is not None: ref_count = ref_counter[mismatch] - rows.append(( - DataMismatch.format_value(mismatch.exp_val), - DataMismatch.format_value(mismatch.got_val), - *get_table_stats_row( - count, n_disagreements, ref_count))) - - headers = ['Expected', 'Got', 'Count', '% of mimatches'] + rows.append( + ( + DataMismatch.format_value(mismatch.exp_val), + DataMismatch.format_value(mismatch.got_val), + *get_table_stats_row(count, n_disagreements, ref_count), + ) + ) + + headers = ["Expected", "Got", "Count", "% of mimatches"] if ref_counter is not None: - headers.extend(['Change', 'Change (%)']) + headers.extend(["Change", "Change (%)"]) print('== Field "{}" mismatch statistics'.format(field)) - print(tabulate( # type: ignore - sorted(rows, key=lambda data: data[2], reverse=True), # type: ignore - headers, - tablefmt='simple', - floatfmt=('s', 's', 'd', '.2f', 's', '+.2f'))) # type: ignore - print('') + print( + tabulate( # type: ignore + sorted(rows, key=lambda data: data[2], reverse=True), # type: ignore + headers, + tablefmt="simple", + floatfmt=("s", "s", "d", ".2f", "s", "+.2f"), + ) + ) # type: ignore + print("") def print_global_stats(report: DiffReport, reference: DiffReport = None) -> None: - ref_duration = getattr(reference, 'duration', None) - ref_total_answers = getattr(reference, 'total_answers', None) - ref_total_queries = getattr(reference, 'total_queries', None) - - if (report.duration is None - or report.total_answers is None - or report.total_queries is None): + ref_duration = getattr(reference, "duration", None) + ref_total_answers = getattr(reference, "total_answers", None) + ref_total_queries = getattr(reference, "total_queries", None) + + if ( + report.duration is None + or report.total_answers is None + or report.total_queries is None + ): raise RuntimeError("Report doesn't containt necassary data!") - print('== Global statistics') - print(format_stats_line('duration', *get_stats_data( - report.duration, ref_n=ref_duration), - additional='seconds')) - print(format_stats_line('queries', *get_stats_data( - report.total_queries, ref_n=ref_total_queries))) - print(format_stats_line('answers', *get_stats_data( - report.total_answers, report.total_queries, - ref_total_answers), - additional='of queries')) - print('') + print("== Global statistics") + print( + format_stats_line( + "duration", + *get_stats_data(report.duration, ref_n=ref_duration), + additional="seconds" + ) + ) + print( + format_stats_line( + "queries", *get_stats_data(report.total_queries, ref_n=ref_total_queries) + ) + ) + print( + format_stats_line( + "answers", + *get_stats_data( + report.total_answers, report.total_queries, ref_total_answers + ), + additional="of queries" + ) + ) + print("") def print_differences_stats(report: DiffReport, reference: DiffReport = None) -> None: - ref_summary = getattr(reference, 'summary', None) - ref_manual_ignore = getattr(ref_summary, 'manual_ignore', None) - ref_upstream_unstable = getattr(ref_summary, 'upstream_unstable', None) - ref_not_reproducible = getattr(ref_summary, 'not_reproducible', None) + ref_summary = getattr(reference, "summary", None) + ref_manual_ignore = getattr(ref_summary, "manual_ignore", None) + ref_upstream_unstable = getattr(ref_summary, "upstream_unstable", None) + ref_not_reproducible = getattr(ref_summary, "not_reproducible", None) ref_target_disagrees = len(ref_summary) if ref_summary is not None else None if report.summary is None: raise RuntimeError("Report doesn't containt necassary data!") - print('== Differences statistics') - print(format_stats_line('manually ignored', *get_stats_data( - report.summary.manual_ignore, report.total_answers, - ref_manual_ignore), - additional='of answers (ignoring)')) - print(format_stats_line('upstream unstable', *get_stats_data( - report.summary.upstream_unstable, report.total_answers, - ref_upstream_unstable), - additional='of answers (ignoring)')) - print(format_stats_line('not 100% reproducible', *get_stats_data( - report.summary.not_reproducible, report.total_answers, - ref_not_reproducible), - additional='of answers (ignoring)')) - print(format_stats_line('target disagrees', *get_stats_data( - len(report.summary), report.summary.usable_answers, - ref_target_disagrees), - additional='of not ignored answers')) - print('') + print("== Differences statistics") + print( + format_stats_line( + "manually ignored", + *get_stats_data( + report.summary.manual_ignore, report.total_answers, ref_manual_ignore + ), + additional="of answers (ignoring)" + ) + ) + print( + format_stats_line( + "upstream unstable", + *get_stats_data( + report.summary.upstream_unstable, + report.total_answers, + ref_upstream_unstable, + ), + additional="of answers (ignoring)" + ) + ) + print( + format_stats_line( + "not 100% reproducible", + *get_stats_data( + report.summary.not_reproducible, + report.total_answers, + ref_not_reproducible, + ), + additional="of answers (ignoring)" + ) + ) + print( + format_stats_line( + "target disagrees", + *get_stats_data( + len(report.summary), report.summary.usable_answers, ref_target_disagrees + ), + additional="of not ignored answers" + ) + ) + print("") def print_mismatch_queries( - field: FieldLabel, - mismatch: DataMismatch, - queries: Sequence[Tuple[str, int, str]], - limit: Optional[int] = DEFAULT_PRINT_QUERY_LIMIT - ) -> None: + field: FieldLabel, + mismatch: DataMismatch, + queries: Sequence[Tuple[str, int, str]], + limit: Optional[int] = DEFAULT_PRINT_QUERY_LIMIT, +) -> None: if limit == 0: limit = None def sort_key(data: Tuple[str, int, str]) -> Tuple[int, int]: - order = ['+', ' ', '-'] + order = ["+", " ", "-"] try: return order.index(data[0]), -data[1] except ValueError: @@ -379,13 +453,10 @@ def print_mismatch_queries( to_print = to_print[:limit] print('== Field "{}", mismatch "{}" query details'.format(field, mismatch)) - print(format_line('', 'Count', 'Query')) + print(format_line("", "Count", "Query")) for diff, count, query in to_print: print(format_line(diff, str(count), query)) if limit is not None and limit < len(queries): - print(format_line( - 'x', - str(len(queries) - limit), - 'queries omitted')) - print('') + print(format_line("x", str(len(queries) - limit), "queries omitted")) + print("") diff --git a/respdiff/database.py b/respdiff/database.py index 13abc18..bbce4e6 100644 --- a/respdiff/database.py +++ b/respdiff/database.py @@ -4,7 +4,16 @@ import os import struct import time from typing import ( # noqa - Any, Callable, Dict, Iterator, List, Mapping, Optional, Tuple, Sequence) + Any, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, + Tuple, + Sequence, +) import dns.exception import dns.message @@ -13,53 +22,56 @@ import lmdb from .typing import ResolverID, QID, QKey, WireFormat -BIN_FORMAT_VERSION = '2018-05-21' +BIN_FORMAT_VERSION = "2018-05-21" def qid2key(qid: QID) -> QKey: - return struct.pack('<I', qid) + return struct.pack("<I", qid) def key2qid(key: QKey) -> QID: - return struct.unpack('<I', key)[0] + return struct.unpack("<I", key)[0] class LMDB: - ANSWERS = b'answers' - DIFFS = b'diffs' - QUERIES = b'queries' - META = b'meta' + ANSWERS = b"answers" + DIFFS = b"diffs" + QUERIES = b"queries" + META = b"meta" ENV_DEFAULTS = { - 'map_size': 10 * 1024**3, # 10 G - 'max_readers': 384, - 'max_dbs': 5, - 'max_spare_txns': 64, + "map_size": 10 * 1024**3, # 10 G + "max_readers": 384, + "max_dbs": 5, + "max_spare_txns": 64, } # type: Dict[str, Any] DB_OPEN_DEFAULTS = { - 'integerkey': False, + "integerkey": False, # surprisingly, optimal configuration seems to be # native integer as database key *without* # integerkey support in LMDB } # type: Dict[str, Any] - def __init__(self, path: str, create: bool = False, - readonly: bool = False, fast: bool = False) -> None: + def __init__( + self, + path: str, + create: bool = False, + readonly: bool = False, + fast: bool = False, + ) -> None: self.path = path self.dbs = {} # type: Dict[bytes, Any] self.config = LMDB.ENV_DEFAULTS.copy() - self.config.update({ - 'path': path, - 'create': create, - 'readonly': readonly - }) + self.config.update({"path": path, "create": create, "readonly": readonly}) if fast: # unsafe on crashes, but faster - self.config.update({ - 'writemap': True, - 'sync': False, - 'map_async': True, - }) + self.config.update( + { + "writemap": True, + "sync": False, + "map_async": True, + } + ) if not os.path.exists(self.path): os.makedirs(self.path) @@ -71,17 +83,25 @@ class LMDB: def __exit__(self, exc_type, exc_val, exc_tb): self.env.close() - def open_db(self, dbname: bytes, create: bool = False, - check_notexists: bool = False, drop: bool = False): + def open_db( + self, + dbname: bytes, + create: bool = False, + check_notexists: bool = False, + drop: bool = False, + ): assert self.env is not None, "LMDB wasn't initialized!" if not create and not self.exists_db(dbname): msg = 'LMDB environment "{}" does not contain DB "{}"! '.format( - self.path, dbname.decode('utf-8')) + self.path, dbname.decode("utf-8") + ) raise RuntimeError(msg) if check_notexists and self.exists_db(dbname): - msg = ('LMDB environment "{}" already contains DB "{}"! ' - 'Overwritting it would invalidate data in the environment, ' - 'terminating.').format(self.path, dbname.decode('utf-8')) + msg = ( + 'LMDB environment "{}" already contains DB "{}"! ' + "Overwritting it would invalidate data in the environment, " + "terminating." + ).format(self.path, dbname.decode("utf-8")) raise RuntimeError(msg) if drop: try: @@ -106,7 +126,9 @@ class LMDB: try: return self.dbs[dbname] except KeyError as e: - raise ValueError("Database {} isn't open!".format(dbname.decode('utf-8'))) from e + raise ValueError( + "Database {} isn't open!".format(dbname.decode("utf-8")) + ) from e def key_stream(self, dbname: bytes) -> Iterator[bytes]: """yield all keys from given db""" @@ -129,70 +151,69 @@ class DNSReply: TIMEOUT_INT = 4294967295 SIZEOF_INT = 4 SIZEOF_SHORT = 2 - WIREFORMAT_VALID = 'Valid' + WIREFORMAT_VALID = "Valid" def __init__(self, wire: Optional[WireFormat], time_: float = 0) -> None: if wire is None: - self.wire = b'' - self.time = float('+inf') + self.wire = b"" + self.time = float("+inf") else: self.wire = wire self.time = time_ @property def timeout(self) -> bool: - return self.time == float('+inf') + return self.time == float("+inf") def __eq__(self, other) -> bool: if self.timeout and other.timeout: return True # float equality comparison: use 10^-7 tolerance since it's less than available # resoltuion from the time_int integer value (which is 10^-6) - return self.wire == other.wire and \ - abs(self.time - other.time) < 10 ** -7 + return self.wire == other.wire and abs(self.time - other.time) < 10**-7 @property def time_int(self) -> int: - if self.time == float('+inf'): + if self.time == float("+inf"): return self.TIMEOUT_INT - value = round(self.time * (10 ** 6)) + value = round(self.time * (10**6)) if value > self.TIMEOUT_INT: raise ValueError( 'Maximum time value exceeded: (value: "{}", max: {})'.format( - value, self.TIMEOUT_INT)) + value, self.TIMEOUT_INT + ) + ) return value @property def binary(self) -> bytes: length = len(self.wire) - return struct.pack('<I', self.time_int) + struct.pack('<H', length) + self.wire + return struct.pack("<I", self.time_int) + struct.pack("<H", length) + self.wire @classmethod - def from_binary(cls, buff: bytes) -> Tuple['DNSReply', bytes]: + def from_binary(cls, buff: bytes) -> Tuple["DNSReply", bytes]: if len(buff) < (cls.SIZEOF_INT + cls.SIZEOF_SHORT): - raise ValueError('Missing data in binary format') + raise ValueError("Missing data in binary format") offset = 0 - time_int, = struct.unpack_from('<I', buff, offset) + (time_int,) = struct.unpack_from("<I", buff, offset) offset += cls.SIZEOF_INT - length, = struct.unpack_from('<H', buff, offset) + (length,) = struct.unpack_from("<H", buff, offset) offset += cls.SIZEOF_SHORT - wire = buff[offset:(offset+length)] + wire = buff[offset : (offset + length)] offset += length if len(wire) != length: - raise ValueError('Missing data in binary format') + raise ValueError("Missing data in binary format") if time_int == cls.TIMEOUT_INT: - time_ = float('+inf') + time_ = float("+inf") else: - time_ = time_int / (10 ** 6) + time_ = time_int / (10**6) reply = DNSReply(wire, time_) return reply, buff[offset:] - def parse_wire( - self - ) -> Tuple[Optional[dns.message.Message], str]: + def parse_wire(self) -> Tuple[Optional[dns.message.Message], str]: try: return dns.message.from_wire(self.wire), self.WIREFORMAT_VALID except dns.exception.FormError as exc: @@ -201,9 +222,10 @@ class DNSReply: class DNSRepliesFactory: """Thread-safe factory to parse DNSReply objects from binary blob.""" + def __init__(self, servers: Sequence[ResolverID]) -> None: if not servers: - raise ValueError('One or more servers have to be specified') + raise ValueError("One or more servers have to be specified") self.servers = servers def parse(self, buff: bytes) -> Dict[ResolverID, DNSReply]: @@ -212,12 +234,12 @@ class DNSRepliesFactory: reply, buff = DNSReply.from_binary(buff) replies[server] = reply if buff: - raise ValueError('Trailing data in buffer') + raise ValueError("Trailing data in buffer") return replies def serialize(self, replies: Mapping[ResolverID, DNSReply]) -> bytes: if len(replies) > len(self.servers): - raise ValueError('Extra unexpected data to serialize!') + raise ValueError("Extra unexpected data to serialize!") data = [] for server in self.servers: try: @@ -226,11 +248,11 @@ class DNSRepliesFactory: raise ValueError('Missing reply for server "{}"!'.format(server)) from e else: data.append(reply.binary) - return b''.join(data) + return b"".join(data) class Database(ABC): - DB_NAME = b'' + DB_NAME = b"" def __init__(self, lmdb_, create: bool = False) -> None: self.lmdb = lmdb_ @@ -242,14 +264,16 @@ class Database(ABC): # ensure database is open if self.db is None: if not self.DB_NAME: - raise RuntimeError('No database to initialize!') + raise RuntimeError("No database to initialize!") try: self.db = self.lmdb.get_db(self.DB_NAME) except ValueError: try: self.db = self.lmdb.open_db(self.DB_NAME, create=self.create) except lmdb.Error as exc: - raise RuntimeError('Failed to open LMDB database: {}'.format(exc)) from exc + raise RuntimeError( + "Failed to open LMDB database: {}".format(exc) + ) from exc with self.lmdb.env.begin(self.db, write=write) as txn: yield txn @@ -258,8 +282,11 @@ class Database(ABC): with self.transaction() as txn: data = txn.get(key) if data is None: - raise KeyError("Missing '{}' key in '{}' database!".format( - key.decode('ascii'), self.DB_NAME.decode('ascii'))) + raise KeyError( + "Missing '{}' key in '{}' database!".format( + key.decode("ascii"), self.DB_NAME.decode("ascii") + ) + ) return data def write_key(self, key: bytes, value: bytes) -> None: @@ -269,18 +296,15 @@ class Database(ABC): class MetaDatabase(Database): DB_NAME = LMDB.META - KEY_VERSION = b'version' - KEY_START_TIME = b'start_time' - KEY_END_TIME = b'end_time' - KEY_SERVERS = b'servers' - KEY_NAME = b'name' + KEY_VERSION = b"version" + KEY_START_TIME = b"start_time" + KEY_END_TIME = b"end_time" + KEY_SERVERS = b"servers" + KEY_NAME = b"name" def __init__( - self, - lmdb_, - servers: Sequence[ResolverID], - create: bool = False - ) -> None: + self, lmdb_, servers: Sequence[ResolverID], create: bool = False + ) -> None: super().__init__(lmdb_, create) if create: self.write_servers(servers) @@ -291,38 +315,41 @@ class MetaDatabase(Database): def read_servers(self) -> List[ResolverID]: servers = [] ndata = self.read_key(self.KEY_SERVERS) - n, = struct.unpack('<I', ndata) + (n,) = struct.unpack("<I", ndata) for i in range(n): - key = self.KEY_NAME + str(i).encode('ascii') + key = self.KEY_NAME + str(i).encode("ascii") server = self.read_key(key) - servers.append(server.decode('ascii')) + servers.append(server.decode("ascii")) return servers def write_servers(self, servers: Sequence[ResolverID]) -> None: if not servers: raise ValueError("Empty list of servers!") - n = struct.pack('<I', len(servers)) + n = struct.pack("<I", len(servers)) self.write_key(self.KEY_SERVERS, n) for i, server in enumerate(servers): - key = self.KEY_NAME + str(i).encode('ascii') - self.write_key(key, server.encode('ascii')) + key = self.KEY_NAME + str(i).encode("ascii") + self.write_key(key, server.encode("ascii")) def check_servers(self, servers: Sequence[ResolverID]) -> None: db_servers = self.read_servers() if not servers == db_servers: raise NotImplementedError( 'Servers defined in config differ from the ones in "meta" database! ' - '(config: "{}", meta db: "{}")'.format(servers, db_servers)) + '(config: "{}", meta db: "{}")'.format(servers, db_servers) + ) def write_version(self) -> None: - self.write_key(self.KEY_VERSION, BIN_FORMAT_VERSION.encode('ascii')) + self.write_key(self.KEY_VERSION, BIN_FORMAT_VERSION.encode("ascii")) def check_version(self) -> None: - version = self.read_key(self.KEY_VERSION).decode('ascii') + version = self.read_key(self.KEY_VERSION).decode("ascii") if version != BIN_FORMAT_VERSION: raise NotImplementedError( 'LMDB version mismatch! (expected "{}", got "{}")'.format( - BIN_FORMAT_VERSION, version)) + BIN_FORMAT_VERSION, version + ) + ) def write_start_time(self, timestamp: Optional[int] = None) -> None: self._write_timestamp(self.KEY_START_TIME, timestamp) @@ -342,10 +369,10 @@ class MetaDatabase(Database): except KeyError: return None else: - return struct.unpack('<I', data)[0] + return struct.unpack("<I", data)[0] def _write_timestamp(self, key: bytes, timestamp: Optional[int]) -> None: if timestamp is None: timestamp = round(time.time()) - data = struct.pack('<I', timestamp) + data = struct.pack("<I", timestamp) self.write_key(key, data) diff --git a/respdiff/dataformat.py b/respdiff/dataformat.py index 19d20e7..affab27 100644 --- a/respdiff/dataformat.py +++ b/respdiff/dataformat.py @@ -4,8 +4,20 @@ from collections import Counter import collections.abc import json from typing import ( # noqa - Any, Callable, Dict, Hashable, ItemsView, Iterator, KeysView, Mapping, - Optional, Set, Sequence, Tuple, Union) + Any, + Callable, + Dict, + Hashable, + ItemsView, + Iterator, + KeysView, + Mapping, + Optional, + Set, + Sequence, + Tuple, + Union, +) from .match import DataMismatch from .typing import FieldLabel, QID @@ -20,23 +32,26 @@ class InvalidFileFormat(Exception): class JSONDataObject: """Object class for (de)serialization into JSON-compatible dictionary.""" + _ATTRIBUTES = {} # type: Mapping[str, Tuple[RestoreFunction, SaveFunction]] def __init__(self, **kwargs): # pylint: disable=unused-argument - self.fileorigin = '' + self.fileorigin = "" def export_json(self, filename: str) -> None: json_string = json.dumps(self.save(), indent=2) - with open(filename, 'w', encoding='UTF-8') as f: + with open(filename, "w", encoding="UTF-8") as f: f.write(json_string) @classmethod def from_json(cls, filename: str): try: - with open(filename, encoding='UTF-8') as f: + with open(filename, encoding="UTF-8") as f: restore_dict = json.load(f) except json.decoder.JSONDecodeError as e: - raise InvalidFileFormat("Couldn't parse JSON file: {}".format(filename)) from e + raise InvalidFileFormat( + "Couldn't parse JSON file: {}".format(filename) + ) from e inst = cls(_restore_dict=restore_dict) inst.fileorigin = filename return inst @@ -54,11 +69,11 @@ class JSONDataObject: return restore_dict def _restore_attr( - self, - restore_dict: Mapping[str, Any], - key: str, - restore_func: RestoreFunction = None - ) -> None: + self, + restore_dict: Mapping[str, Any], + key: str, + restore_func: RestoreFunction = None, + ) -> None: """ Restore attribute from key in dictionary. If it's missing or None, don't call restore_func() and leave attribute's value default. @@ -72,11 +87,7 @@ class JSONDataObject: value = restore_func(value) setattr(self, key, value) - def _save_attr( - self, - key: str, - save_func: SaveFunction = None - ) -> Mapping[str, Any]: + def _save_attr(self, key: str, save_func: SaveFunction = None) -> Mapping[str, Any]: """ Save attribute as a key in dictionary. If the attribute is None, save it as such (without calling save_func()). @@ -89,6 +100,7 @@ class JSONDataObject: class Diff(collections.abc.Mapping): """Read-only representation of mismatches in each field for a single query""" + __setitem__ = None __delitem__ = None @@ -107,19 +119,17 @@ class Diff(collections.abc.Mapping): return iter(self._mismatches) def get_significant_field( - self, - field_weights: Sequence[FieldLabel] - ) -> Tuple[Optional[FieldLabel], Optional[DataMismatch]]: + self, field_weights: Sequence[FieldLabel] + ) -> Tuple[Optional[FieldLabel], Optional[DataMismatch]]: for significant_field in field_weights: if significant_field in self: return significant_field, self[significant_field] return None, None def __repr__(self) -> str: - return 'Diff({})'.format( - ', '.join([ - repr(mismatch) for mismatch in self.values() - ])) + return "Diff({})".format( + ", ".join([repr(mismatch) for mismatch in self.values()]) + ) def __eq__(self, other) -> bool: if len(self) != len(other): @@ -142,9 +152,9 @@ class Disagreements(collections.abc.Mapping, JSONDataObject): """ def __init__( - self, - _restore_dict: Optional[Mapping[str, Any]] = None, - ) -> None: + self, + _restore_dict: Optional[Mapping[str, Any]] = None, + ) -> None: """ `_restore_dict` is used to restore from JSON, minimal format: "fields": { @@ -161,40 +171,44 @@ class Disagreements(collections.abc.Mapping, JSONDataObject): """ super().__init__() self._fields = collections.defaultdict( - lambda: collections.defaultdict(set) - ) # type: Dict[FieldLabel, Dict[DataMismatch, Set[QID]]] + lambda: collections.defaultdict(set) + ) # type: Dict[FieldLabel, Dict[DataMismatch, Set[QID]]] if _restore_dict is not None: self.restore(_restore_dict) def restore(self, restore_dict: Mapping[str, Any]) -> None: super().restore(restore_dict) - for field_label, field_data in restore_dict['fields'].items(): - for mismatch_data in field_data['mismatches']: + for field_label, field_data in restore_dict["fields"].items(): + for mismatch_data in field_data["mismatches"]: mismatch = DataMismatch( - mismatch_data['exp_val'], - mismatch_data['got_val']) - self._fields[field_label][mismatch] = set(mismatch_data['queries']) + mismatch_data["exp_val"], mismatch_data["got_val"] + ) + self._fields[field_label][mismatch] = set(mismatch_data["queries"]) def save(self) -> Dict[str, Any]: fields = {} for field, field_data in self._fields.items(): mismatches = [] for mismatch, mismatch_data in field_data.items(): - mismatches.append({ - 'count': len(mismatch_data), - 'exp_val': mismatch.exp_val, - 'got_val': mismatch.got_val, - 'queries': list(mismatch_data), - }) + mismatches.append( + { + "count": len(mismatch_data), + "exp_val": mismatch.exp_val, + "got_val": mismatch.got_val, + "queries": list(mismatch_data), + } + ) fields[field] = { - 'count': len(mismatches), - 'mismatches': mismatches, + "count": len(mismatches), + "mismatches": mismatches, } restore_dict = super().save() or {} - restore_dict.update({ - 'count': self.count, - 'fields': fields, - }) + restore_dict.update( + { + "count": self.count, + "fields": fields, + } + ) return restore_dict def add_mismatch(self, field: FieldLabel, mismatch: DataMismatch, qid: QID) -> None: @@ -205,9 +219,8 @@ class Disagreements(collections.abc.Mapping, JSONDataObject): return self._fields.keys() def get_field_mismatches( - self, - field: FieldLabel - ) -> ItemsView[DataMismatch, Set[QID]]: + self, field: FieldLabel + ) -> ItemsView[DataMismatch, Set[QID]]: return self._fields[field].items() @property @@ -239,7 +252,7 @@ class Disagreements(collections.abc.Mapping, JSONDataObject): class DisagreementsCounter(JSONDataObject): _ATTRIBUTES = { - 'queries': (set, list), + "queries": (set, list), } def __init__(self, _restore_dict: Mapping[str, int] = None) -> None: @@ -257,17 +270,17 @@ class DisagreementsCounter(JSONDataObject): class Summary(Disagreements): """Disagreements, where each query has no more than one mismatch.""" + _ATTRIBUTES = { - 'upstream_unstable': (None, None), - 'usable_answers': (None, None), - 'not_reproducible': (None, None), - 'manual_ignore': (None, None), + "upstream_unstable": (None, None), + "usable_answers": (None, None), + "not_reproducible": (None, None), + "manual_ignore": (None, None), } def __init__( - self, - _restore_dict: Optional[Mapping[FieldLabel, Mapping[str, Any]]] = None - ) -> None: + self, _restore_dict: Optional[Mapping[FieldLabel, Mapping[str, Any]]] = None + ) -> None: self.usable_answers = 0 self.upstream_unstable = 0 self.not_reproducible = 0 @@ -276,17 +289,17 @@ class Summary(Disagreements): def add_mismatch(self, field: FieldLabel, mismatch: DataMismatch, qid: QID) -> None: if qid in self.keys(): - raise ValueError('QID {} already exists in Summary!'.format(qid)) + raise ValueError("QID {} already exists in Summary!".format(qid)) self._fields[field][mismatch].add(qid) @staticmethod def from_report( - report: 'DiffReport', - field_weights: Sequence[FieldLabel], - reproducibility_threshold: float = 1, - without_diffrepro: bool = False, - ignore_qids: Optional[Set[QID]] = None - ) -> 'Summary': + report: "DiffReport", + field_weights: Sequence[FieldLabel], + reproducibility_threshold: float = 1, + without_diffrepro: bool = False, + ignore_qids: Optional[Set[QID]] = None, + ) -> "Summary": """ Get summary of disagreements above the specified reproduciblity threshold [0, 1]. @@ -294,9 +307,11 @@ class Summary(Disagreements): Optionally, provide a list of known unstable and/or failing QIDs which will be ignored. """ - if (report.other_disagreements is None - or report.target_disagreements is None - or report.total_answers is None): + if ( + report.other_disagreements is None + or report.target_disagreements is None + or report.total_answers is None + ): raise RuntimeError("Report has insufficient data to create Summary") if ignore_qids is None: @@ -315,7 +330,9 @@ class Summary(Disagreements): if reprocounter.retries != reprocounter.upstream_stable: summary.upstream_unstable += 1 continue # filter unstable upstream - reproducibility = float(reprocounter.verified) / reprocounter.retries + reproducibility = ( + float(reprocounter.verified) / reprocounter.retries + ) if reproducibility < reproducibility_threshold: summary.not_reproducible += 1 continue # filter less reproducible than threshold @@ -324,7 +341,8 @@ class Summary(Disagreements): summary.add_mismatch(field, mismatch, qid) summary.usable_answers = ( - report.total_answers - summary.upstream_unstable - summary.not_reproducible) + report.total_answers - summary.upstream_unstable - summary.not_reproducible + ) return summary def get_field_counters(self) -> Mapping[FieldLabel, Counter]: @@ -339,20 +357,23 @@ class Summary(Disagreements): class ReproCounter(JSONDataObject): _ATTRIBUTES = { - 'retries': (None, None), # total amount of attempts to reproduce - 'upstream_stable': (None, None), # number of cases, where others disagree - 'verified': (None, None), # the query fails, and the diff is same (reproduced) - 'different_failure': (None, None) # the query fails, but the diff doesn't match + "retries": (None, None), # total amount of attempts to reproduce + "upstream_stable": (None, None), # number of cases, where others disagree + "verified": (None, None), # the query fails, and the diff is same (reproduced) + "different_failure": ( + None, + None, + ), # the query fails, but the diff doesn't match } def __init__( - self, - retries: int = 0, - upstream_stable: int = 0, - verified: int = 0, - different_failure: int = 0, - _restore_dict: Optional[Mapping[str, int]] = None - ) -> None: + self, + retries: int = 0, + upstream_stable: int = 0, + verified: int = 0, + different_failure: int = 0, + _restore_dict: Optional[Mapping[str, int]] = None, + ) -> None: super().__init__() self.retries = retries self.upstream_stable = upstream_stable @@ -371,13 +392,16 @@ class ReproCounter(JSONDataObject): self.retries == other.retries and self.upstream_stable == other.upstream_stable and self.verified == other.verified - and self.different_failure == other.different_failure) + and self.different_failure == other.different_failure + ) class ReproData(collections.abc.Mapping, JSONDataObject): def __init__(self, _restore_dict: Optional[Mapping[str, Any]] = None) -> None: super().__init__() - self._counters = collections.defaultdict(ReproCounter) # type: Dict[QID, ReproCounter] + self._counters = collections.defaultdict( + ReproCounter + ) # type: Dict[QID, ReproCounter] if _restore_dict is not None: self.restore(_restore_dict) @@ -406,41 +430,41 @@ class ReproData(collections.abc.Mapping, JSONDataObject): yield from self._counters.keys() -QueryData = collections.namedtuple('QueryData', 'total, others_disagree, target_disagrees') +QueryData = collections.namedtuple( + "QueryData", "total, others_disagree, target_disagrees" +) class DiffReport(JSONDataObject): # pylint: disable=too-many-instance-attributes _ATTRIBUTES = { - 'start_time': (None, None), - 'end_time': (None, None), - 'total_queries': (None, None), - 'total_answers': (None, None), - 'other_disagreements': ( + "start_time": (None, None), + "end_time": (None, None), + "total_queries": (None, None), + "total_answers": (None, None), + "other_disagreements": ( lambda x: DisagreementsCounter(_restore_dict=x), - lambda x: x.save()), - 'target_disagreements': ( + lambda x: x.save(), + ), + "target_disagreements": ( lambda x: Disagreements(_restore_dict=x), - lambda x: x.save()), - 'summary': ( - lambda x: Summary(_restore_dict=x), - lambda x: x.save()), - 'reprodata': ( - lambda x: ReproData(_restore_dict=x), - lambda x: x.save()), + lambda x: x.save(), + ), + "summary": (lambda x: Summary(_restore_dict=x), lambda x: x.save()), + "reprodata": (lambda x: ReproData(_restore_dict=x), lambda x: x.save()), } def __init__( - self, - start_time: Optional[int] = None, - end_time: Optional[int] = None, - total_queries: Optional[int] = None, - total_answers: Optional[int] = None, - other_disagreements: Optional[DisagreementsCounter] = None, - target_disagreements: Optional[Disagreements] = None, - summary: Optional[Summary] = None, - reprodata: Optional[ReproData] = None, - _restore_dict: Optional[Mapping[str, Any]] = None - ) -> None: + self, + start_time: Optional[int] = None, + end_time: Optional[int] = None, + total_queries: Optional[int] = None, + total_answers: Optional[int] = None, + other_disagreements: Optional[DisagreementsCounter] = None, + target_disagreements: Optional[Disagreements] = None, + summary: Optional[Summary] = None, + reprodata: Optional[ReproData] = None, + _restore_dict: Optional[Mapping[str, Any]] = None, + ) -> None: super().__init__() self.start_time = start_time self.end_time = end_time diff --git a/respdiff/dnsviz.py b/respdiff/dnsviz.py index 7748512..010037c 100644 --- a/respdiff/dnsviz.py +++ b/respdiff/dnsviz.py @@ -6,19 +6,19 @@ import dns import dns.name -TYPES = 'A,AAAA,CNAME' -KEYS_ERROR = ['error', 'errors'] -KEYS_WARNING = ['warnings'] +TYPES = "A,AAAA,CNAME" +KEYS_ERROR = ["error", "errors"] +KEYS_WARNING = ["warnings"] class DnsvizDomainResult: def __init__( - self, - errors: Optional[Mapping[str, List]] = None, - warnings: Optional[Mapping[str, List]] = None, - ) -> None: + self, + errors: Optional[Mapping[str, List]] = None, + warnings: Optional[Mapping[str, List]] = None, + ) -> None: super().__init__() - self.errors = defaultdict(list) # type: Dict[str, List] + self.errors = defaultdict(list) # type: Dict[str, List] self.warnings = defaultdict(list) # type: Dict[str, List] if errors is not None: self.errors.update(errors) @@ -31,10 +31,10 @@ class DnsvizDomainResult: def _find_keys( - kv_iter: Iterable[Tuple[Any, Any]], - keys: Optional[List[Any]] = None, - path: Optional[List[Any]] = None - ) -> Iterator[Tuple[Any, Any]]: + kv_iter: Iterable[Tuple[Any, Any]], + keys: Optional[List[Any]] = None, + path: Optional[List[Any]] = None, +) -> Iterator[Tuple[Any, Any]]: assert isinstance(keys, list) if path is None: path = [] @@ -59,25 +59,26 @@ class DnsvizGrok(dict): domain = key_path[0] if domain not in self.domains: self.domains[domain] = DnsvizDomainResult() - path = '_'.join([str(value) for value in key_path]) + path = "_".join([str(value) for value in key_path]) if key_path[-1] in KEYS_ERROR: self.domains[domain].errors[path] += messages else: self.domains[domain].warnings[path] += messages @staticmethod - def from_json(filename: str) -> 'DnsvizGrok': - with open(filename, encoding='UTF-8') as f: + def from_json(filename: str) -> "DnsvizGrok": + with open(filename, encoding="UTF-8") as f: grok_data = json.load(f) if not isinstance(grok_data, dict): raise RuntimeError( - "File {} doesn't contain dnsviz grok json data".format(filename)) + "File {} doesn't contain dnsviz grok json data".format(filename) + ) return DnsvizGrok(grok_data) def error_domains(self) -> List[dns.name.Name]: err_domains = [] for domain, data in self.domains.items(): if data.is_error: - assert domain[-1] == '.' + assert domain[-1] == "." err_domains.append(dns.name.from_text(domain)) return err_domains diff --git a/respdiff/match.py b/respdiff/match.py index 80e429e..f1dac68 100644 --- a/respdiff/match.py +++ b/respdiff/match.py @@ -1,7 +1,15 @@ import collections import logging from typing import ( # noqa - Any, Dict, Hashable, Iterator, Mapping, Optional, Sequence, Tuple) + Any, + Dict, + Hashable, + Iterator, + Mapping, + Optional, + Sequence, + Tuple, +) import dns.rdatatype from dns.rrset import RRset @@ -21,8 +29,10 @@ class DataMismatch(Exception): if isinstance(val, dns.rrset.RRset): return str(val) logging.warning( - 'DataMismatch: unknown value type (%s), casting to str', type(val), - stack_info=True) + "DataMismatch: unknown value type (%s), casting to str", + type(val), + stack_info=True, + ) return str(val) exp_val = convert_val_type(exp_val) @@ -37,21 +47,23 @@ class DataMismatch(Exception): @staticmethod def format_value(value: MismatchValue) -> str: if isinstance(value, list): - value = ' '.join(value) + value = " ".join(value) return str(value) def __str__(self) -> str: return "expected '{}' got '{}'".format( - self.format_value(self.exp_val), - self.format_value(self.got_val)) + self.format_value(self.exp_val), self.format_value(self.got_val) + ) def __repr__(self) -> str: - return 'DataMismatch({}, {})'.format(self.exp_val, self.got_val) + return "DataMismatch({}, {})".format(self.exp_val, self.got_val) def __eq__(self, other) -> bool: - return (isinstance(other, DataMismatch) - and tuple(self.exp_val) == tuple(other.exp_val) - and tuple(self.got_val) == tuple(other.got_val)) + return ( + isinstance(other, DataMismatch) + and tuple(self.exp_val) == tuple(other.exp_val) + and tuple(self.got_val) == tuple(other.got_val) + ) @property def key(self) -> Tuple[Hashable, Hashable]: @@ -67,14 +79,14 @@ class DataMismatch(Exception): def compare_val(exp_val: MismatchValue, got_val: MismatchValue): - """ Compare values, throw exception if different. """ + """Compare values, throw exception if different.""" if exp_val != got_val: raise DataMismatch(str(exp_val), str(got_val)) return True def compare_rrs(expected: Sequence[RRset], got: Sequence[RRset]): - """ Compare lists of RR sets, throw exception if different. """ + """Compare lists of RR sets, throw exception if different.""" for rr in expected: if rr not in got: raise DataMismatch(expected, got) @@ -87,17 +99,17 @@ def compare_rrs(expected: Sequence[RRset], got: Sequence[RRset]): def compare_rrs_types( - exp_val: Sequence[RRset], - got_val: Sequence[RRset], - compare_rrsigs: bool): + exp_val: Sequence[RRset], got_val: Sequence[RRset], compare_rrsigs: bool +): """sets of RR types in both sections must match""" + def rr_ordering_key(rrset): return rrset.covers if compare_rrsigs else rrset.rdtype def key_to_text(rrtype): if not compare_rrsigs: return dns.rdatatype.to_text(rrtype) - return 'RRSIG(%s)' % dns.rdatatype.to_text(rrtype) + return "RRSIG(%s)" % dns.rdatatype.to_text(rrtype) def filter_by_rrsig(seq, rrsig): for el in seq: @@ -105,51 +117,60 @@ def compare_rrs_types( if el_rrsig == rrsig: yield el - exp_types = frozenset(rr_ordering_key(rrset) - for rrset in filter_by_rrsig(exp_val, compare_rrsigs)) - got_types = frozenset(rr_ordering_key(rrset) - for rrset in filter_by_rrsig(got_val, compare_rrsigs)) + exp_types = frozenset( + rr_ordering_key(rrset) for rrset in filter_by_rrsig(exp_val, compare_rrsigs) + ) + got_types = frozenset( + rr_ordering_key(rrset) for rrset in filter_by_rrsig(got_val, compare_rrsigs) + ) if exp_types != got_types: raise DataMismatch( tuple(key_to_text(i) for i in sorted(exp_types)), - tuple(key_to_text(i) for i in sorted(got_types))) + tuple(key_to_text(i) for i in sorted(got_types)), + ) def match_part( # pylint: disable=inconsistent-return-statements - exp_msg: dns.message.Message, - got_msg: dns.message.Message, - criteria: FieldLabel - ): - """ Compare scripted reply to given message using single criteria. """ - if criteria == 'opcode': + exp_msg: dns.message.Message, got_msg: dns.message.Message, criteria: FieldLabel +): + """Compare scripted reply to given message using single criteria.""" + if criteria == "opcode": return compare_val(exp_msg.opcode(), got_msg.opcode()) - elif criteria == 'flags': - return compare_val(dns.flags.to_text(exp_msg.flags), dns.flags.to_text(got_msg.flags)) - elif criteria == 'rcode': - return compare_val(dns.rcode.to_text(exp_msg.rcode()), dns.rcode.to_text(got_msg.rcode())) - elif criteria == 'question': + elif criteria == "flags": + return compare_val( + dns.flags.to_text(exp_msg.flags), dns.flags.to_text(got_msg.flags) + ) + elif criteria == "rcode": + return compare_val( + dns.rcode.to_text(exp_msg.rcode()), dns.rcode.to_text(got_msg.rcode()) + ) + elif criteria == "question": question_match = compare_rrs(exp_msg.question, got_msg.question) if not exp_msg.question: # 0 RRs, nothing else to compare return True - assert len(exp_msg.question) == 1, "multiple question in single DNS query unsupported" - case_match = compare_val(got_msg.question[0].name.labels, exp_msg.question[0].name.labels) + assert ( + len(exp_msg.question) == 1 + ), "multiple question in single DNS query unsupported" + case_match = compare_val( + got_msg.question[0].name.labels, exp_msg.question[0].name.labels + ) return question_match and case_match - elif criteria in ('answer', 'ttl'): + elif criteria in ("answer", "ttl"): return compare_rrs(exp_msg.answer, got_msg.answer) - elif criteria == 'answertypes': + elif criteria == "answertypes": return compare_rrs_types(exp_msg.answer, got_msg.answer, compare_rrsigs=False) - elif criteria == 'answerrrsigs': + elif criteria == "answerrrsigs": return compare_rrs_types(exp_msg.answer, got_msg.answer, compare_rrsigs=True) - elif criteria == 'authority': + elif criteria == "authority": return compare_rrs(exp_msg.authority, got_msg.authority) - elif criteria == 'additional': + elif criteria == "additional": return compare_rrs(exp_msg.additional, got_msg.additional) - elif criteria == 'edns': + elif criteria == "edns": if got_msg.edns != exp_msg.edns: raise DataMismatch(str(exp_msg.edns), str(got_msg.edns)) if got_msg.payload != exp_msg.payload: raise DataMismatch(str(exp_msg.payload), str(got_msg.payload)) - elif criteria == 'nsid': + elif criteria == "nsid": nsid_opt = None for opt in exp_msg.options: if opt.otype == dns.edns.NSID: @@ -159,23 +180,21 @@ def match_part( # pylint: disable=inconsistent-return-statements for opt in got_msg.options: if opt.otype == dns.edns.NSID: if not nsid_opt: - raise DataMismatch('', str(opt.data)) + raise DataMismatch("", str(opt.data)) if opt == nsid_opt: return True else: raise DataMismatch(str(nsid_opt.data), str(opt.data)) if nsid_opt: - raise DataMismatch(str(nsid_opt.data), '') + raise DataMismatch(str(nsid_opt.data), "") else: raise NotImplementedError('unknown match request "%s"' % criteria) def match( - expected: DNSReply, - got: DNSReply, - match_fields: Sequence[FieldLabel] - ) -> Iterator[Tuple[FieldLabel, DataMismatch]]: - """ Compare scripted reply to given message based on match criteria. """ + expected: DNSReply, got: DNSReply, match_fields: Sequence[FieldLabel] +) -> Iterator[Tuple[FieldLabel, DataMismatch]]: + """Compare scripted reply to given message based on match criteria.""" exp_msg, exp_res = expected.parse_wire() got_msg, got_res = got.parse_wire() exp_malformed = exp_res != DNSReply.WIREFORMAT_VALID @@ -183,15 +202,16 @@ def match( if expected.timeout or got.timeout: if not expected.timeout: - yield 'timeout', DataMismatch('answer', 'timeout') + yield "timeout", DataMismatch("answer", "timeout") if not got.timeout: - yield 'timeout', DataMismatch('timeout', 'answer') + yield "timeout", DataMismatch("timeout", "answer") elif exp_malformed or got_malformed: if exp_res == got_res: logging.warning( - 'match: DNS replies malformed in the same way! (%s)', exp_res) + "match: DNS replies malformed in the same way! (%s)", exp_res + ) else: - yield 'malformed', DataMismatch(exp_res, got_res) + yield "malformed", DataMismatch(exp_res, got_res) if expected.timeout or got.timeout or exp_malformed or got_malformed: return # don't attempt to match any other fields @@ -208,19 +228,19 @@ def match( def diff_pair( - answers: Mapping[ResolverID, DNSReply], - criteria: Sequence[FieldLabel], - name1: ResolverID, - name2: ResolverID - ) -> Iterator[Tuple[FieldLabel, DataMismatch]]: + answers: Mapping[ResolverID, DNSReply], + criteria: Sequence[FieldLabel], + name1: ResolverID, + name2: ResolverID, +) -> Iterator[Tuple[FieldLabel, DataMismatch]]: yield from match(answers[name1], answers[name2], criteria) def transitive_equality( - answers: Mapping[ResolverID, DNSReply], - criteria: Sequence[FieldLabel], - resolvers: Sequence[ResolverID] - ) -> bool: + answers: Mapping[ResolverID, DNSReply], + criteria: Sequence[FieldLabel], + resolvers: Sequence[ResolverID], +) -> bool: """ Compare answers from all resolvers. Optimization is based on transitivity of equivalence relation. @@ -228,16 +248,19 @@ def transitive_equality( assert len(resolvers) >= 2 res_a = resolvers[0] # compare all others to this resolver res_others = resolvers[1:] - return all(map( - lambda res_b: not any(diff_pair(answers, criteria, res_a, res_b)), - res_others)) + return all( + map( + lambda res_b: not any(diff_pair(answers, criteria, res_a, res_b)), + res_others, + ) + ) def compare( - answers: Mapping[ResolverID, DNSReply], - criteria: Sequence[FieldLabel], - target: ResolverID - ) -> Tuple[bool, Optional[Mapping[FieldLabel, DataMismatch]]]: + answers: Mapping[ResolverID, DNSReply], + criteria: Sequence[FieldLabel], + target: ResolverID, +) -> Tuple[bool, Optional[Mapping[FieldLabel, DataMismatch]]]: others = list(answers.keys()) try: others.remove(target) diff --git a/respdiff/qstats.py b/respdiff/qstats.py index 94a155a..33b895f 100644 --- a/respdiff/qstats.py +++ b/respdiff/qstats.py @@ -7,8 +7,12 @@ from .dataformat import DiffReport, JSONDataObject, QueryData from .typing import QID -UPSTREAM_UNSTABLE_THRESHOLD = 0.1 # consider query unstable when 10 % of results are unstable -ALLOWED_FAIL_THRESHOLD = 0.05 # ignore up to 5 % of FAIL results for a given query (as noise) +UPSTREAM_UNSTABLE_THRESHOLD = ( + 0.1 # consider query unstable when 10 % of results are unstable +) +ALLOWED_FAIL_THRESHOLD = ( + 0.05 # ignore up to 5 % of FAIL results for a given query (as noise) +) class QueryStatus(Enum): @@ -27,16 +31,16 @@ def get_query_status(query_data: QueryData) -> QueryStatus: class QueryStatistics(JSONDataObject): _ATTRIBUTES = { - 'failing': (set, list), - 'unstable': (set, list), + "failing": (set, list), + "unstable": (set, list), } def __init__( - self, - failing: Optional[Set[QID]] = None, - unstable: Optional[Set[QID]] = None, - _restore_dict: Optional[Mapping[str, Any]] = None - ) -> None: + self, + failing: Optional[Set[QID]] = None, + unstable: Optional[Set[QID]] = None, + _restore_dict: Optional[Mapping[str, Any]] = None, + ) -> None: super().__init__() self.failing = failing if failing is not None else set() self.unstable = unstable if unstable is not None else set() @@ -51,9 +55,9 @@ class QueryStatistics(JSONDataObject): self.unstable.add(qid) @staticmethod - def from_reports(reports: Sequence[DiffReport]) -> 'QueryStatistics': + def from_reports(reports: Sequence[DiffReport]) -> "QueryStatistics": """Create query statistics from multiple reports - usually used as a reference""" - others_disagree = collections.Counter() # type: collections.Counter + others_disagree = collections.Counter() # type: collections.Counter target_disagrees = collections.Counter() # type: collections.Counter reprodata_present = False @@ -68,7 +72,9 @@ class QueryStatistics(JSONDataObject): for qid in report.target_disagreements: target_disagrees[qid] += 1 if reprodata_present: - logging.warning("reprodata ignored when creating query stability statistics") + logging.warning( + "reprodata ignored when creating query stability statistics" + ) # evaluate total = len(reports) @@ -77,5 +83,6 @@ class QueryStatistics(JSONDataObject): suspect_queries.update(target_disagrees.keys()) for qid in suspect_queries: query_statistics.add_query( - qid, QueryData(total, others_disagree[qid], target_disagrees[qid])) + qid, QueryData(total, others_disagree[qid], target_disagrees[qid]) + ) return query_statistics diff --git a/respdiff/query.py b/respdiff/query.py index 99b52b6..9b6b0ef 100644 --- a/respdiff/query.py +++ b/respdiff/query.py @@ -9,10 +9,7 @@ from .database import LMDB, qid2key from .typing import QID, WireFormat -def get_query_iterator( - lmdb_, - qids: Iterable[QID] - ) -> Iterator[Tuple[QID, WireFormat]]: +def get_query_iterator(lmdb_, qids: Iterable[QID]) -> Iterator[Tuple[QID, WireFormat]]: qdb = lmdb_.get_db(LMDB.QUERIES) with lmdb_.env.begin(qdb) as txn: for qid in qids: @@ -25,9 +22,9 @@ def qwire_to_qname(qwire: WireFormat) -> str: try: qmsg = dns.message.from_wire(qwire) except dns.exception.DNSException as exc: - raise ValueError('unable to parse qname from wire format') from exc + raise ValueError("unable to parse qname from wire format") from exc if not qmsg.question: - raise ValueError('no qname in wire format') + raise ValueError("no qname in wire format") return qmsg.question[0].name @@ -36,12 +33,12 @@ def qwire_to_qname_qtype(qwire: WireFormat) -> str: try: qmsg = dns.message.from_wire(qwire) except dns.exception.DNSException as exc: - raise ValueError('unable to parse qname from wire format') from exc + raise ValueError("unable to parse qname from wire format") from exc if not qmsg.question: - raise ValueError('no qname in wire format') - return '{} {}'.format( - qmsg.question[0].name, - dns.rdatatype.to_text(qmsg.question[0].rdtype)) + raise ValueError("no qname in wire format") + return "{} {}".format( + qmsg.question[0].name, dns.rdatatype.to_text(qmsg.question[0].rdtype) + ) def qwire_to_msgid_qname_qtype(qwire: WireFormat) -> str: @@ -49,46 +46,47 @@ def qwire_to_msgid_qname_qtype(qwire: WireFormat) -> str: try: qmsg = dns.message.from_wire(qwire) except dns.exception.DNSException as exc: - raise ValueError('unable to parse qname from wire format') from exc + raise ValueError("unable to parse qname from wire format") from exc if not qmsg.question: - raise ValueError('no qname in wire format') - return '[{:05d}] {} {}'.format( - qmsg.id, - qmsg.question[0].name, - dns.rdatatype.to_text(qmsg.question[0].rdtype)) + raise ValueError("no qname in wire format") + return "[{:05d}] {} {}".format( + qmsg.id, qmsg.question[0].name, dns.rdatatype.to_text(qmsg.question[0].rdtype) + ) def convert_queries( - query_iterator: Iterator[Tuple[QID, WireFormat]], - qwire_to_text_func: Callable[[WireFormat], str] = qwire_to_qname_qtype - ) -> Counter: + query_iterator: Iterator[Tuple[QID, WireFormat]], + qwire_to_text_func: Callable[[WireFormat], str] = qwire_to_qname_qtype, +) -> Counter: qcounter = Counter() # type: Counter for qid, qwire in query_iterator: try: text = qwire_to_text_func(qwire) except ValueError as exc: - logging.debug('Omitting QID %d: %s', qid, exc) + logging.debug("Omitting QID %d: %s", qid, exc) else: qcounter[text] += 1 return qcounter def get_printable_queries_format( - queries_mismatch: Counter, - queries_all: Counter = None, # all queries (needed for comparison with ref) - ref_queries_mismatch: Counter = None, # ref queries for the same mismatch - ref_queries_all: Counter = None # ref queries from all mismatches - ) -> Sequence[Tuple[str, int, str]]: + queries_mismatch: Counter, + queries_all: Counter = None, # all queries (needed for comparison with ref) + ref_queries_mismatch: Counter = None, # ref queries for the same mismatch + ref_queries_all: Counter = None, # ref queries from all mismatches +) -> Sequence[Tuple[str, int, str]]: def get_query_diff(query: str) -> str: - if (ref_queries_mismatch is None - or ref_queries_all is None - or queries_all is None): - return ' ' # no reference to compare to + if ( + ref_queries_mismatch is None + or ref_queries_all is None + or queries_all is None + ): + return " " # no reference to compare to if query in queries_mismatch and query not in ref_queries_all: - return '+' # previously unseen query has appeared + return "+" # previously unseen query has appeared if query in ref_queries_mismatch and query not in queries_all: - return '-' # query no longer appears in any mismatch category - return ' ' # no change, or query has moved to a different mismatch category + return "-" # query no longer appears in any mismatch category + return " " # no change, or query has moved to a different mismatch category query_set = set(queries_mismatch.keys()) if ref_queries_mismatch is not None: @@ -101,9 +99,9 @@ def get_printable_queries_format( for query in query_set: diff = get_query_diff(query) count = queries_mismatch[query] - if diff == ' ' and count == 0: + if diff == " " and count == 0: continue # omit queries that just moved between categories - if diff == '-': + if diff == "-": assert ref_queries_mismatch is not None count = ref_queries_mismatch[query] # show how many cases were removed queries.append((diff, count, query)) diff --git a/respdiff/repro.py b/respdiff/repro.py index e9af030..8828887 100644 --- a/respdiff/repro.py +++ b/respdiff/repro.py @@ -4,11 +4,27 @@ from multiprocessing import pool import random import subprocess from typing import ( # noqa - AbstractSet, Any, Iterator, Iterable, Mapping, Optional, Sequence, Tuple, - TypeVar, Union) + AbstractSet, + Any, + Iterator, + Iterable, + Mapping, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) from .database import ( - DNSRepliesFactory, DNSReply, key2qid, ResolverID, qid2key, QKey, WireFormat) + DNSRepliesFactory, + DNSReply, + key2qid, + ResolverID, + qid2key, + QKey, + WireFormat, +) from .dataformat import Diff, DiffReport, FieldLabel from .match import compare from .query import get_query_iterator @@ -16,25 +32,25 @@ from .sendrecv import worker_perform_single_query from .typing import QID # noqa -T = TypeVar('T') +T = TypeVar("T") def restart_resolver(script_path: str) -> None: try: subprocess.check_call(script_path) except subprocess.CalledProcessError as exc: - logging.warning('Resolver restart failed (exit code %d): %s', - exc.returncode, script_path) + logging.warning( + "Resolver restart failed (exit code %d): %s", exc.returncode, script_path + ) except PermissionError: - logging.warning('Resolver restart failed (permission error): %s', - script_path) + logging.warning("Resolver restart failed (permission error): %s", script_path) def get_restart_scripts(config: Mapping[str, Any]) -> Mapping[ResolverID, str]: restart_scripts = {} - for resolver in config['servers']['names']: + for resolver in config["servers"]["names"]: try: - restart_scripts[resolver] = config[resolver]['restart_script'] + restart_scripts[resolver] = config[resolver]["restart_script"] except KeyError: logging.warning('No restart script available for "%s"!', resolver) return restart_scripts @@ -51,12 +67,12 @@ def chunker(iterable: Iterable[T], size: int) -> Iterator[Iterable[T]]: def process_answers( - qkey: QKey, - answers: Mapping[ResolverID, DNSReply], - report: DiffReport, - criteria: Sequence[FieldLabel], - target: ResolverID - ) -> None: + qkey: QKey, + answers: Mapping[ResolverID, DNSReply], + report: DiffReport, + criteria: Sequence[FieldLabel], + target: ResolverID, +) -> None: if report.target_disagreements is None or report.reprodata is None: raise RuntimeError("Report doesn't contain necessary data!") qid = key2qid(qkey) @@ -75,15 +91,17 @@ def process_answers( def query_stream_from_disagreements( - lmdb, - report: DiffReport, - skip_unstable: bool = True, - skip_non_reproducible: bool = True, - shuffle: bool = True - ) -> Iterator[Tuple[QKey, WireFormat]]: + lmdb, + report: DiffReport, + skip_unstable: bool = True, + skip_non_reproducible: bool = True, + shuffle: bool = True, +) -> Iterator[Tuple[QKey, WireFormat]]: if report.target_disagreements is None or report.reprodata is None: raise RuntimeError("Report doesn't contain necessary data!") - qids = report.target_disagreements.keys() # type: Union[Sequence[QID], AbstractSet[QID]] + qids = ( + report.target_disagreements.keys() + ) # type: Union[Sequence[QID], AbstractSet[QID]] if shuffle: # create a new, randomized list from disagreements qids = list(qids) @@ -94,23 +112,23 @@ def query_stream_from_disagreements( reprocounter = report.reprodata[qid] # verify if answers are stable if skip_unstable and reprocounter.retries != reprocounter.upstream_stable: - logging.debug('Skipping QID %7d: unstable upstream', diff.qid) + logging.debug("Skipping QID %7d: unstable upstream", diff.qid) continue if skip_non_reproducible and reprocounter.retries != reprocounter.verified: - logging.debug('Skipping QID %7d: not 100 %% reproducible', diff.qid) + logging.debug("Skipping QID %7d: not 100 %% reproducible", diff.qid) continue yield qid2key(qid), qwire def reproduce_queries( - query_stream: Iterator[Tuple[QKey, WireFormat]], - report: DiffReport, - dnsreplies_factory: DNSRepliesFactory, - criteria: Sequence[FieldLabel], - target: ResolverID, - restart_scripts: Optional[Mapping[ResolverID, str]] = None, - nproc: int = 1 - ) -> None: + query_stream: Iterator[Tuple[QKey, WireFormat]], + report: DiffReport, + dnsreplies_factory: DNSRepliesFactory, + criteria: Sequence[FieldLabel], + target: ResolverID, + restart_scripts: Optional[Mapping[ResolverID, str]] = None, + nproc: int = 1, +) -> None: if restart_scripts is None: restart_scripts = {} with pool.Pool(processes=nproc) as p: @@ -121,12 +139,14 @@ def reproduce_queries( restart_resolver(script) process_args = [args for args in process_args if args is not None] - for qkey, replies_data, in p.imap_unordered( - worker_perform_single_query, - process_args, - chunksize=1): + for ( + qkey, + replies_data, + ) in p.imap_unordered( + worker_perform_single_query, process_args, chunksize=1 + ): replies = dnsreplies_factory.parse(replies_data) process_answers(qkey, replies, report, criteria, target) done += len(process_args) - logging.debug('Processed {:4d} queries'.format(done)) + logging.debug("Processed {:4d} queries".format(done)) diff --git a/respdiff/sendrecv.py b/respdiff/sendrecv.py index 862f038..ba30f5a 100644 --- a/respdiff/sendrecv.py +++ b/respdiff/sendrecv.py @@ -8,7 +8,6 @@ The entire module keeps a global state, which enables its easy use with both threads or processes. Make sure not to break this compatibility. """ - from argparse import Namespace import logging import random @@ -19,7 +18,15 @@ import ssl import struct import time import threading -from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple # noqa: type hints +from typing import ( + Any, + Dict, + List, + Mapping, + Optional, + Sequence, + Tuple, +) # noqa: type hints import dns.inet import dns.message @@ -41,7 +48,9 @@ ResolverSockets = Sequence[Tuple[ResolverID, Socket, IsStreamFlag]] # module-wide state variables __resolvers = [] # type: Sequence[Tuple[ResolverID, IP, Protocol, Port]] __worker_state = threading.local() -__max_timeouts = 10 # crash when N consecutive timeouts are received from a single resolver +__max_timeouts = ( + 10 # crash when N consecutive timeouts are received from a single resolver +) __ignore_timeout = False __timeout = 16 __time_delay_min = 0 @@ -76,11 +85,11 @@ def module_init(args: Namespace) -> None: global __dnsreplies_factory __resolvers = get_resolvers(args.cfg) - __timeout = args.cfg['sendrecv']['timeout'] - __time_delay_min = args.cfg['sendrecv']['time_delay_min'] - __time_delay_max = args.cfg['sendrecv']['time_delay_max'] + __timeout = args.cfg["sendrecv"]["timeout"] + __time_delay_min = args.cfg["sendrecv"]["time_delay_min"] + __time_delay_max = args.cfg["sendrecv"]["time_delay_max"] try: - __max_timeouts = args.cfg['sendrecv']['max_timeouts'] + __max_timeouts = args.cfg["sendrecv"]["max_timeouts"] except KeyError: pass try: @@ -135,7 +144,9 @@ def worker_perform_query(args: Tuple[QKey, WireFormat]) -> Tuple[QKey, RepliesBl return qkey, blob -def worker_perform_single_query(args: Tuple[QKey, WireFormat]) -> Tuple[QKey, RepliesBlob]: +def worker_perform_single_query( + args: Tuple[QKey, WireFormat] +) -> Tuple[QKey, RepliesBlob]: """Perform a single DNS query with setup and teardown of sockets. Used by diffrepro.""" qkey, qwire = args worker_reinit() @@ -150,12 +161,12 @@ def worker_perform_single_query(args: Tuple[QKey, WireFormat]) -> Tuple[QKey, Re def get_resolvers( - config: Mapping[str, Any] - ) -> Sequence[Tuple[ResolverID, IP, Protocol, Port]]: + config: Mapping[str, Any] +) -> Sequence[Tuple[ResolverID, IP, Protocol, Port]]: resolvers = [] - for resname in config['servers']['names']: + for resname in config["servers"]["names"]: rescfg = config[resname] - resolvers.append((resname, rescfg['ip'], rescfg['transport'], rescfg['port'])) + resolvers.append((resname, rescfg["ip"], rescfg["transport"], rescfg["port"])) return resolvers @@ -170,7 +181,9 @@ def _check_timeout(replies: Mapping[ResolverID, DNSReply]) -> None: raise RuntimeError( "Resolver '{}' timed-out {:d} times in a row. " "Use '--ignore-timeout' to supress this error.".format( - resolver, __max_timeouts)) + resolver, __max_timeouts + ) + ) def make_ssl_context(): @@ -179,19 +192,19 @@ def make_ssl_context(): # https://docs.python.org/3/library/ssl.html#tls-1-3 # NOTE forcing TLS v1.2 is hacky, because of different py3/openssl versions... - if getattr(ssl, 'PROTOCOL_TLS', None) is not None: + if getattr(ssl, "PROTOCOL_TLS", None) is not None: context = ssl.SSLContext(ssl.PROTOCOL_TLS) # pylint: disable=no-member else: context = ssl.SSLContext() - if getattr(ssl, 'maximum_version', None) is not None: + if getattr(ssl, "maximum_version", None) is not None: context.maximum_version = ssl.TLSVersion.TLSv1_2 # pylint: disable=no-member else: context.options |= ssl.OP_NO_SSLv2 context.options |= ssl.OP_NO_SSLv3 context.options |= ssl.OP_NO_TLSv1 context.options |= ssl.OP_NO_TLSv1_1 - if getattr(ssl, 'OP_NO_TLSv1_3', None) is not None: + if getattr(ssl, "OP_NO_TLSv1_3", None) is not None: context.options |= ssl.OP_NO_TLSv1_3 # pylint: disable=no-member # turn off certificate verification @@ -201,7 +214,9 @@ def make_ssl_context(): return context -def sock_init(retry: int = 3) -> Tuple[Selector, Sequence[Tuple[ResolverID, Socket, IsStreamFlag]]]: +def sock_init( + retry: int = 3, +) -> Tuple[Selector, Sequence[Tuple[ResolverID, Socket, IsStreamFlag]]]: sockets = [] selector = selectors.DefaultSelector() for name, ipaddr, transport, port in __resolvers: @@ -211,22 +226,22 @@ def sock_init(retry: int = 3) -> Tuple[Selector, Sequence[Tuple[ResolverID, Sock elif af == dns.inet.AF_INET6: destination = (ipaddr, port, 0, 0) else: - raise NotImplementedError('AF') + raise NotImplementedError("AF") - if transport in {'tcp', 'tls'}: + if transport in {"tcp", "tls"}: socktype = socket.SOCK_STREAM isstream = True - elif transport == 'udp': + elif transport == "udp": socktype = socket.SOCK_DGRAM isstream = False else: - raise NotImplementedError('transport: {}'.format(transport)) + raise NotImplementedError("transport: {}".format(transport)) # attempt to connect to socket attempt = 1 while attempt <= retry: sock = socket.socket(af, socktype, 0) - if transport == 'tls': + if transport == "tls": ctx = make_ssl_context() sock = ctx.wrap_socket(sock) try: @@ -234,7 +249,9 @@ def sock_init(retry: int = 3) -> Tuple[Selector, Sequence[Tuple[ResolverID, Sock except ConnectionRefusedError as e: # TCP socket is closed raise RuntimeError( "socket: Failed to connect to {dest[0]} port {dest[1]}".format( - dest=destination)) from e + dest=destination + ) + ) from e except OSError as exc: if exc.errno != 0 and not isinstance(exc, ConnectionResetError): raise @@ -245,7 +262,9 @@ def sock_init(retry: int = 3) -> Tuple[Selector, Sequence[Tuple[ResolverID, Sock if attempt > retry: raise RuntimeError( "socket: Failed to connect to {dest[0]} port {dest[1]}".format( - dest=destination)) from exc + dest=destination + ) + ) from exc else: break sock.setblocking(False) @@ -261,11 +280,11 @@ def _recv_msg(sock: Socket, isstream: IsStreamFlag) -> WireFormat: try: blength = sock.recv(2) except ssl.SSLWantReadError as e: - raise TcpDnsLengthError('failed to recv DNS packet length') from e + raise TcpDnsLengthError("failed to recv DNS packet length") from e else: if len(blength) != 2: # FIN / RST - raise TcpDnsLengthError('failed to recv DNS packet length') - (length, ) = struct.unpack('!H', blength) + raise TcpDnsLengthError("failed to recv DNS packet length") + (length,) = struct.unpack("!H", blength) else: length = 65535 # max. UDP message size, no IPv6 jumbograms return sock.recv(length) @@ -274,11 +293,13 @@ def _recv_msg(sock: Socket, isstream: IsStreamFlag) -> WireFormat: def _create_sendbuf(dnsdata: WireFormat, isstream: IsStreamFlag) -> bytes: if isstream: # prepend length, RFC 1035 section 4.2.2 length = len(dnsdata) - return struct.pack('!H', length) + dnsdata + return struct.pack("!H", length) + dnsdata return dnsdata -def _get_resolver_from_sock(sockets: ResolverSockets, sock: Socket) -> Optional[ResolverID]: +def _get_resolver_from_sock( + sockets: ResolverSockets, sock: Socket +) -> Optional[ResolverID]: for resolver, resolver_sock, _ in sockets: if sock == resolver_sock: return resolver @@ -286,11 +307,8 @@ def _get_resolver_from_sock(sockets: ResolverSockets, sock: Socket) -> Optional[ def _recv_from_resolvers( - selector: Selector, - sockets: ResolverSockets, - msgid: bytes, - timeout: float - ) -> Tuple[Dict[ResolverID, DNSReply], bool]: + selector: Selector, sockets: ResolverSockets, msgid: bytes, timeout: float +) -> Tuple[Dict[ResolverID, DNSReply], bool]: def raise_resolver_exc(sock, exc): resolver = _get_resolver_from_sock(sockets, sock) @@ -298,7 +316,6 @@ def _recv_from_resolvers( raise ResolverConnectionError(resolver, str(exc)) from exc raise exc - start_time = time.perf_counter() end_time = start_time + timeout replies = {} # type: Dict[ResolverID, DNSReply] @@ -332,11 +349,11 @@ def _recv_from_resolvers( def _send_recv_parallel( - dgram: WireFormat, # DNS message suitable for UDP transport - selector: Selector, - sockets: ResolverSockets, - timeout: float - ) -> Tuple[Mapping[ResolverID, DNSReply], ReinitFlag]: + dgram: WireFormat, # DNS message suitable for UDP transport + selector: Selector, + sockets: ResolverSockets, + timeout: float, +) -> Tuple[Mapping[ResolverID, DNSReply], ReinitFlag]: # send queries for resolver, sock, isstream in sockets: sendbuf = _create_sendbuf(dgram, isstream) @@ -358,10 +375,10 @@ def _send_recv_parallel( def send_recv_parallel( - dgram: WireFormat, # DNS message suitable for UDP transport - timeout: float, - reinit_on_tcpfin: bool = True - ) -> Mapping[ResolverID, DNSReply]: + dgram: WireFormat, # DNS message suitable for UDP transport + timeout: float, + reinit_on_tcpfin: bool = True, +) -> Mapping[ResolverID, DNSReply]: problematic = [] for _ in range(CONN_RESET_RETRIES + 1): try: # get sockets and selector @@ -389,5 +406,7 @@ def send_recv_parallel( worker_deinit() # re-establish connection worker_reinit() raise RuntimeError( - 'ConnectionError received {} times in a row ({}), exiting!'.format( - CONN_RESET_RETRIES + 1, ', '.join(problematic))) + "ConnectionError received {} times in a row ({}), exiting!".format( + CONN_RESET_RETRIES + 1, ", ".join(problematic) + ) + ) diff --git a/respdiff/stats.py b/respdiff/stats.py index f3820fb..4a2ceb1 100644 --- a/respdiff/stats.py +++ b/respdiff/stats.py @@ -23,24 +23,26 @@ class Stats(JSONDataObject): Example: samples = [1540, 1613, 1489] """ + _ATTRIBUTES = { - 'samples': (None, None), - 'threshold': (None, None), + "samples": (None, None), + "threshold": (None, None), } class SamplePosition(Enum): """Position of a single sample against the rest of the distribution.""" + ABOVE_REF = 1 ABOVE_THRESHOLD = 2 NORMAL = 3 BELOW_REF = 4 def __init__( - self, - samples: Sequence[float] = None, - threshold: Optional[float] = None, - _restore_dict: Optional[Mapping[str, float]] = None - ) -> None: + self, + samples: Sequence[float] = None, + threshold: Optional[float] = None, + _restore_dict: Optional[Mapping[str, float]] = None, + ) -> None: """ samples contain the entire data set of reference values of this parameter. If no custom threshold is provided, it is calculated automagically. @@ -73,9 +75,9 @@ class Stats(JSONDataObject): return max(self.samples) def get_percentile_rank(self, sample: float) -> float: - return scipy.stats.percentileofscore(self.samples, sample, kind='weak') + return scipy.stats.percentileofscore(self.samples, sample, kind="weak") - def evaluate_sample(self, sample: float) -> 'Stats.SamplePosition': + def evaluate_sample(self, sample: float) -> "Stats.SamplePosition": if sample < self.min: return Stats.SamplePosition.BELOW_REF elif sample > self.max: @@ -107,17 +109,15 @@ class MismatchStatistics(dict, JSONDataObject): """ _ATTRIBUTES = { - 'total': ( - lambda x: Stats(_restore_dict=x), - lambda x: x.save()), + "total": (lambda x: Stats(_restore_dict=x), lambda x: x.save()), } def __init__( - self, - mismatch_counters_list: Optional[Sequence[Counter]] = None, - sample_size: Optional[int] = None, - _restore_dict: Optional[Mapping[str, Any]] = None - ) -> None: + self, + mismatch_counters_list: Optional[Sequence[Counter]] = None, + sample_size: Optional[int] = None, + _restore_dict: Optional[Mapping[str, Any]] = None, + ) -> None: super().__init__() self.total = None if mismatch_counters_list is not None and sample_size is not None: @@ -128,15 +128,15 @@ class MismatchStatistics(dict, JSONDataObject): n += count mismatch_key = str(mismatch.key) samples[mismatch_key].append(count) - samples['total'].append(n) + samples["total"].append(n) # fill in missing samples for seq in samples.values(): seq.extend([0] * (sample_size - len(seq))) # create stats from samples - self.total = Stats(samples['total']) - del samples['total'] + self.total = Stats(samples["total"]) + del samples["total"] for mismatch_key, stats_seq in samples.items(): self[mismatch_key] = Stats(stats_seq) elif _restore_dict is not None: @@ -166,17 +166,18 @@ class FieldStatistics(dict, JSONDataObject): """ def __init__( - self, - summaries_list: Optional[Sequence[Summary]] = None, - _restore_dict: Optional[Mapping[str, Any]] = None - ) -> None: + self, + summaries_list: Optional[Sequence[Summary]] = None, + _restore_dict: Optional[Mapping[str, Any]] = None, + ) -> None: super().__init__() if summaries_list is not None: field_counters_list = [d.get_field_counters() for d in summaries_list] for field in ALL_FIELDS: mismatch_counters_list = [fc[field] for fc in field_counters_list] self[field] = MismatchStatistics( - mismatch_counters_list, len(summaries_list)) + mismatch_counters_list, len(summaries_list) + ) elif _restore_dict is not None: self.restore(_restore_dict) @@ -194,32 +195,20 @@ class FieldStatistics(dict, JSONDataObject): class SummaryStatistics(JSONDataObject): _ATTRIBUTES = { - 'sample_size': (None, None), - 'upstream_unstable': ( - lambda x: Stats(_restore_dict=x), - lambda x: x.save()), - 'usable_answers': ( - lambda x: Stats(_restore_dict=x), - lambda x: x.save()), - 'not_reproducible': ( - lambda x: Stats(_restore_dict=x), - lambda x: x.save()), - 'target_disagreements': ( - lambda x: Stats(_restore_dict=x), - lambda x: x.save()), - 'fields': ( - lambda x: FieldStatistics(_restore_dict=x), - lambda x: x.save()), - 'queries': ( - lambda x: QueryStatistics(_restore_dict=x), - lambda x: x.save()), + "sample_size": (None, None), + "upstream_unstable": (lambda x: Stats(_restore_dict=x), lambda x: x.save()), + "usable_answers": (lambda x: Stats(_restore_dict=x), lambda x: x.save()), + "not_reproducible": (lambda x: Stats(_restore_dict=x), lambda x: x.save()), + "target_disagreements": (lambda x: Stats(_restore_dict=x), lambda x: x.save()), + "fields": (lambda x: FieldStatistics(_restore_dict=x), lambda x: x.save()), + "queries": (lambda x: QueryStatistics(_restore_dict=x), lambda x: x.save()), } def __init__( - self, - reports: Sequence[DiffReport] = None, - _restore_dict: Mapping[str, Any] = None - ) -> None: + self, + reports: Sequence[DiffReport] = None, + _restore_dict: Mapping[str, Any] = None, + ) -> None: super().__init__() self.sample_size = None self.upstream_unstable = None @@ -233,15 +222,18 @@ class SummaryStatistics(JSONDataObject): usable_reports = [] for report in reports: if report.summary is None: - logging.warning('Empty diffsum in %s Omitting...', report.fileorigin) + logging.warning( + "Empty diffsum in %s Omitting...", report.fileorigin + ) else: usable_reports.append(report) summaries = [ - report.summary for report in reports if report.summary is not None] + report.summary for report in reports if report.summary is not None + ] assert len(summaries) == len(usable_reports) if not summaries: - raise ValueError('No summaries found in reports!') + raise ValueError("No summaries found in reports!") self.sample_size = len(summaries) self.upstream_unstable = Stats([s.upstream_unstable for s in summaries]) diff --git a/statcmp.py b/statcmp.py index c82adf3..3039a68 100755 --- a/statcmp.py +++ b/statcmp.py @@ -14,16 +14,17 @@ from respdiff.typing import FieldLabel import matplotlib import matplotlib.axes import matplotlib.ticker -matplotlib.use('Agg') + +matplotlib.use("Agg") import matplotlib.pyplot as plt # noqa -COLOR_OK = 'tab:blue' -COLOR_GOOD = 'tab:green' -COLOR_BAD = 'xkcd:bright red' -COLOR_THRESHOLD = 'tab:orange' -COLOR_BG = 'tab:gray' -COLOR_LABEL = 'black' +COLOR_OK = "tab:blue" +COLOR_GOOD = "tab:green" +COLOR_BAD = "xkcd:bright red" +COLOR_THRESHOLD = "tab:orange" +COLOR_BG = "tab:gray" +COLOR_LABEL = "black" VIOLIN_FIGSIZE = (3, 6) @@ -36,7 +37,9 @@ SAMPLE_COLORS = { class AxisMarker: - def __init__(self, position: float, width: float = 0.7, color: str = COLOR_BG) -> None: + def __init__( + self, position: float, width: float = 0.7, color: str = COLOR_BG + ) -> None: self.position = position self.width = width self.color = color @@ -48,19 +51,20 @@ class AxisMarker: def plot_violin( - ax: matplotlib.axes.Axes, - violin_data: Sequence[float], - markers: Sequence[AxisMarker], - label: str, - color: str = COLOR_LABEL - ) -> None: - ax.set_title(label, fontdict={'fontsize': 14}, color=color) + ax: matplotlib.axes.Axes, + violin_data: Sequence[float], + markers: Sequence[AxisMarker], + label: str, + color: str = COLOR_LABEL, +) -> None: + ax.set_title(label, fontdict={"fontsize": 14}, color=color) # plot violin graph - violin_parts = ax.violinplot(violin_data, bw_method=0.07, - showmedians=False, showextrema=False) + violin_parts = ax.violinplot( + violin_data, bw_method=0.07, showmedians=False, showextrema=False + ) # set violin background color - for pc in violin_parts['bodies']: + for pc in violin_parts["bodies"]: pc.set_facecolor(COLOR_BG) pc.set_edgecolor(COLOR_BG) @@ -69,10 +73,10 @@ def plot_violin( marker.draw(ax) # turn off axis spines - for sp in ['right', 'top', 'bottom']: - ax.spines[sp].set_color('none') + for sp in ["right", "top", "bottom"]: + ax.spines[sp].set_color("none") # move the left ax spine to center - ax.spines['left'].set_position(('data', 1)) + ax.spines["left"].set_position(("data", 1)) # customize axis ticks ax.xaxis.set_major_locator(matplotlib.ticker.NullLocator()) @@ -80,8 +84,11 @@ def plot_violin( if max(violin_data) == 0: # fix tick at 0 when there's no data ax.yaxis.set_major_locator(matplotlib.ticker.FixedLocator([0])) else: - ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator( - nbins='auto', steps=[1, 2, 4, 5, 10], integer=True)) + ax.yaxis.set_major_locator( + matplotlib.ticker.MaxNLocator( + nbins="auto", steps=[1, 2, 4, 5, 10], integer=True + ) + ) ax.yaxis.set_minor_locator(matplotlib.ticker.NullLocator()) ax.tick_params(labelsize=14) @@ -99,34 +106,41 @@ def _axes_iter(axes, width: int): def eval_and_plot_single( - ax: matplotlib.axes.Axes, - stats: Stats, - label: str, - samples: Sequence[float] - ) -> bool: + ax: matplotlib.axes.Axes, stats: Stats, label: str, samples: Sequence[float] +) -> bool: markers = [] below_min = False above_thr = False for sample in samples: result = stats.evaluate_sample(sample) markers.append(AxisMarker(sample, color=SAMPLE_COLORS[result])) - if result in (Stats.SamplePosition.ABOVE_REF, - Stats.SamplePosition.ABOVE_THRESHOLD): + if result in ( + Stats.SamplePosition.ABOVE_REF, + Stats.SamplePosition.ABOVE_THRESHOLD, + ): above_thr = True logging.error( - ' %s: threshold exceeded! sample: %d / %4.2f%% vs threshold: %d / %4.2f%%', - label, sample, stats.get_percentile_rank(sample), - stats.threshold, stats.get_percentile_rank(stats.threshold)) + " %s: threshold exceeded! sample: %d / %4.2f%% vs threshold: %d / %4.2f%%", + label, + sample, + stats.get_percentile_rank(sample), + stats.threshold, + stats.get_percentile_rank(stats.threshold), + ) elif result == Stats.SamplePosition.BELOW_REF: below_min = True logging.info( - ' %s: new minimum found! new: %d vs prev: %d', - label, sample, stats.min) + " %s: new minimum found! new: %d vs prev: %d", label, sample, stats.min + ) else: logging.info( - ' %s: ok! sample: %d / %4.2f%% vs threshold: %d / %4.2f%%', - label, sample, stats.get_percentile_rank(sample), - stats.threshold, stats.get_percentile_rank(stats.threshold)) + " %s: ok! sample: %d / %4.2f%% vs threshold: %d / %4.2f%%", + label, + sample, + stats.get_percentile_rank(sample), + stats.threshold, + stats.get_percentile_rank(stats.threshold), + ) # add min/med/max markers markers.append(AxisMarker(stats.min, 0.5, COLOR_BG)) @@ -148,11 +162,11 @@ def eval_and_plot_single( def plot_overview( - sumstats: SummaryStatistics, - fields: Sequence[FieldLabel], - summaries: Optional[Sequence[Summary]] = None, - label: str = 'fields_overview' - ) -> bool: + sumstats: SummaryStatistics, + fields: Sequence[FieldLabel], + summaries: Optional[Sequence[Summary]] = None, + label: str = "fields_overview", +) -> bool: """ Plot an overview of all fields using violing graphs. If any summaries are provided, they are drawn in the graphs and also evaluated. If any sample in any field exceeds @@ -169,26 +183,33 @@ def plot_overview( fig, axes = plt.subplots( OVERVIEW_Y_FIG, OVERVIEW_X_FIG, - figsize=(OVERVIEW_X_FIG*VIOLIN_FIGSIZE[0], OVERVIEW_Y_FIG*VIOLIN_FIGSIZE[1])) + figsize=( + OVERVIEW_X_FIG * VIOLIN_FIGSIZE[0], + OVERVIEW_Y_FIG * VIOLIN_FIGSIZE[1], + ), + ) ax_it = _axes_iter(axes, OVERVIEW_X_FIG) # target disagreements assert sumstats.target_disagreements is not None samples = [len(summary) for summary in summaries] passed &= eval_and_plot_single( - next(ax_it), sumstats.target_disagreements, 'target_disagreements', samples) + next(ax_it), sumstats.target_disagreements, "target_disagreements", samples + ) # upstream unstable assert sumstats.upstream_unstable is not None samples = [summary.upstream_unstable for summary in summaries] passed &= eval_and_plot_single( - next(ax_it), sumstats.upstream_unstable, 'upstream_unstable', samples) + next(ax_it), sumstats.upstream_unstable, "upstream_unstable", samples + ) # not 100% reproducible assert sumstats.not_reproducible is not None samples = [summary.not_reproducible for summary in summaries] passed &= eval_and_plot_single( - next(ax_it), sumstats.not_reproducible, 'not_reproducible', samples) + next(ax_it), sumstats.not_reproducible, "not_reproducible", samples + ) # fields assert sumstats.fields is not None @@ -201,7 +222,8 @@ def plot_overview( next(ax_it), sumstats.fields[field].total, field, - [len(list(fc[field].elements())) for fc in fcs]) + [len(list(fc[field].elements())) for fc in fcs], + ) # hide unused axis for ax in ax_it: @@ -209,15 +231,21 @@ def plot_overview( # display sample size fig.text( - 0.95, 0.95, - 'stat sample size: {}'.format(len(sumstats.target_disagreements.samples)), - fontsize=18, color=COLOR_BG, ha='right', va='bottom', alpha=0.7) + 0.95, + 0.95, + "stat sample size: {}".format(len(sumstats.target_disagreements.samples)), + fontsize=18, + color=COLOR_BG, + ha="right", + va="bottom", + alpha=0.7, + ) # save image plt.tight_layout() plt.subplots_adjust(top=0.9) fig.suptitle(label, fontsize=22) - plt.savefig('{}.png'.format(label)) + plt.savefig("{}.png".format(label)) plt.close() return passed @@ -226,25 +254,32 @@ def plot_overview( def main(): cli.setup_logging() parser = argparse.ArgumentParser( - description=("Plot and compare reports against statistical data. " - "Returns non-zero exit code if any threshold is exceeded.")) + description=( + "Plot and compare reports against statistical data. " + "Returns non-zero exit code if any threshold is exceeded." + ) + ) cli.add_arg_stats(parser) cli.add_arg_report(parser) cli.add_arg_config(parser) - parser.add_argument('-l', '--label', default='fields_overview', - help='Set plot label. It is also used for the filename.') + parser.add_argument( + "-l", + "--label", + default="fields_overview", + help="Set plot label. It is also used for the filename.", + ) args = parser.parse_args() sumstats = args.stats - field_weights = args.cfg['report']['field_weights'] + field_weights = args.cfg["report"]["field_weights"] try: summaries = cli.load_summaries(args.report) except ValueError: sys.exit(1) - logging.info('Start Comparison: %s', args.label) + logging.info("Start Comparison: %s", args.label) passed = plot_overview(sumstats, field_weights, summaries, args.label) if not passed: @@ -253,5 +288,5 @@ def main(): sys.exit(0) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/sumcmp.py b/sumcmp.py index f73e003..d8d1b48 100755 --- a/sumcmp.py +++ b/sumcmp.py @@ -10,7 +10,10 @@ from respdiff import cli from respdiff.database import LMDB from respdiff.dataformat import DiffReport from respdiff.query import ( - convert_queries, get_printable_queries_format, get_query_iterator) + convert_queries, + get_printable_queries_format, + get_query_iterator, +) ANSWERS_DIFFERENCE_THRESHOLD_WARNING = 0.05 @@ -25,27 +28,33 @@ def check_report_summary(report: DiffReport): def check_usable_answers(report: DiffReport, ref_report: DiffReport): if report.summary is None or ref_report.summary is None: raise RuntimeError("Report doesn't contain necessary data!") - answers_difference = math.fabs( - report.summary.usable_answers - ref_report.summary.usable_answers - ) / ref_report.summary.usable_answers + answers_difference = ( + math.fabs(report.summary.usable_answers - ref_report.summary.usable_answers) + / ref_report.summary.usable_answers + ) if answers_difference >= ANSWERS_DIFFERENCE_THRESHOLD_WARNING: - logging.warning('Number of usable answers changed by {:.1f} %!'.format( - answers_difference * 100.0)) + logging.warning( + "Number of usable answers changed by {:.1f} %!".format( + answers_difference * 100.0 + ) + ) def main(): cli.setup_logging() - parser = argparse.ArgumentParser(description='compare two diff summaries') + parser = argparse.ArgumentParser(description="compare two diff summaries") cli.add_arg_config(parser) - parser.add_argument('old_datafile', type=str, help='report to compare against') - parser.add_argument('new_datafile', type=str, help='report to compare evaluate') - cli.add_arg_envdir(parser) # TODO remove when we no longer need to read queries from lmdb + parser.add_argument("old_datafile", type=str, help="report to compare against") + parser.add_argument("new_datafile", type=str, help="report to compare evaluate") + cli.add_arg_envdir( + parser + ) # TODO remove when we no longer need to read queries from lmdb cli.add_arg_limit(parser) args = parser.parse_args() - report = DiffReport.from_json(cli.get_datafile(args, key='new_datafile')) - field_weights = args.cfg['report']['field_weights'] - ref_report = DiffReport.from_json(cli.get_datafile(args, key='old_datafile')) + report = DiffReport.from_json(cli.get_datafile(args, key="new_datafile")) + field_weights = args.cfg["report"]["field_weights"] + ref_report = DiffReport.from_json(cli.get_datafile(args, key="old_datafile")) check_report_summary(report) check_report_summary(ref_report) @@ -63,7 +72,9 @@ def main(): if field not in field_counters: field_counters[field] = Counter() - cli.print_fields_overview(field_counters, len(report.summary), ref_field_counters) + cli.print_fields_overview( + field_counters, len(report.summary), ref_field_counters + ) for field in field_weights: if field in field_counters: @@ -76,22 +87,27 @@ def main(): counter[mismatch] = 0 cli.print_field_mismatch_stats( - field, counter, len(report.summary), ref_counter) + field, counter, len(report.summary), ref_counter + ) # query details with LMDB(args.envdir, readonly=True) as lmdb: lmdb.open_db(LMDB.QUERIES) queries_all = convert_queries( - get_query_iterator(lmdb, report.summary.keys())) + get_query_iterator(lmdb, report.summary.keys()) + ) ref_queries_all = convert_queries( - get_query_iterator(lmdb, ref_report.summary.keys())) + get_query_iterator(lmdb, ref_report.summary.keys()) + ) for field in field_weights: if field in field_counters: # ensure "disappeared" mismatches are shown field_mismatches = dict(report.summary.get_field_mismatches(field)) - ref_field_mismatches = dict(ref_report.summary.get_field_mismatches(field)) + ref_field_mismatches = dict( + ref_report.summary.get_field_mismatches(field) + ) mismatches = set(field_mismatches.keys()) mismatches.update(ref_field_mismatches.keys()) @@ -99,17 +115,19 @@ def main(): qids = field_mismatches.get(mismatch, set()) queries = convert_queries(get_query_iterator(lmdb, qids)) ref_queries = convert_queries( - get_query_iterator(lmdb, ref_field_mismatches.get(mismatch, set()))) + get_query_iterator( + lmdb, ref_field_mismatches.get(mismatch, set()) + ) + ) cli.print_mismatch_queries( field, mismatch, get_printable_queries_format( - queries, - queries_all, - ref_queries, - ref_queries_all), - args.limit) + queries, queries_all, ref_queries, ref_queries_all + ), + args.limit, + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/sumstat.py b/sumstat.py index d17fe9c..5e30d47 100755 --- a/sumstat.py +++ b/sumstat.py @@ -10,12 +10,14 @@ from respdiff.stats import SummaryStatistics def _log_threshold(stats, label): percentile_rank = stats.get_percentile_rank(stats.threshold) - logging.info(' %s: %4.2f percentile rank', label, percentile_rank) + logging.info(" %s: %4.2f percentile rank", label, percentile_rank) def main(): cli.setup_logging() - parser = argparse.ArgumentParser(description='generate statistics file from reports') + parser = argparse.ArgumentParser( + description="generate statistics file from reports" + ) cli.add_arg_report_filename(parser) cli.add_arg_stats_filename(parser) @@ -28,16 +30,16 @@ def main(): logging.critical(exc) sys.exit(1) - logging.info('Total sample size: %d', sumstats.sample_size) - logging.info('Upper boundaries:') - _log_threshold(sumstats.target_disagreements, 'target_disagreements') - _log_threshold(sumstats.upstream_unstable, 'upstream_unstable') - _log_threshold(sumstats.not_reproducible, 'not_reproducible') + logging.info("Total sample size: %d", sumstats.sample_size) + logging.info("Upper boundaries:") + _log_threshold(sumstats.target_disagreements, "target_disagreements") + _log_threshold(sumstats.upstream_unstable, "upstream_unstable") + _log_threshold(sumstats.not_reproducible, "not_reproducible") for field_name, mismatch_stats in sumstats.fields.items(): _log_threshold(mismatch_stats.total, field_name) sumstats.export_json(args.stats_filename) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tests/lmdb/create_test_lmdb.py b/tests/lmdb/create_test_lmdb.py index 1ce25df..03df81c 100755 --- a/tests/lmdb/create_test_lmdb.py +++ b/tests/lmdb/create_test_lmdb.py @@ -12,13 +12,13 @@ import lmdb VERSION = "2018-05-21" CREATE_ENVS = { "2018-05-21": [ - 'answers_single_server', - 'answers_multiple_servers', + "answers_single_server", + "answers_multiple_servers", ], } -BIN_INT_3000000000 = b'\x00^\xd0\xb2' +BIN_INT_3000000000 = b"\x00^\xd0\xb2" class LMDBExistsError(Exception): @@ -27,68 +27,65 @@ class LMDBExistsError(Exception): def open_env(version, name): path = os.path.join(version, name) - if os.path.exists(os.path.join(path, 'data.mdb')): + if os.path.exists(os.path.join(path, "data.mdb")): raise LMDBExistsError if not os.path.exists(path): os.makedirs(path) - return lmdb.Environment( - path=path, - max_dbs=5, - create=True) + return lmdb.Environment(path=path, max_dbs=5, create=True) def create_answers_single_server(): - env = open_env(VERSION, 'answers_single_server') - mdb = env.open_db(key=b'meta', create=True) + env = open_env(VERSION, "answers_single_server") + mdb = env.open_db(key=b"meta", create=True) with env.begin(mdb, write=True) as txn: - txn.put(b'version', VERSION.encode('ascii')) - txn.put(b'start_time', BIN_INT_3000000000) - txn.put(b'end_time', BIN_INT_3000000000) - txn.put(b'servers', struct.pack('<I', 1)) - txn.put(b'name0', 'kresd'.encode('ascii')) + txn.put(b"version", VERSION.encode("ascii")) + txn.put(b"start_time", BIN_INT_3000000000) + txn.put(b"end_time", BIN_INT_3000000000) + txn.put(b"servers", struct.pack("<I", 1)) + txn.put(b"name0", "kresd".encode("ascii")) - adb = env.open_db(key=b'answers', create=True) + adb = env.open_db(key=b"answers", create=True) with env.begin(adb, write=True) as txn: answer = BIN_INT_3000000000 - answer += struct.pack('<H', 1) - answer += b'a' + answer += struct.pack("<H", 1) + answer += b"a" txn.put(BIN_INT_3000000000, answer) def create_answers_multiple_servers(): - env = open_env(VERSION, 'answers_multiple_servers') - mdb = env.open_db(key=b'meta', create=True) + env = open_env(VERSION, "answers_multiple_servers") + mdb = env.open_db(key=b"meta", create=True) with env.begin(mdb, write=True) as txn: - txn.put(b'version', VERSION.encode('ascii')) - txn.put(b'servers', struct.pack('<I', 3)) - txn.put(b'name0', 'kresd'.encode('ascii')) - txn.put(b'name1', 'bind'.encode('ascii')) - txn.put(b'name2', 'unbound'.encode('ascii')) + txn.put(b"version", VERSION.encode("ascii")) + txn.put(b"servers", struct.pack("<I", 3)) + txn.put(b"name0", "kresd".encode("ascii")) + txn.put(b"name1", "bind".encode("ascii")) + txn.put(b"name2", "unbound".encode("ascii")) - adb = env.open_db(key=b'answers', create=True) + adb = env.open_db(key=b"answers", create=True) with env.begin(adb, write=True) as txn: # kresd answer = BIN_INT_3000000000 - answer += struct.pack('<H', 0) + answer += struct.pack("<H", 0) # bind answer += BIN_INT_3000000000 - answer += struct.pack('<H', 2) - answer += b'ab' + answer += struct.pack("<H", 2) + answer += b"ab" # unbound answer += BIN_INT_3000000000 - answer += struct.pack('<H', 1) - answer += b'a' + answer += struct.pack("<H", 1) + answer += b"a" txn.put(BIN_INT_3000000000, answer) def main(): for env_name in CREATE_ENVS[VERSION]: try: - globals()['create_{}'.format(env_name)]() + globals()["create_{}".format(env_name)]() except LMDBExistsError: - print('{} exists, skipping'.format(env_name)) + print("{} exists, skipping".format(env_name)) continue -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tests/test_cli.py b/tests/test_cli.py index 613c100..3a44186 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -6,18 +6,21 @@ from pytest import approx from respdiff.cli import get_stats_data -@pytest.mark.parametrize('n, total, ref_n, expected', [ - (9, None, None, (9, None, None, None)), - (9, 10, None, (9, 90, None, None)), - (11, None, 10, (11, None, 1, 10)), - (9, None, 10, (9, None, -1, -10)), - (10, None, 10, (10, None, 0, 0)), - (10, None, 0, (10, None, +10, float('inf'))), - (0, None, 0, (0, None, 0, float('nan'))), - (9, 10, 10, (9, 90, -1, -10)), - (9, 10, 90, (9, 90, -81, -81*100.0/90)), - (90, 100, 9, (90, 90, 81, 81*100.0/9)), -]) +@pytest.mark.parametrize( + "n, total, ref_n, expected", + [ + (9, None, None, (9, None, None, None)), + (9, 10, None, (9, 90, None, None)), + (11, None, 10, (11, None, 1, 10)), + (9, None, 10, (9, None, -1, -10)), + (10, None, 10, (10, None, 0, 0)), + (10, None, 0, (10, None, +10, float("inf"))), + (0, None, 0, (0, None, 0, float("nan"))), + (9, 10, 10, (9, 90, -1, -10)), + (9, 10, 90, (9, 90, -81, -81 * 100.0 / 90)), + (90, 100, 9, (90, 90, 81, 81 * 100.0 / 9)), + ], +) def test_get_stats_data(n, total, ref_n, expected): got_n, got_pct, got_diff, got_diff_pct = get_stats_data(n, total, ref_n) assert got_n == expected[0] diff --git a/tests/test_data.py b/tests/test_data.py index 7fd1b92..ec632fa 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -4,32 +4,43 @@ import json import pytest from respdiff.dataformat import ( - Diff, DiffReport, Disagreements, DisagreementsCounter, JSONDataObject, ReproCounter, - ReproData, Summary) + Diff, + DiffReport, + Disagreements, + DisagreementsCounter, + JSONDataObject, + ReproCounter, + ReproData, + Summary, +) from respdiff.match import DataMismatch MISMATCH_DATA = [ - ('timeout', 'answer'), - (['A'], ['A', 'CNAME']), - (['A'], ['A', 'RRSIG(A)']), + ("timeout", "answer"), + (["A"], ["A", "CNAME"]), + (["A"], ["A", "RRSIG(A)"]), ] DIFF_DATA = [ - ('timeout', MISMATCH_DATA[0]), - ('answertypes', MISMATCH_DATA[1]), - ('answerrrsigs', MISMATCH_DATA[2]), - ('answerrrsigs', MISMATCH_DATA[1]), + ("timeout", MISMATCH_DATA[0]), + ("answertypes", MISMATCH_DATA[1]), + ("answerrrsigs", MISMATCH_DATA[2]), + ("answerrrsigs", MISMATCH_DATA[1]), ] -QUERY_DIFF_DATA = list(enumerate([ - (), - (DIFF_DATA[0],), - (DIFF_DATA[0], DIFF_DATA[1]), - (DIFF_DATA[1], DIFF_DATA[0]), - (DIFF_DATA[0], DIFF_DATA[1], DIFF_DATA[2]), - (DIFF_DATA[0], DIFF_DATA[3], DIFF_DATA[1]), -])) +QUERY_DIFF_DATA = list( + enumerate( + [ + (), + (DIFF_DATA[0],), + (DIFF_DATA[0], DIFF_DATA[1]), + (DIFF_DATA[1], DIFF_DATA[0]), + (DIFF_DATA[0], DIFF_DATA[1], DIFF_DATA[2]), + (DIFF_DATA[0], DIFF_DATA[3], DIFF_DATA[1]), + ] + ) +) QUERY_DIFF_JSON = """ { @@ -143,11 +154,17 @@ def test_data_mismatch_init(): DataMismatch(1, 1) -@pytest.mark.parametrize('mismatch_data, expected_key', zip(MISMATCH_DATA, [ - ('timeout', 'answer'), - (('A',), ('A', 'CNAME')), - (('A',), ('A', 'RRSIG(A)')), -])) +@pytest.mark.parametrize( + "mismatch_data, expected_key", + zip( + MISMATCH_DATA, + [ + ("timeout", "answer"), + (("A",), ("A", "CNAME")), + (("A",), ("A", "RRSIG(A)")), + ], + ), +) def test_data_mismatch(mismatch_data, expected_key): mismatch1 = DataMismatch(*mismatch_data) mismatch2 = DataMismatch(*mismatch_data) @@ -162,8 +179,9 @@ def test_data_mismatch(mismatch_data, expected_key): assert mismatch1.key == expected_key -@pytest.mark.parametrize('mismatch1_data, mismatch2_data', - itertools.combinations(MISMATCH_DATA, 2)) +@pytest.mark.parametrize( + "mismatch1_data, mismatch2_data", itertools.combinations(MISMATCH_DATA, 2) +) def test_data_mismatch_differnet_key_hash(mismatch1_data, mismatch2_data): if mismatch1_data == mismatch2_data: return @@ -181,35 +199,38 @@ def test_json_data_object(): assert empty.save() is None # simple scalar, list or dict -- no restore/save callbacks - attrs = {'a': (None, None)} + attrs = {"a": (None, None)} basic = JSONDataObject() basic.a = 1 basic._ATTRIBUTES = attrs data = basic.save() - assert data['a'] == 1 + assert data["a"] == 1 restored = JSONDataObject() restored._ATTRIBUTES = attrs restored.restore(data) assert restored.a == basic.a # with save/restore callback - attrs = {'b': (lambda x: x + 1, lambda x: x - 1)} + attrs = {"b": (lambda x: x + 1, lambda x: x - 1)} complex_obj = JSONDataObject() complex_obj.b = 1 complex_obj._ATTRIBUTES = attrs data = complex_obj.save() - assert data['b'] == 0 + assert data["b"] == 0 restored = JSONDataObject() restored._ATTRIBUTES = attrs restored.restore(data) assert restored.b == complex_obj.b + + # pylint: enable=protected-access -@pytest.mark.parametrize('qid, diff_data', QUERY_DIFF_DATA) +@pytest.mark.parametrize("qid, diff_data", QUERY_DIFF_DATA) def test_diff(qid, diff_data): - mismatches = {field: DataMismatch(*mismatch_data) - for field, mismatch_data in diff_data} + mismatches = { + field: DataMismatch(*mismatch_data) for field, mismatch_data in diff_data + } diff = Diff(qid, mismatches) fields = [] @@ -224,28 +245,38 @@ def test_diff(qid, diff_data): if not field_weights: continue field_weights = list(field_weights) - assert diff.get_significant_field(field_weights) == \ - (field_weights[0], mismatches[field_weights[0]]) - field_weights.append('custom') - assert diff.get_significant_field(field_weights) == \ - (field_weights[0], mismatches[field_weights[0]]) - field_weights.insert(0, 'custom2') - assert diff.get_significant_field(field_weights) == \ - (field_weights[1], mismatches[field_weights[1]]) - assert diff.get_significant_field(['non_existent']) == (None, None) + assert diff.get_significant_field(field_weights) == ( + field_weights[0], + mismatches[field_weights[0]], + ) + field_weights.append("custom") + assert diff.get_significant_field(field_weights) == ( + field_weights[0], + mismatches[field_weights[0]], + ) + field_weights.insert(0, "custom2") + assert diff.get_significant_field(field_weights) == ( + field_weights[1], + mismatches[field_weights[1]], + ) + assert diff.get_significant_field(["non_existent"]) == (None, None) # adding or removing items isn't possible with pytest.raises(Exception): - diff['non_existent'] = None # pylint: disable=unsupported-assignment-operation + diff["non_existent"] = None # pylint: disable=unsupported-assignment-operation with pytest.raises(Exception): del diff[list(diff.keys())[0]] # pylint: disable=unsupported-delete-operation def test_diff_equality(): - mismatches_tuple = {'timeout': DataMismatch('answer', 'timeout'), - 'answertypes': DataMismatch(('A',), ('CNAME',))} - mismatches_list = {'timeout': DataMismatch('answer', 'timeout'), - 'answertypes': DataMismatch(['A'], ['CNAME'])} + mismatches_tuple = { + "timeout": DataMismatch("answer", "timeout"), + "answertypes": DataMismatch(("A",), ("CNAME",)), + } + mismatches_list = { + "timeout": DataMismatch("answer", "timeout"), + "answertypes": DataMismatch(["A"], ["CNAME"]), + } # tuple or list doesn't matter assert Diff(1, mismatches_tuple) == Diff(1, mismatches_list) @@ -254,7 +285,7 @@ def test_diff_equality(): assert Diff(1, mismatches_tuple) == Diff(2, mismatches_tuple) # different mismatches - mismatches_tuple['answerrrsigs'] = DataMismatch(('RRSIG(A)',), ('',)) + mismatches_tuple["answerrrsigs"] = DataMismatch(("RRSIG(A)",), ("",)) assert Diff(1, mismatches_tuple) != Diff(1, mismatches_list) @@ -327,8 +358,8 @@ def test_diff_report(): # report with some missing fields partial_data = report.save() - del partial_data['other_disagreements'] - del partial_data['target_disagreements'] + del partial_data["other_disagreements"] + del partial_data["target_disagreements"] partial_report = DiffReport(_restore_dict=partial_data) assert partial_report.other_disagreements is None assert partial_report.target_disagreements is None @@ -337,7 +368,7 @@ def test_diff_report(): def test_summary(): - field_weights = ['timeout', 'answertypes', 'aswerrrsigs'] + field_weights = ["timeout", "answertypes", "aswerrrsigs"] report = DiffReport(_restore_dict=json.loads(DIFF_REPORT_JSON)) # no reprodata -- no queries are missing @@ -390,17 +421,17 @@ def test_repro_counter(): assert rc.verified == 1 assert rc.different_failure == 1 - rc = ReproCounter(_restore_dict={'retries': 4}) + rc = ReproCounter(_restore_dict={"retries": 4}) assert rc.retries == 4 assert rc.upstream_stable == 0 assert rc.verified == 0 assert rc.different_failure == 0 data = rc.save() - assert data['retries'] == 4 - assert data['upstream_stable'] == 0 - assert data['verified'] == 0 - assert data['different_failure'] == 0 + assert data["retries"] == 4 + assert data["upstream_stable"] == 0 + assert data["verified"] == 0 + assert data["different_failure"] == 0 assert ReproCounter().save() is None diff --git a/tests/test_database.py b/tests/test_database.py index d70b4e3..4811866 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -2,11 +2,18 @@ import os import pytest from respdiff.database import ( - DNSReply, DNSRepliesFactory, LMDB, MetaDatabase, BIN_FORMAT_VERSION, qid2key) + DNSReply, + DNSRepliesFactory, + LMDB, + MetaDatabase, + BIN_FORMAT_VERSION, + qid2key, +) LMDB_DIR = os.path.join( - os.path.abspath(os.path.dirname(__file__)), 'lmdb', BIN_FORMAT_VERSION) + os.path.abspath(os.path.dirname(__file__)), "lmdb", BIN_FORMAT_VERSION +) def create_reply(wire, time): @@ -18,71 +25,83 @@ def create_reply(wire, time): def test_dns_reply_timeout(): reply = DNSReply(None) assert reply.timeout - assert reply.time == float('+inf') - - -@pytest.mark.parametrize('wire1, time1, wire2, time2, equals', [ - (None, None, None, None, True), - (None, None, None, 1, True), - (b'', None, b'', None, True), - (b'', None, b'', 1, False), - (b'a', None, b'a', None, True), - (b'a', None, b'b', None, False), - (b'a', None, b'aa', None, False), -]) + assert reply.time == float("+inf") + + +@pytest.mark.parametrize( + "wire1, time1, wire2, time2, equals", + [ + (None, None, None, None, True), + (None, None, None, 1, True), + (b"", None, b"", None, True), + (b"", None, b"", 1, False), + (b"a", None, b"a", None, True), + (b"a", None, b"b", None, False), + (b"a", None, b"aa", None, False), + ], +) def test_dns_reply_equals(wire1, time1, wire2, time2, equals): r1 = create_reply(wire1, time1) r2 = create_reply(wire2, time2) assert (r1 == r2) == equals -@pytest.mark.parametrize('time, time_int', [ - (None, 0), - (0, 0), - (1.43, 1430000), - (0.4591856, 459186), -]) +@pytest.mark.parametrize( + "time, time_int", + [ + (None, 0), + (0, 0), + (1.43, 1430000), + (0.4591856, 459186), + ], +) def test_dns_reply_time_int(time, time_int): - reply = create_reply(b'', time) + reply = create_reply(b"", time) assert reply.time_int == time_int DR_TIMEOUT = DNSReply(None) -DR_TIMEOUT_BIN = b'\xff\xff\xff\xff\x00\x00' -DR_EMPTY_0 = DNSReply(b'') -DR_EMPTY_0_BIN = b'\x00\x00\x00\x00\x00\x00' -DR_EMPTY_1 = DNSReply(b'', 1) -DR_EMPTY_1_BIN = b'\x40\x42\x0f\x00\x00\x00' -DR_A_0 = DNSReply(b'a') -DR_A_0_BIN = b'\x00\x00\x00\x00\x01\x00a' -DR_A_1 = DNSReply(b'a', 1) -DR_A_1_BIN = b'\x40\x42\x0f\x00\x01\x00a' -DR_ABCD_1 = DNSReply(b'abcd', 1) -DR_ABCD_1_BIN = b'\x40\x42\x0f\x00\x04\x00abcd' - - -@pytest.mark.parametrize('reply, binary', [ - (DR_TIMEOUT, DR_TIMEOUT_BIN), - (DR_EMPTY_0, DR_EMPTY_0_BIN), - (DR_EMPTY_1, DR_EMPTY_1_BIN), - (DR_A_0, DR_A_0_BIN), - (DR_A_1, DR_A_1_BIN), - (DR_ABCD_1, DR_ABCD_1_BIN), -]) +DR_TIMEOUT_BIN = b"\xff\xff\xff\xff\x00\x00" +DR_EMPTY_0 = DNSReply(b"") +DR_EMPTY_0_BIN = b"\x00\x00\x00\x00\x00\x00" +DR_EMPTY_1 = DNSReply(b"", 1) +DR_EMPTY_1_BIN = b"\x40\x42\x0f\x00\x00\x00" +DR_A_0 = DNSReply(b"a") +DR_A_0_BIN = b"\x00\x00\x00\x00\x01\x00a" +DR_A_1 = DNSReply(b"a", 1) +DR_A_1_BIN = b"\x40\x42\x0f\x00\x01\x00a" +DR_ABCD_1 = DNSReply(b"abcd", 1) +DR_ABCD_1_BIN = b"\x40\x42\x0f\x00\x04\x00abcd" + + +@pytest.mark.parametrize( + "reply, binary", + [ + (DR_TIMEOUT, DR_TIMEOUT_BIN), + (DR_EMPTY_0, DR_EMPTY_0_BIN), + (DR_EMPTY_1, DR_EMPTY_1_BIN), + (DR_A_0, DR_A_0_BIN), + (DR_A_1, DR_A_1_BIN), + (DR_ABCD_1, DR_ABCD_1_BIN), + ], +) def test_dns_reply_serialization(reply, binary): assert reply.binary == binary -@pytest.mark.parametrize('binary, reply, remaining', [ - (DR_TIMEOUT_BIN, DR_TIMEOUT, b''), - (DR_EMPTY_0_BIN, DR_EMPTY_0, b''), - (DR_EMPTY_1_BIN, DR_EMPTY_1, b''), - (DR_A_0_BIN, DR_A_0, b''), - (DR_A_1_BIN, DR_A_1, b''), - (DR_ABCD_1_BIN, DR_ABCD_1, b''), - (DR_A_1_BIN + b'a', DR_A_1, b'a'), - (DR_ABCD_1_BIN + b'bcd', DR_ABCD_1, b'bcd'), -]) +@pytest.mark.parametrize( + "binary, reply, remaining", + [ + (DR_TIMEOUT_BIN, DR_TIMEOUT, b""), + (DR_EMPTY_0_BIN, DR_EMPTY_0, b""), + (DR_EMPTY_1_BIN, DR_EMPTY_1, b""), + (DR_A_0_BIN, DR_A_0, b""), + (DR_A_1_BIN, DR_A_1, b""), + (DR_ABCD_1_BIN, DR_ABCD_1, b""), + (DR_A_1_BIN + b"a", DR_A_1, b"a"), + (DR_ABCD_1_BIN + b"bcd", DR_ABCD_1, b"bcd"), + ], +) def test_dns_reply_deserialization(binary, reply, remaining): got_reply, buff = DNSReply.from_binary(binary) assert reply == got_reply @@ -93,18 +112,18 @@ def test_dns_replies_factory(): with pytest.raises(ValueError): DNSRepliesFactory([]) - rf = DNSRepliesFactory(['a']) + rf = DNSRepliesFactory(["a"]) replies = rf.parse(DR_TIMEOUT_BIN) - assert replies['a'] == DR_TIMEOUT + assert replies["a"] == DR_TIMEOUT - rf2 = DNSRepliesFactory(['a', 'b']) + rf2 = DNSRepliesFactory(["a", "b"]) bin_data = DR_A_0_BIN + DR_ABCD_1_BIN replies = rf2.parse(bin_data) - assert replies['a'] == DR_A_0 - assert replies['b'] == DR_ABCD_1 + assert replies["a"] == DR_A_0 + assert replies["b"] == DR_ABCD_1 with pytest.raises(ValueError): - rf2.parse(DR_A_0_BIN + b'a') + rf2.parse(DR_A_0_BIN + b"a") assert rf2.serialize(replies) == bin_data @@ -114,38 +133,38 @@ TIME_3M = 3000.0 def test_lmdb_answers_single_server(): - envdir = os.path.join(LMDB_DIR, 'answers_single_server') + envdir = os.path.join(LMDB_DIR, "answers_single_server") with LMDB(envdir) as lmdb: adb = lmdb.open_db(LMDB.ANSWERS) - meta = MetaDatabase(lmdb, ['kresd']) + meta = MetaDatabase(lmdb, ["kresd"]) assert meta.read_start_time() == INT_3M assert meta.read_end_time() == INT_3M servers = meta.read_servers() assert len(servers) == 1 - assert servers[0] == 'kresd' + assert servers[0] == "kresd" with lmdb.env.begin(adb) as txn: data = txn.get(qid2key(INT_3M)) df = DNSRepliesFactory(servers) replies = df.parse(data) assert len(replies) == 1 - assert replies[servers[0]] == DNSReply(b'a', TIME_3M) + assert replies[servers[0]] == DNSReply(b"a", TIME_3M) def test_lmdb_answers_multiple_servers(): - envdir = os.path.join(LMDB_DIR, 'answers_multiple_servers') + envdir = os.path.join(LMDB_DIR, "answers_multiple_servers") with LMDB(envdir) as lmdb: adb = lmdb.open_db(LMDB.ANSWERS) - meta = MetaDatabase(lmdb, ['kresd', 'bind', 'unbound']) + meta = MetaDatabase(lmdb, ["kresd", "bind", "unbound"]) assert meta.read_start_time() is None assert meta.read_end_time() is None servers = meta.read_servers() assert len(servers) == 3 - assert servers[0] == 'kresd' - assert servers[1] == 'bind' - assert servers[2] == 'unbound' + assert servers[0] == "kresd" + assert servers[1] == "bind" + assert servers[2] == "unbound" df = DNSRepliesFactory(servers) @@ -154,6 +173,6 @@ def test_lmdb_answers_multiple_servers(): replies = df.parse(data) assert len(replies) == 3 - assert replies[servers[0]] == DNSReply(b'', TIME_3M) - assert replies[servers[1]] == DNSReply(b'ab', TIME_3M) - assert replies[servers[2]] == DNSReply(b'a', TIME_3M) + assert replies[servers[0]] == DNSReply(b"", TIME_3M) + assert replies[servers[1]] == DNSReply(b"ab", TIME_3M) + assert replies[servers[2]] == DNSReply(b"a", TIME_3M) diff --git a/tests/test_qprep_pcap.py b/tests/test_qprep_pcap.py index 16b1bad..157d964 100644 --- a/tests/test_qprep_pcap.py +++ b/tests/test_qprep_pcap.py @@ -5,52 +5,73 @@ import pytest from qprep import wrk_process_frame, wrk_process_wire_packet -@pytest.mark.parametrize('wire', [ - b'', - b'x', - b'xx', -]) +@pytest.mark.parametrize( + "wire", + [ + b"", + b"x", + b"xx", + ], +) def test_wire_input_invalid(wire): - assert wrk_process_wire_packet(1, wire, 'invalid') == (1, wire) - assert wrk_process_wire_packet(1, wire, 'invalid') == (1, wire) + assert wrk_process_wire_packet(1, wire, "invalid") == (1, wire) + assert wrk_process_wire_packet(1, wire, "invalid") == (1, wire) -@pytest.mark.parametrize('wire_hex', [ - # www.audioweb.cz A - 'ed21010000010000000000010377777708617564696f77656202637a00000100010000291000000080000000', -]) +@pytest.mark.parametrize( + "wire_hex", + [ + # www.audioweb.cz A + "ed21010000010000000000010377777708617564696f77656202637a00000100010000291000000080000000", + ], +) def test_wire_input_valid(wire_hex): wire_in = binascii.unhexlify(wire_hex) - qid, wire_out = wrk_process_wire_packet(1, wire_in, 'qid 1') + qid, wire_out = wrk_process_wire_packet(1, wire_in, "qid 1") assert wire_in == wire_out assert qid == 1 -@pytest.mark.parametrize('wire_hex', [ - # test.dotnxdomain.net. A - ('ce970120000100000000000104746573740b646f746e78646f6d61696e036e657400000' - '10001000029100000000000000c000a00084a69fef0f174d87e'), - # 0es-u2af5c077-c56-s1492621913-i00000000.eue.dotnxdomain.net A - ('d72f01000001000000000001273065732d7532616635633037372d6335362d733134393' - '23632313931332d693030303030303030036575650b646f746e78646f6d61696e036e65' - '7400000100010000291000000080000000'), -]) +@pytest.mark.parametrize( + "wire_hex", + [ + # test.dotnxdomain.net. A + ( + "ce970120000100000000000104746573740b646f746e78646f6d61696e036e657400000" + "10001000029100000000000000c000a00084a69fef0f174d87e" + ), + # 0es-u2af5c077-c56-s1492621913-i00000000.eue.dotnxdomain.net A + ( + "d72f01000001000000000001273065732d7532616635633037372d6335362d733134393" + "23632313931332d693030303030303030036575650b646f746e78646f6d61696e036e65" + "7400000100010000291000000080000000" + ), + ], +) def test_pcap_input_blacklist(wire_hex): wire = binascii.unhexlify(wire_hex) - assert wrk_process_wire_packet(1, wire, 'qid 1') == (None, None) - - -@pytest.mark.parametrize('frame_hex, wire_hex', [ - # UPD nic.cz A - ('deadbeefcafecafebeefbeef08004500004bf9d000004011940d0202020201010101b533003500375520', - 'b90001200001000000000001036e696302637a0000010001000029100000000000000c000a00081491f8' - '93b0c90b2f'), - # TCP nic.cz A - ('deadbeefcafebeefbeefcafe080045000059e2f2400040066ae80202020201010101ace7003557b51707' - '47583400501800e5568c0000002f', '49e501200001000000000001036e696302637a00000100010000' - '29100000000000000c000a0008a1db546e1d6fa39f'), -]) + assert wrk_process_wire_packet(1, wire, "qid 1") == (None, None) + + +@pytest.mark.parametrize( + "frame_hex, wire_hex", + [ + # UPD nic.cz A + ( + "deadbeefcafecafebeefbeef08004500004bf9d000004011940d0202020201010101b533003500375520", + "b90001200001000000000001036e696302637a0000010001000029100000000000000c000a00081491f8" + "93b0c90b2f", + ), + # TCP nic.cz A + ( + "deadbeefcafebeefbeefcafe080045000059e2f2400040066ae80202020201010101ace7003557b51707" + "47583400501800e5568c0000002f", + "49e501200001000000000001036e696302637a00000100010000" + "29100000000000000c000a0008a1db546e1d6fa39f", + ), + ], +) def test_wrk_process_frame(frame_hex, wire_hex): data = binascii.unhexlify(frame_hex + wire_hex) wire = binascii.unhexlify(wire_hex) - assert wrk_process_frame((1, data, 'qid 1')) == (1, wire) + assert wrk_process_frame((1, data, "qid 1")) == (1, wire) diff --git a/tests/test_qprep_text.py b/tests/test_qprep_text.py index 553b145..ac61864 100644 --- a/tests/test_qprep_text.py +++ b/tests/test_qprep_text.py @@ -7,25 +7,31 @@ import pytest from qprep import wrk_process_line -@pytest.mark.parametrize('line', [ - '', - 'x'*256 + ' A', - '\123x.test. 65536', - '\321.test. 1', - 'test. A,AAAA', - 'test. A, AAAA', -]) +@pytest.mark.parametrize( + "line", + [ + "", + "x" * 256 + " A", + "\123x.test. 65536", + "\321.test. 1", + "test. A,AAAA", + "test. A, AAAA", + ], +) def test_text_input_invalid(line): assert wrk_process_line((1, line, line)) == (None, None) -@pytest.mark.parametrize('qname, qtype', [ - ('x', 'A'), - ('x', 1), - ('blabla.test.', 'TSIG'), -]) +@pytest.mark.parametrize( + "qname, qtype", + [ + ("x", "A"), + ("x", 1), + ("blabla.test.", "TSIG"), + ], +) def test_text_input_valid(qname, qtype): - line = '{} {}'.format(qname, qtype) + line = "{} {}".format(qname, qtype) if isinstance(qtype, int): rdtype = qtype @@ -39,12 +45,15 @@ def test_text_input_valid(qname, qtype): assert qid == 1 -@pytest.mark.parametrize('line', [ - 'test. ANY', - 'test. RRSIG', - 'dotnxdomain.net. 28', - 'something.dotnxdomain.net. A', - 'something.dashnxdomain.net. AAAA', -]) +@pytest.mark.parametrize( + "line", + [ + "test. ANY", + "test. RRSIG", + "dotnxdomain.net. 28", + "something.dotnxdomain.net. A", + "something.dashnxdomain.net. AAAA", + ], +) def test_text_input_blacklist(line): assert wrk_process_line((1, line, line)) == (None, None) diff --git a/utils/dns2txt.py b/utils/dns2txt.py index 57ad8e5..d0271db 100644 --- a/utils/dns2txt.py +++ b/utils/dns2txt.py @@ -4,6 +4,6 @@ import sys import dns.message -with open(sys.argv[1], 'rb') as f: +with open(sys.argv[1], "rb") as f: m = dns.message.from_wire(f.read()) print(str(m)) diff --git a/utils/normalize_names.py b/utils/normalize_names.py index 603a288..1d7fe0b 100644 --- a/utils/normalize_names.py +++ b/utils/normalize_names.py @@ -1,13 +1,13 @@ import string import sys -allowed_bytes = (string.ascii_letters + string.digits + '.-_').encode('ascii') +allowed_bytes = (string.ascii_letters + string.digits + ".-_").encode("ascii") trans = {} for i in range(0, 256): if i in allowed_bytes: - trans[i] = bytes(chr(i), encoding='ascii') + trans[i] = bytes(chr(i), encoding="ascii") else: - trans[i] = (r'\%03i' % i).encode('ascii') + trans[i] = (r"\%03i" % i).encode("ascii") # pprint(trans) while True: @@ -16,7 +16,7 @@ while True: break line = line[:-1] # newline - typestart = line.rfind(b' ') + 1 # rightmost space + typestart = line.rfind(b" ") + 1 # rightmost space if not typestart: continue # no RR type!? typetext = line[typestart:] @@ -24,10 +24,10 @@ while True: continue # normalize name - normalized = b'' - for nb in line[:typestart - 1]: + normalized = b"" + for nb in line[: typestart - 1]: normalized += trans[nb] sys.stdout.buffer.write(normalized) - sys.stdout.buffer.write(b' ') + sys.stdout.buffer.write(b" ") sys.stdout.buffer.write(typetext) - sys.stdout.buffer.write(b'\n') + sys.stdout.buffer.write(b"\n") -- GitLab From f2e3e389b863467c5f3714e585fceba4d54dd297 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicki=20K=C5=99=C3=AD=C5=BEek?= <nicki@isc.org> Date: Wed, 5 Feb 2025 15:44:52 +0100 Subject: [PATCH 3/5] Add .mailmap --- .mailmap | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 .mailmap diff --git a/.mailmap b/.mailmap new file mode 100644 index 0000000..5e43653 --- /dev/null +++ b/.mailmap @@ -0,0 +1,2 @@ +Nicki KřÞek <nicki@isc.org> <tkrizek@isc.org> +Nicki KřÞek <nicki@isc.org> <tomas.krizek@nic.cz> -- GitLab From fb2e291f6e9db310ae7bff3d1b4d668f0aae6cad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicki=20K=C5=99=C3=AD=C5=BEek?= <nicki@isc.org> Date: Wed, 5 Feb 2025 16:04:45 +0100 Subject: [PATCH 4/5] Minor edits to make linters happy --- pylintrc | 1 - respdiff/database.py | 5 ++--- respdiff/sendrecv.py | 6 ++---- tests/test_data.py | 2 +- 4 files changed, 5 insertions(+), 9 deletions(-) diff --git a/pylintrc b/pylintrc index 4c50773..e0c0b56 100644 --- a/pylintrc +++ b/pylintrc @@ -12,7 +12,6 @@ disable= line-too-long, # checked by flake8 invalid-name, broad-except, - bad-continuation, global-statement, no-else-return, duplicate-code, diff --git a/respdiff/database.py b/respdiff/database.py index bbce4e6..b78be56 100644 --- a/respdiff/database.py +++ b/respdiff/database.py @@ -199,7 +199,7 @@ class DNSReply: offset += cls.SIZEOF_INT (length,) = struct.unpack_from("<H", buff, offset) offset += cls.SIZEOF_SHORT - wire = buff[offset : (offset + length)] + wire = buff[offset:(offset + length)] offset += length if len(wire) != length: @@ -246,8 +246,7 @@ class DNSRepliesFactory: reply = replies[server] except KeyError as e: raise ValueError('Missing reply for server "{}"!'.format(server)) from e - else: - data.append(reply.binary) + data.append(reply.binary) return b"".join(data) diff --git a/respdiff/sendrecv.py b/respdiff/sendrecv.py index ba30f5a..9330c72 100644 --- a/respdiff/sendrecv.py +++ b/respdiff/sendrecv.py @@ -21,7 +21,6 @@ import threading from typing import ( Any, Dict, - List, Mapping, Optional, Sequence, @@ -281,9 +280,8 @@ def _recv_msg(sock: Socket, isstream: IsStreamFlag) -> WireFormat: blength = sock.recv(2) except ssl.SSLWantReadError as e: raise TcpDnsLengthError("failed to recv DNS packet length") from e - else: - if len(blength) != 2: # FIN / RST - raise TcpDnsLengthError("failed to recv DNS packet length") + if len(blength) != 2: # FIN / RST + raise TcpDnsLengthError("failed to recv DNS packet length") (length,) = struct.unpack("!H", blength) else: length = 65535 # max. UDP message size, no IPv6 jumbograms diff --git a/tests/test_data.py b/tests/test_data.py index ec632fa..3bd6ff5 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -151,7 +151,7 @@ REPRODATA_JSON = """ def test_data_mismatch_init(): with pytest.raises(Exception): - DataMismatch(1, 1) + DataMismatch(1, 1) # pylint: disable=pointless-exception-statement @pytest.mark.parametrize( -- GitLab From 56c4e9a7a0a0681c9dc484eb6bb41d99a21db8c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicki=20K=C5=99=C3=AD=C5=BEek?= <nicki@isc.org> Date: Wed, 5 Feb 2025 16:05:11 +0100 Subject: [PATCH 5/5] Remove mypy check Unfortunately, the mypy check has been rotting for years now. Let's not pretend like this check is maintained and remove the associated code. --- .gitlab-ci.yml | 5 ----- ci/mypy-run.sh | 10 ---------- 2 files changed, 15 deletions(-) delete mode 100755 ci/mypy-run.sh diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 3726ba2..bd0a0e1 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -11,11 +11,6 @@ stages: - linux - amd64 -test:mypy: - <<: *debian - script: - - ./ci/mypy-run.sh - test:flake8: <<: *debian script: diff --git a/ci/mypy-run.sh b/ci/mypy-run.sh deleted file mode 100755 index 5454b8a..0000000 --- a/ci/mypy-run.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env bash -set -e - -# Find Python scripts -FILES=$(find . \ - -path './ci' -prune -o \ - -path './.git' -prune -o \ - -name '*.py' -print) - -python3 -m mypy --install-types --non-interactive --ignore-missing-imports ${FILES} -- GitLab