diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 3726ba28389932f166335ab67158f300f80be2d0..bd0a0e1d092b8a2b657f743db30211f955eb9795 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -11,11 +11,6 @@ stages: - linux - amd64 -test:mypy: - <<: *debian - script: - - ./ci/mypy-run.sh - test:flake8: <<: *debian script: diff --git a/.mailmap b/.mailmap new file mode 100644 index 0000000000000000000000000000000000000000..5e43653ff5e2a4005e1114d78b8f05f1a1a698f9 --- /dev/null +++ b/.mailmap @@ -0,0 +1,2 @@ +Nicki KĹ™ĂĹľek <nicki@isc.org> <tkrizek@isc.org> +Nicki KĹ™ĂĹľek <nicki@isc.org> <tomas.krizek@nic.cz> diff --git a/ci/mypy-run.sh b/ci/mypy-run.sh deleted file mode 100755 index 5454b8a37628e74236f1c6d6d5dcb4f4a94cf3ab..0000000000000000000000000000000000000000 --- a/ci/mypy-run.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env bash -set -e - -# Find Python scripts -FILES=$(find . \ - -path './ci' -prune -o \ - -path './.git' -prune -o \ - -name '*.py' -print) - -python3 -m mypy --install-types --non-interactive --ignore-missing-imports ${FILES} diff --git a/contrib/job_manager/create.py b/contrib/job_manager/create.py index fb2970fb706d61032069d8352c408336f5ac1a1f..08d809bba05f8125171e9f50ffa0b9b45c496a0b 100755 --- a/contrib/job_manager/create.py +++ b/contrib/job_manager/create.py @@ -14,8 +14,8 @@ import yaml DIR_PATH = os.path.dirname(os.path.realpath(__file__)) -TEST_CASE_DIR = os.path.join(DIR_PATH, 'test_cases') -FILES_DIR = os.path.join(DIR_PATH, 'files') +TEST_CASE_DIR = os.path.join(DIR_PATH, "test_cases") +FILES_DIR = os.path.join(DIR_PATH, "files") def prepare_dir(directory: str, clean: bool = False) -> None: @@ -29,33 +29,32 @@ def prepare_dir(directory: str, clean: bool = False) -> None: except FileExistsError as e: raise RuntimeError( 'Directory "{}" already exists! Use -l label / --clean or (re)move the ' - 'directory manually to resolve the issue.'.format(directory)) from e + "directory manually to resolve the issue.".format(directory) + ) from e -def copy_file(name: str, destdir: str, destname: str = ''): +def copy_file(name: str, destdir: str, destname: str = ""): if not destname: destname = name - shutil.copy( - os.path.join(FILES_DIR, name), - os.path.join(destdir, destname)) + shutil.copy(os.path.join(FILES_DIR, name), os.path.join(destdir, destname)) def create_file_from_template( - name: str, - data: Mapping[str, Any], - destdir: str, - destname: str = '', - executable=False - ) -> None: + name: str, + data: Mapping[str, Any], + destdir: str, + destname: str = "", + executable=False, +) -> None: env = jinja2.Environment(loader=jinja2.FileSystemLoader(FILES_DIR)) template = env.get_template(name) rendered = template.render(**data) if not destname: - assert name[-3:] == '.j2' + assert name[-3:] == ".j2" destname = os.path.basename(name)[:-3] dest = os.path.join(destdir, destname) - with open(dest, 'w', encoding='UTF-8') as fh: + with open(dest, "w", encoding="UTF-8") as fh: fh.write(rendered) if executable: @@ -64,86 +63,99 @@ def create_file_from_template( def load_test_case_config(test_case: str) -> Dict[str, Any]: - path = os.path.join(TEST_CASE_DIR, test_case + '.yaml') - with open(path, 'r', encoding='UTF-8') as f: + path = os.path.join(TEST_CASE_DIR, test_case + ".yaml") + with open(path, "r", encoding="UTF-8") as f: return yaml.safe_load(f) def create_resolver_configs(directory: str, config: Dict[str, Any]): - for name, resolver in config['resolvers'].items(): - resolver['name'] = name - resolver['verbose'] = config['verbose'] - if resolver['type'] == 'knot-resolver': - dockerfile_dir = os.path.join(directory, 'docker-knot-resolver') + for name, resolver in config["resolvers"].items(): + resolver["name"] = name + resolver["verbose"] = config["verbose"] + if resolver["type"] == "knot-resolver": + dockerfile_dir = os.path.join(directory, "docker-knot-resolver") if not os.path.exists(dockerfile_dir): os.makedirs(dockerfile_dir) - copy_file('Dockerfile.knot-resolver', dockerfile_dir, 'Dockerfile') - copy_file('kresd.entrypoint.sh', dockerfile_dir) + copy_file("Dockerfile.knot-resolver", dockerfile_dir, "Dockerfile") + copy_file("kresd.entrypoint.sh", dockerfile_dir) create_file_from_template( - 'kresd.conf.j2', resolver, directory, name + '.conf') - elif resolver['type'] == 'unbound': + "kresd.conf.j2", resolver, directory, name + ".conf" + ) + elif resolver["type"] == "unbound": create_file_from_template( - 'unbound.conf.j2', resolver, directory, name + '.conf') - copy_file('cert.pem', directory) - copy_file('key.pem', directory) - copy_file('root.keys', directory) - elif resolver['type'] == 'bind': + "unbound.conf.j2", resolver, directory, name + ".conf" + ) + copy_file("cert.pem", directory) + copy_file("key.pem", directory) + copy_file("root.keys", directory) + elif resolver["type"] == "bind": create_file_from_template( - 'named.conf.j2', resolver, directory, name + '.conf') - copy_file('rfc1912.zones', directory) - copy_file('bind.keys', directory) + "named.conf.j2", resolver, directory, name + ".conf" + ) + copy_file("rfc1912.zones", directory) + copy_file("bind.keys", directory) else: raise NotImplementedError( - "unknown resolver type: '{}'".format(resolver['type'])) + "unknown resolver type: '{}'".format(resolver["type"]) + ) def create_resperf_files(directory: str, config: Dict[str, Any]): - create_file_from_template('run_resperf.sh.j2', config, directory, executable=True) - create_file_from_template('docker-compose.yaml.j2', config, directory) + create_file_from_template("run_resperf.sh.j2", config, directory, executable=True) + create_file_from_template("docker-compose.yaml.j2", config, directory) create_resolver_configs(directory, config) def create_distrotest_files(directory: str, config: Dict[str, Any]): - create_file_from_template('run_distrotest.sh.j2', config, directory, executable=True) + create_file_from_template( + "run_distrotest.sh.j2", config, directory, executable=True + ) def create_respdiff_files(directory: str, config: Dict[str, Any]): - create_file_from_template('run_respdiff.sh.j2', config, directory, executable=True) - create_file_from_template('restart-all.sh.j2', config, directory, executable=True) - create_file_from_template('docker-compose.yaml.j2', config, directory) + create_file_from_template("run_respdiff.sh.j2", config, directory, executable=True) + create_file_from_template("restart-all.sh.j2", config, directory, executable=True) + create_file_from_template("docker-compose.yaml.j2", config, directory) create_resolver_configs(directory, config) # omit resolvers without respdiff section from respdiff.cfg - config['resolvers'] = { - name: res for name, res - in config['resolvers'].items() - if 'respdiff' in res} - create_file_from_template('respdiff.cfg.j2', config, directory) + config["resolvers"] = { + name: res for name, res in config["resolvers"].items() if "respdiff" in res + } + create_file_from_template("respdiff.cfg.j2", config, directory) - if config['respdiff_stats']: # copy optional stats file + if config["respdiff_stats"]: # copy optional stats file try: - shutil.copyfile(config['respdiff_stats'], os.path.join(directory, 'stats.json')) + shutil.copyfile( + config["respdiff_stats"], os.path.join(directory, "stats.json") + ) except FileNotFoundError as e: raise RuntimeError( - "Statistics file not found: {}".format(config['respdiff_stats'])) from e + "Statistics file not found: {}".format(config["respdiff_stats"]) + ) from e def create_template_files(directory: str, config: Dict[str, Any]): - if 'respdiff' in config: + if "respdiff" in config: create_respdiff_files(directory, config) - elif 'resperf' in config: + elif "resperf" in config: create_resperf_files(directory, config) - elif 'distrotest' in config: + elif "distrotest" in config: create_distrotest_files(directory, config) -def get_test_case_list(nameglob: str = '') -> List[str]: +def get_test_case_list(nameglob: str = "") -> List[str]: # test cases end with '*.jXXX' implying a number of jobs (less -> longer runtime) # return them in ascending order, so more time consuming test cases run first - return sorted([ - os.path.splitext(os.path.basename(fname))[0] - for fname in glob.glob(os.path.join(TEST_CASE_DIR, '{}*.yaml'.format(nameglob)))], - key=lambda x: x.split('.')[-1]) # sort by job count + return sorted( + [ + os.path.splitext(os.path.basename(fname))[0] + for fname in glob.glob( + os.path.join(TEST_CASE_DIR, "{}*.yaml".format(nameglob)) + ) + ], + key=lambda x: x.split(".")[-1], + ) # sort by job count def create_jobs(args: argparse.Namespace) -> None: @@ -158,20 +170,20 @@ def create_jobs(args: argparse.Namespace) -> None: git_sha = args.sha_or_tag commit_dir = git_sha if args.label is not None: - if ' ' in args.label: - raise RuntimeError('Label may not contain spaces.') - commit_dir += '-' + args.label + if " " in args.label: + raise RuntimeError("Label may not contain spaces.") + commit_dir += "-" + args.label for test_case in test_cases: config = load_test_case_config(test_case) - config['git_sha'] = git_sha - config['knot_branch'] = args.knot_branch - config['verbose'] = args.verbose - config['asan'] = args.asan - config['log_keys'] = args.log_keys - config['respdiff_stats'] = args.respdiff_stats - config['obs_repo'] = args.obs_repo - config['package'] = args.package + config["git_sha"] = git_sha + config["knot_branch"] = args.knot_branch + config["verbose"] = args.verbose + config["asan"] = args.asan + config["log_keys"] = args.log_keys + config["respdiff_stats"] = args.respdiff_stats + config["obs_repo"] = args.obs_repo + config["package"] = args.package directory = os.path.join(args.jobs_dir, commit_dir, test_case) prepare_dir(directory, clean=args.clean) @@ -184,55 +196,83 @@ def create_jobs(args: argparse.Namespace) -> None: def main() -> None: parser = argparse.ArgumentParser( - description="Prepare files for docker respdiff job") + description="Prepare files for docker respdiff job" + ) parser.add_argument( - 'sha_or_tag', type=str, - help="Knot Resolver git commit or tag to use (don't use branch!)") + "sha_or_tag", + type=str, + help="Knot Resolver git commit or tag to use (don't use branch!)", + ) parser.add_argument( - '-a', '--all', default='shortlist', - help="Create all test cases which start with expression (default: shortlist)") + "-a", + "--all", + default="shortlist", + help="Create all test cases which start with expression (default: shortlist)", + ) parser.add_argument( - '-t', choices=get_test_case_list(), - help="Create only the specified test case") + "-t", choices=get_test_case_list(), help="Create only the specified test case" + ) parser.add_argument( - '-l', '--label', - help="Assign label for easier job identification and isolation") + "-l", "--label", help="Assign label for easier job identification and isolation" + ) parser.add_argument( - '--clean', action='store_true', - help="Remove target directory if it already exists (use with caution!)") + "--clean", + action="store_true", + help="Remove target directory if it already exists (use with caution!)", + ) parser.add_argument( - '--jobs-dir', default='/var/tmp/respdiff-jobs', - help="Directory with job collections (default: /var/tmp/respdiff-jobs)") + "--jobs-dir", + default="/var/tmp/respdiff-jobs", + help="Directory with job collections (default: /var/tmp/respdiff-jobs)", + ) parser.add_argument( - '--knot-branch', type=str, default='3.1', - help="Build knot-resolver against selected Knot DNS branch") + "--knot-branch", + type=str, + default="3.1", + help="Build knot-resolver against selected Knot DNS branch", + ) parser.add_argument( - '-v', '--verbose', action='store_true', - help="Capture verbose logs (kresd, unbound)") + "-v", + "--verbose", + action="store_true", + help="Capture verbose logs (kresd, unbound)", + ) parser.add_argument( - '--asan', action='store_true', - help="Build with Address Sanitizer") + "--asan", action="store_true", help="Build with Address Sanitizer" + ) parser.add_argument( - '--log-keys', action='store_true', - help="Log TLS session keys (kresd)") + "--log-keys", action="store_true", help="Log TLS session keys (kresd)" + ) parser.add_argument( - '--respdiff-stats', type=str, default='', - help=("Statistics file to generate extra respdiff report(s) with omitted " - "unstable/failing queries")) + "--respdiff-stats", + type=str, + default="", + help=( + "Statistics file to generate extra respdiff report(s) with omitted " + "unstable/failing queries" + ), + ) parser.add_argument( - '--obs-repo', type=str, default='knot-resolver-devel', - help=("OBS repository for distrotests (default: knot-resolver-devel)")) + "--obs-repo", + type=str, + default="knot-resolver-devel", + help=("OBS repository for distrotests (default: knot-resolver-devel)"), + ) parser.add_argument( - '--package', type=str, default='knot-resolver', - help=("package for distrotests (default: knot-resolver)")) + "--package", + type=str, + default="knot-resolver", + help=("package for distrotests (default: knot-resolver)"), + ) args = parser.parse_args() create_jobs(args) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig( - format='%(asctime)s %(levelname)8s %(message)s', level=logging.DEBUG) + format="%(asctime)s %(levelname)8s %(message)s", level=logging.DEBUG + ) try: main() except RuntimeError as exc: @@ -241,5 +281,5 @@ if __name__ == '__main__': sys.exit(1) except Exception as exc: logging.debug(traceback.format_exc()) - logging.critical('Unhandled code exception: %s', str(exc)) + logging.critical("Unhandled code exception: %s", str(exc)) sys.exit(2) diff --git a/contrib/job_manager/submit.py b/contrib/job_manager/submit.py index c9a75f77dc9fdd8a77e689c5b36ba5204e8d549d..44db1f5d4b8cec481e813fc186044bf320152488 100755 --- a/contrib/job_manager/submit.py +++ b/contrib/job_manager/submit.py @@ -19,9 +19,9 @@ JOB_STATUS_RUNNING = 2 def get_all_files(directory: str) -> List[str]: files = [] - for filename in glob.iglob('{}/**'.format(directory), recursive=True): + for filename in glob.iglob("{}/**".format(directory), recursive=True): # omit job artifacts (begins with j*) - if os.path.isfile(filename) and not os.path.basename(filename).startswith('j'): + if os.path.isfile(filename) and not os.path.basename(filename).startswith("j"): files.append(os.path.relpath(filename, directory)) return files @@ -30,61 +30,69 @@ def condor_submit(txn, priority: int) -> int: directory = os.getcwd() input_files = get_all_files(directory) - if 'run_respdiff.sh' in input_files: - executable = 'run_respdiff.sh' + if "run_respdiff.sh" in input_files: + executable = "run_respdiff.sh" output_files = [ - 'j$(Cluster).$(Process)_docker.txt', - 'j$(Cluster).$(Process)_report.json', - 'j$(Cluster).$(Process)_report.diffrepro.json', - 'j$(Cluster).$(Process)_report.txt', - 'j$(Cluster).$(Process)_report.diffrepro.txt', - 'j$(Cluster).$(Process)_histogram.tar.gz', - 'j$(Cluster).$(Process)_logs.tar.gz'] - if 'stats.json' in input_files: - output_files.extend([ - 'j$(Cluster).$(Process)_report.noref.json', - 'j$(Cluster).$(Process)_report.noref.txt', - 'j$(Cluster).$(Process)_report.diffrepro.noref.json', - 'j$(Cluster).$(Process)_report.diffrepro.noref.txt', - # 'j$(Cluster).$(Process)_dnsviz.json.gz', - # 'j$(Cluster).$(Process)_report.noref.dnsviz.json', - # 'j$(Cluster).$(Process)_report.noref.dnsviz.txt', - ]) - elif 'run_resperf.sh' in input_files: - executable = 'run_resperf.sh' + "j$(Cluster).$(Process)_docker.txt", + "j$(Cluster).$(Process)_report.json", + "j$(Cluster).$(Process)_report.diffrepro.json", + "j$(Cluster).$(Process)_report.txt", + "j$(Cluster).$(Process)_report.diffrepro.txt", + "j$(Cluster).$(Process)_histogram.tar.gz", + "j$(Cluster).$(Process)_logs.tar.gz", + ] + if "stats.json" in input_files: + output_files.extend( + [ + "j$(Cluster).$(Process)_report.noref.json", + "j$(Cluster).$(Process)_report.noref.txt", + "j$(Cluster).$(Process)_report.diffrepro.noref.json", + "j$(Cluster).$(Process)_report.diffrepro.noref.txt", + # 'j$(Cluster).$(Process)_dnsviz.json.gz', + # 'j$(Cluster).$(Process)_report.noref.dnsviz.json', + # 'j$(Cluster).$(Process)_report.noref.dnsviz.txt', + ] + ) + elif "run_resperf.sh" in input_files: + executable = "run_resperf.sh" output_files = [ - 'j$(Cluster).$(Process)_exitcode', - 'j$(Cluster).$(Process)_docker.txt', - 'j$(Cluster).$(Process)_resperf.txt', - 'j$(Cluster).$(Process)_logs.tar.gz'] - elif 'run_distrotest.sh' in input_files: - executable = 'run_distrotest.sh' + "j$(Cluster).$(Process)_exitcode", + "j$(Cluster).$(Process)_docker.txt", + "j$(Cluster).$(Process)_resperf.txt", + "j$(Cluster).$(Process)_logs.tar.gz", + ] + elif "run_distrotest.sh" in input_files: + executable = "run_distrotest.sh" output_files = [ - 'j$(Cluster).$(Process)_exitcode', - 'j$(Cluster).$(Process)_vagrant.log.txt'] + "j$(Cluster).$(Process)_exitcode", + "j$(Cluster).$(Process)_vagrant.log.txt", + ] else: raise RuntimeError( "The provided directory doesn't look like a respdiff/resperf job. " - "{}/run_*.sh is missing!".format(directory)) + "{}/run_*.sh is missing!".format(directory) + ) # create batch name from dir structure commit_dir_path, test_case = os.path.split(directory) _, commit_dir = os.path.split(commit_dir_path) - batch_name = commit_dir + '_' + test_case - - submit = Submit({ - 'priority': str(priority), - 'executable': executable, - 'arguments': '$(Cluster) $(Process)', - 'error': 'j$(Cluster).$(Process)_stderr.txt', - 'output': 'j$(Cluster).$(Process)_stdout.txt', - 'log': 'j$(Cluster).$(Process)_log.txt', - 'jobbatchname': batch_name, - 'should_transfer_files': 'YES', - 'when_to_transfer_output': 'ON_EXIT', - 'transfer_input_files': ', '.join(input_files), - 'transfer_output_files': ', '.join(output_files), - }) + batch_name = commit_dir + "_" + test_case + + submit = Submit( + { + "priority": str(priority), + "executable": executable, + "arguments": "$(Cluster) $(Process)", + "error": "j$(Cluster).$(Process)_stderr.txt", + "output": "j$(Cluster).$(Process)_stdout.txt", + "log": "j$(Cluster).$(Process)_log.txt", + "jobbatchname": batch_name, + "should_transfer_files": "YES", + "when_to_transfer_output": "ON_EXIT", + "transfer_input_files": ", ".join(input_files), + "transfer_output_files": ", ".join(output_files), + } + ) return submit.queue(txn) @@ -107,12 +115,17 @@ def condor_wait_for(schedd, job_ids: Sequence[int]) -> None: break # log only status changes - if (remaining != prev_remaining or - running != prev_running or - worst_pos != prev_worst_pos): + if ( + remaining != prev_remaining + or running != prev_running + or worst_pos != prev_worst_pos + ): logging.info( " remaning: %2d (running: %2d) worst queue position: %2d", - remaining, running, worst_pos + 1) + remaining, + running, + worst_pos + 1, + ) prev_remaining = remaining prev_running = running @@ -122,19 +135,19 @@ def condor_wait_for(schedd, job_ids: Sequence[int]) -> None: def condor_check_status(schedd, job_ids: Sequence[int]) -> Tuple[int, int, int]: - all_jobs = schedd.query(True, ['JobPrio', 'ClusterId', 'ProcId', 'JobStatus']) + all_jobs = schedd.query(True, ["JobPrio", "ClusterId", "ProcId", "JobStatus"]) all_jobs = sorted( - all_jobs, - key=lambda x: (-x['JobPrio'], x['ClusterId'], x['ProcId'])) + all_jobs, key=lambda x: (-x["JobPrio"], x["ClusterId"], x["ProcId"]) + ) worst_pos = 0 running = 0 remaining = 0 for i, job in enumerate(all_jobs): - if int(job['ClusterId']) in job_ids: + if int(job["ClusterId"]) in job_ids: remaining += 1 - if int(job['JobStatus']) == JOB_STATUS_RUNNING: + if int(job["JobStatus"]) == JOB_STATUS_RUNNING: running += 1 worst_pos = i @@ -143,19 +156,31 @@ def condor_check_status(schedd, job_ids: Sequence[int]) -> Tuple[int, int, int]: def main() -> None: parser = argparse.ArgumentParser( - description="Submit prepared jobs to HTCondor cluster") + description="Submit prepared jobs to HTCondor cluster" + ) parser.add_argument( - '-c', '--count', type=int, default=1, - help="How many times to submit job (default: 1)") + "-c", + "--count", + type=int, + default=1, + help="How many times to submit job (default: 1)", + ) parser.add_argument( - '-p', '--priority', type=int, default=5, - help="Set condor job priority, higher means sooner execution (default: 5)") + "-p", + "--priority", + type=int, + default=5, + help="Set condor job priority, higher means sooner execution (default: 5)", + ) parser.add_argument( - '-w', '--wait', action='store_true', - help="Wait until all submitted jobs are finished") + "-w", + "--wait", + action="store_true", + help="Wait until all submitted jobs are finished", + ) parser.add_argument( - 'job_dir', nargs='+', - help="Path to the job directory to be submitted") + "job_dir", nargs="+", help="Path to the job directory to be submitted" + ) args = parser.parse_args() @@ -170,14 +195,15 @@ def main() -> None: job_ids[directory].append(condor_submit(txn, args.priority)) for directory, jobs in job_ids.items(): - logging.info("%s JOB ID(s): %s", directory, ', '.join(str(j) for j in jobs)) + logging.info("%s JOB ID(s): %s", directory, ", ".join(str(j) for j in jobs)) job_count = sum(len(jobs) for jobs in job_ids.values()) logging.info("%d job(s) successfully submitted!", job_count) if args.wait: logging.info( - 'WAITING for jobs to complete. This can be safely interrupted with Ctl+C...') + "WAITING for jobs to complete. This can be safely interrupted with Ctl+C..." + ) try: condor_wait_for(schedd, list(itertools.chain(*job_ids.values()))) except KeyboardInterrupt: @@ -186,16 +212,17 @@ def main() -> None: logging.info("All jobs done!") -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig( - format='%(asctime)s %(levelname)8s %(message)s', level=logging.DEBUG) + format="%(asctime)s %(levelname)8s %(message)s", level=logging.DEBUG + ) with warnings.catch_warnings(): warnings.simplefilter("error") # trigger UserWarning which causes ImportError try: from htcondor import Submit, Schedd except (ImportError, UserWarning): - logging.error('HTCondor not detected. Use this script on a submit machine.') + logging.error("HTCondor not detected. Use this script on a submit machine.") sys.exit(1) try: @@ -206,5 +233,5 @@ if __name__ == '__main__': sys.exit(1) except Exception as exc: logging.debug(traceback.format_exc()) - logging.critical('Unhandled code exception: %s', str(exc)) + logging.critical("Unhandled code exception: %s", str(exc)) sys.exit(2) diff --git a/diffrepro.py b/diffrepro.py index 92ef28a80af548a9222c63c37d0c79f3b5b7ae30..127616e103643ebf1022b887d4d9b322e8f555a2 100755 --- a/diffrepro.py +++ b/diffrepro.py @@ -10,25 +10,30 @@ from respdiff.dataformat import DiffReport, ReproData def main(): cli.setup_logging() parser = argparse.ArgumentParser( - description='attempt to reproduce original diffs from JSON report') + description="attempt to reproduce original diffs from JSON report" + ) cli.add_arg_envdir(parser) cli.add_arg_config(parser) cli.add_arg_datafile(parser) - parser.add_argument('--sequential', action='store_true', default=False, - help='send one query at a time (slower, but more reliable)') + parser.add_argument( + "--sequential", + action="store_true", + default=False, + help="send one query at a time (slower, but more reliable)", + ) args = parser.parse_args() sendrecv.module_init(args) datafile = cli.get_datafile(args) report = DiffReport.from_json(datafile) restart_scripts = repro.get_restart_scripts(args.cfg) - servers = args.cfg['servers']['names'] + servers = args.cfg["servers"]["names"] dnsreplies_factory = DNSRepliesFactory(servers) if args.sequential: nproc = 1 else: - nproc = args.cfg['sendrecv']['jobs'] + nproc = args.cfg["sendrecv"]["jobs"] if report.reprodata is None: report.reprodata = ReproData() @@ -40,12 +45,18 @@ def main(): dstream = repro.query_stream_from_disagreements(lmdb, report) try: repro.reproduce_queries( - dstream, report, dnsreplies_factory, args.cfg['diff']['criteria'], - args.cfg['diff']['target'], restart_scripts, nproc) + dstream, + report, + dnsreplies_factory, + args.cfg["diff"]["criteria"], + args.cfg["diff"]["target"], + restart_scripts, + nproc, + ) finally: # make sure data is saved in case of interrupt report.export_json(datafile) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/diffsum.py b/diffsum.py index 6d375aa77e7cdbee03ac5796e1a1654957b33145..4f53c1fcba1f0e4df1c5d3b1a7cecbab3af0dccb 100755 --- a/diffsum.py +++ b/diffsum.py @@ -4,8 +4,17 @@ import argparse import logging import sys from typing import ( # noqa - Any, Callable, Iterable, Iterator, ItemsView, List, Set, Sequence, Tuple, - Union) + Any, + Callable, + Iterable, + Iterator, + ItemsView, + List, + Set, + Sequence, + Tuple, + Union, +) import dns.message from respdiff import cli @@ -13,41 +22,61 @@ from respdiff.database import LMDB from respdiff.dataformat import DiffReport, Summary from respdiff.dnsviz import DnsvizGrok from respdiff.query import ( - convert_queries, get_printable_queries_format, get_query_iterator, qwire_to_msgid_qname_qtype) + convert_queries, + get_printable_queries_format, + get_query_iterator, + qwire_to_msgid_qname_qtype, +) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( - description='create a summary report from gathered data stored in LMDB ' - 'and JSON datafile') + description="create a summary report from gathered data stored in LMDB " + "and JSON datafile" + ) cli.add_arg_envdir(parser) cli.add_arg_config(parser) cli.add_arg_datafile(parser) cli.add_arg_limit(parser) - cli.add_arg_stats_filename(parser, default='') - cli.add_arg_dnsviz(parser, default='') - parser.add_argument('--without-dnsviz-errors', action='store_true', - help='omit domains that have any errors in DNSViz results') - parser.add_argument('--without-diffrepro', action='store_true', - help='omit reproducibility data from summary') - parser.add_argument('--without-ref-unstable', action='store_true', - help='omit unstable reference queries from summary') - parser.add_argument('--without-ref-failing', action='store_true', - help='omit failing reference queries from summary') + cli.add_arg_stats_filename(parser, default="") + cli.add_arg_dnsviz(parser, default="") + parser.add_argument( + "--without-dnsviz-errors", + action="store_true", + help="omit domains that have any errors in DNSViz results", + ) + parser.add_argument( + "--without-diffrepro", + action="store_true", + help="omit reproducibility data from summary", + ) + parser.add_argument( + "--without-ref-unstable", + action="store_true", + help="omit unstable reference queries from summary", + ) + parser.add_argument( + "--without-ref-failing", + action="store_true", + help="omit failing reference queries from summary", + ) return parser.parse_args() def check_args(args: argparse.Namespace, report: DiffReport): - if (args.without_ref_unstable or args.without_ref_failing) \ - and not args.stats_filename: + if ( + args.without_ref_unstable or args.without_ref_failing + ) and not args.stats_filename: logging.critical("Statistics file must be provided as a reference.") sys.exit(1) if not report.total_answers: - logging.error('No answers in DB!') + logging.error("No answers in DB!") sys.exit(1) if report.target_disagreements is None: - logging.error('JSON report is missing diff data! Did you forget to run msgdiff?') + logging.error( + "JSON report is missing diff data! Did you forget to run msgdiff?" + ) sys.exit(1) @@ -56,7 +85,7 @@ def main(): args = parse_args() datafile = cli.get_datafile(args) report = DiffReport.from_json(datafile) - field_weights = args.cfg['report']['field_weights'] + field_weights = args.cfg["report"]["field_weights"] check_args(args, report) @@ -74,16 +103,18 @@ def main(): report = DiffReport.from_json(datafile) report.summary = Summary.from_report( - report, field_weights, + report, + field_weights, without_diffrepro=args.without_diffrepro, - ignore_qids=ignore_qids) + ignore_qids=ignore_qids, + ) # dnsviz filter: by domain -> need to iterate over disagreements to get QIDs if args.without_dnsviz_errors: try: dnsviz_grok = DnsvizGrok.from_json(args.dnsviz) except (FileNotFoundError, RuntimeError) as exc: - logging.critical('Failed to load dnsviz data: %s', exc) + logging.critical("Failed to load dnsviz data: %s", exc) sys.exit(1) error_domains = dnsviz_grok.error_domains() @@ -93,14 +124,18 @@ def main(): for qid, wire in get_query_iterator(lmdb, report.summary.keys()): msg = dns.message.from_wire(wire) if msg.question: - if any(msg.question[0].name.is_subdomain(name) - for name in error_domains): + if any( + msg.question[0].name.is_subdomain(name) + for name in error_domains + ): ignore_qids.add(qid) report.summary = Summary.from_report( - report, field_weights, + report, + field_weights, without_diffrepro=args.without_diffrepro, - ignore_qids=ignore_qids) + ignore_qids=ignore_qids, + ) cli.print_global_stats(report) cli.print_differences_stats(report) @@ -111,7 +146,8 @@ def main(): for field in field_weights: if field in report.summary.field_labels: cli.print_field_mismatch_stats( - field, field_counters[field], len(report.summary)) + field, field_counters[field], len(report.summary) + ) # query details with LMDB(args.envdir, readonly=True) as lmdb: @@ -120,16 +156,18 @@ def main(): for field in field_weights: if field in report.summary.field_labels: for mismatch, qids in report.summary.get_field_mismatches(field): - queries = convert_queries(get_query_iterator(lmdb, qids), - qwire_to_msgid_qname_qtype) + queries = convert_queries( + get_query_iterator(lmdb, qids), qwire_to_msgid_qname_qtype + ) cli.print_mismatch_queries( field, mismatch, get_printable_queries_format(queries), - args.limit) + args.limit, + ) report.export_json(datafile) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/distrcmp.py b/distrcmp.py index 012360d182d186e88e3ee1807ebb0657236ddc2b..21e0dffd403b6dd7e141cdfae1e396054e754670 100755 --- a/distrcmp.py +++ b/distrcmp.py @@ -12,7 +12,9 @@ from respdiff.stats import Stats, SummaryStatistics DEFAULT_COEF = 2 -def belongs_to_distribution(ref: Optional[Stats], new: Optional[Stats], coef: float) -> bool: +def belongs_to_distribution( + ref: Optional[Stats], new: Optional[Stats], coef: float +) -> bool: if ref is None or new is None: return False median = statistics.median(new.samples) @@ -24,7 +26,9 @@ def belongs_to_distribution(ref: Optional[Stats], new: Optional[Stats], coef: fl def belongs_to_all(ref: SummaryStatistics, new: SummaryStatistics, coef: float) -> bool: - if not belongs_to_distribution(ref.target_disagreements, new.target_disagreements, coef): + if not belongs_to_distribution( + ref.target_disagreements, new.target_disagreements, coef + ): return False if not belongs_to_distribution(ref.upstream_unstable, new.upstream_unstable, coef): return False @@ -32,24 +36,38 @@ def belongs_to_all(ref: SummaryStatistics, new: SummaryStatistics, coef: float) return False if ref.fields is not None and new.fields is not None: for field in ref.fields: - if not belongs_to_distribution(ref.fields[field].total, new.fields[field].total, coef): + if not belongs_to_distribution( + ref.fields[field].total, new.fields[field].total, coef + ): return False return True def main(): cli.setup_logging() - parser = argparse.ArgumentParser(description='Check if new samples belong to reference ' - 'distribution. Ends with exitcode 0 if belong, ' - '1 if not,') - parser.add_argument('-r', '--reference', type=cli.read_stats, - help='json statistics file with reference data') + parser = argparse.ArgumentParser( + description="Check if new samples belong to reference " + "distribution. Ends with exitcode 0 if belong, " + "1 if not," + ) + parser.add_argument( + "-r", + "--reference", + type=cli.read_stats, + help="json statistics file with reference data", + ) cli.add_arg_report_filename(parser) - parser.add_argument('-c', '--coef', type=float, - default=DEFAULT_COEF, - help=('coeficient for comparation (new belongs to refference if ' - 'its median is closer than COEF * standart deviation of reference ' - 'from reference mean) (default: {})'.format(DEFAULT_COEF))) + parser.add_argument( + "-c", + "--coef", + type=float, + default=DEFAULT_COEF, + help=( + "coeficient for comparation (new belongs to refference if " + "its median is closer than COEF * standart deviation of reference " + "from reference mean) (default: {})".format(DEFAULT_COEF) + ), + ) args = parser.parse_args() reports = cli.get_reports_from_filenames(args) @@ -64,5 +82,5 @@ def main(): sys.exit(0) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/dnsviz.py b/dnsviz.py index a5658cda2eb3737571aeb43bfb2de85601715691..337496e251411916741574439723b6e5ada084a9 100755 --- a/dnsviz.py +++ b/dnsviz.py @@ -11,17 +11,32 @@ import respdiff.dnsviz def main(): cli.setup_logging() parser = argparse.ArgumentParser( - description="use dnsviz to categorize domains (perfect, warnings, errors)") + description="use dnsviz to categorize domains (perfect, warnings, errors)" + ) cli.add_arg_config(parser) cli.add_arg_dnsviz(parser) - parser.add_argument('input', type=str, help='input file with domains (one qname per line)') + parser.add_argument( + "input", type=str, help="input file with domains (one qname per line)" + ) args = parser.parse_args() - njobs = args.cfg['sendrecv']['jobs'] + njobs = args.cfg["sendrecv"]["jobs"] try: - probe = subprocess.run([ - 'dnsviz', 'probe', '-A', '-R', respdiff.dnsviz.TYPES, '-f', - args.input, '-t', str(njobs)], check=True, stdout=subprocess.PIPE) + probe = subprocess.run( + [ + "dnsviz", + "probe", + "-A", + "-R", + respdiff.dnsviz.TYPES, + "-f", + args.input, + "-t", + str(njobs), + ], + check=True, + stdout=subprocess.PIPE, + ) except subprocess.CalledProcessError as exc: logging.critical("dnsviz probe failed: %s", exc) sys.exit(1) @@ -30,12 +45,16 @@ def main(): sys.exit(1) try: - subprocess.run(['dnsviz', 'grok', '-o', args.dnsviz], input=probe.stdout, - check=True, stdout=subprocess.PIPE) + subprocess.run( + ["dnsviz", "grok", "-o", args.dnsviz], + input=probe.stdout, + check=True, + stdout=subprocess.PIPE, + ) except subprocess.CalledProcessError as exc: logging.critical("dnsviz grok failed: %s", exc) sys.exit(1) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/histogram.py b/histogram.py index 5ec82c56175dddbc0a3b04bcc5a289d5d2a7fcb0..0bd1e0d8229c9676a8c013a25011c711dfb5c76f 100755 --- a/histogram.py +++ b/histogram.py @@ -25,7 +25,8 @@ from respdiff.typing import ResolverID # Force matplotlib to use a different backend to handle machines without a display import matplotlib import matplotlib.ticker as mtick -matplotlib.use('Agg') + +matplotlib.use("Agg") import matplotlib.pyplot as plt # noqa @@ -33,9 +34,8 @@ HISTOGRAM_RCODE_MAX = 23 def load_data( - txn: lmdb.Transaction, - dnsreplies_factory: DNSRepliesFactory - ) -> Dict[ResolverID, List[Tuple[float, Optional[int]]]]: + txn: lmdb.Transaction, dnsreplies_factory: DNSRepliesFactory +) -> Dict[ResolverID, List[Tuple[float, Optional[int]]]]: data = {} # type: Dict[ResolverID, List[Tuple[float, Optional[int]]]] cursor = txn.cursor() for value in cursor.iternext(keys=False, values=True): @@ -45,17 +45,15 @@ def load_data( # 12 is chosen to be consistent with dnspython's ShortHeader exception rcode = None else: - (flags, ) = struct.unpack('!H', reply.wire[2:4]) - rcode = flags & 0x000f + (flags,) = struct.unpack("!H", reply.wire[2:4]) + rcode = flags & 0x000F data.setdefault(resolver, []).append((reply.time, rcode)) return data def plot_log_percentile_histogram( - data: Dict[str, List[float]], - title: str, - config=None - ) -> None: + data: Dict[str, List[float]], title: str, config=None +) -> None: """ For graph explanation, see https://blog.powerdns.com/2017/11/02/dns-performance-metrics-the-logarithmic-percentile-histogram/ @@ -66,55 +64,59 @@ def plot_log_percentile_histogram( # Distribute sample points along logarithmic X axis percentiles = np.logspace(-3, 2, num=100) - ax.set_xscale('log') - ax.xaxis.set_major_formatter(mtick.FormatStrFormatter('%s')) - ax.set_yscale('log') - ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%s')) + ax.set_xscale("log") + ax.xaxis.set_major_formatter(mtick.FormatStrFormatter("%s")) + ax.set_yscale("log") + ax.yaxis.set_major_formatter(mtick.FormatStrFormatter("%s")) - ax.grid(True, which='major') - ax.grid(True, which='minor', linestyle='dotted', color='#DDDDDD') + ax.grid(True, which="major") + ax.grid(True, which="minor", linestyle="dotted", color="#DDDDDD") - ax.set_xlabel('Slowest percentile') - ax.set_ylabel('Response time [ms]') - ax.set_title('Resolver Response Time' + " - " + title) + ax.set_xlabel("Slowest percentile") + ax.set_ylabel("Response time [ms]") + ax.set_title("Resolver Response Time" + " - " + title) # plot data for server in sorted(data): if data[server]: try: - color = config[server]['graph_color'] + color = config[server]["graph_color"] except (KeyError, TypeError): color = None # convert to ms and sort values = sorted([1000 * x for x in data[server]], reverse=True) - ax.plot(percentiles, - [values[math.ceil(pctl * len(values) / 100) - 1] for pctl in percentiles], lw=2, - label='{:<10}'.format(server) + " " + '{:9d}'.format(len(values)), color=color) + ax.plot( + percentiles, + [ + values[math.ceil(pctl * len(values) / 100) - 1] + for pctl in percentiles + ], + lw=2, + label="{:<10}".format(server) + " " + "{:9d}".format(len(values)), + color=color, + ) plt.legend() def create_histogram( - data: Dict[str, List[float]], - filename: str, - title: str, - config=None - ) -> None: + data: Dict[str, List[float]], filename: str, title: str, config=None +) -> None: # don't plot graphs which don't contain any finite time - if any(any(time < float('+inf') for time in d) for d in data.values()): + if any(any(time < float("+inf") for time in d) for d in data.values()): plot_log_percentile_histogram(data, title, config) # save to file plt.savefig(filename, dpi=300) def histogram_by_rcode( - data: Dict[ResolverID, List[Tuple[float, Optional[int]]]], - filename: str, - title: str, - config=None, - rcode: Optional[int] = None - ) -> None: + data: Dict[ResolverID, List[Tuple[float, Optional[int]]]], + filename: str, + title: str, + config=None, + rcode: Optional[int] = None, +) -> None: def same_rcode(value: Tuple[float, Optional[int]]) -> bool: if rcode is None: if value[1] is None: @@ -125,26 +127,43 @@ def histogram_by_rcode( filtered_by_rcode = { resolver: [time for (time, rc) in filter(same_rcode, values)] - for (resolver, values) in data.items()} + for (resolver, values) in data.items() + } create_histogram(filtered_by_rcode, filename, title, config) def main(): cli.setup_logging() parser = argparse.ArgumentParser( - description='Plot query response time histogram from answers stored ' - 'in LMDB') - parser.add_argument('-o', '--output', type=str, default='histogram', - help='output directory for image files (default: histogram)') - parser.add_argument('-f', '--format', type=str, default='png', - help='output image format (default: png)') - parser.add_argument('-c', '--config', default='respdiff.cfg', dest='cfgpath', - help='config file (default: respdiff.cfg)') - parser.add_argument('envdir', type=str, - help='LMDB environment to read answers from') + description="Plot query response time histogram from answers stored " "in LMDB" + ) + parser.add_argument( + "-o", + "--output", + type=str, + default="histogram", + help="output directory for image files (default: histogram)", + ) + parser.add_argument( + "-f", + "--format", + type=str, + default="png", + help="output image format (default: png)", + ) + parser.add_argument( + "-c", + "--config", + default="respdiff.cfg", + dest="cfgpath", + help="config file (default: respdiff.cfg)", + ) + parser.add_argument( + "envdir", type=str, help="LMDB environment to read answers from" + ) args = parser.parse_args() config = cfg.read_cfg(args.cfgpath) - servers = config['servers']['names'] + servers = config["servers"]["names"] dnsreplies_factory = DNSRepliesFactory(servers) with LMDB(args.envdir, readonly=True) as lmdb_: @@ -160,12 +179,16 @@ def main(): data = load_data(txn, dnsreplies_factory) def get_filepath(filename) -> str: - return os.path.join(args.output, filename + '.' + args.format) + return os.path.join(args.output, filename + "." + args.format) if not os.path.exists(args.output): os.makedirs(args.output) - create_histogram({k: [tup[0] for tup in d] for (k, d) in data.items()}, - get_filepath('all'), 'all', config) + create_histogram( + {k: [tup[0] for tup in d] for (k, d) in data.items()}, + get_filepath("all"), + "all", + config, + ) # rcode-specific queries with pool.Pool() as p: @@ -175,9 +198,9 @@ def main(): filepath = get_filepath(rcode_text) fargs.append((data, filepath, rcode_text, config, rcode)) p.starmap(histogram_by_rcode, fargs) - filepath = get_filepath('unparsed') - histogram_by_rcode(data, filepath, 'unparsed queries', config, None) + filepath = get_filepath("unparsed") + histogram_by_rcode(data, filepath, "unparsed queries", config, None) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/msgdiff.py b/msgdiff.py index 10d7b68b8b3121dbff323a85de35a5f477c5d522..43d95a288f757639c10125d09d257d77599cf7ff 100755 --- a/msgdiff.py +++ b/msgdiff.py @@ -10,7 +10,12 @@ from typing import Any, Dict, Iterator, Mapping, Optional, Sequence, Tuple # no from respdiff import cli from respdiff.dataformat import ( - DiffReport, Disagreements, DisagreementsCounter, FieldLabel, QID) + DiffReport, + Disagreements, + DisagreementsCounter, + FieldLabel, + QID, +) from respdiff.database import DNSRepliesFactory, DNSReply, key2qid, LMDB, MetaDatabase from respdiff.match import compare from respdiff.typing import ResolverID @@ -20,9 +25,8 @@ lmdb = None def read_answers_lmdb( - dnsreplies_factory: DNSRepliesFactory, - qid: QID - ) -> Mapping[ResolverID, DNSReply]: + dnsreplies_factory: DNSRepliesFactory, qid: QID +) -> Mapping[ResolverID, DNSReply]: assert lmdb is not None, "LMDB wasn't initialized!" adb = lmdb.get_db(LMDB.ANSWERS) with lmdb.env.begin(adb) as txn: @@ -32,11 +36,11 @@ def read_answers_lmdb( def compare_lmdb_wrapper( - criteria: Sequence[FieldLabel], - target: ResolverID, - dnsreplies_factory: DNSRepliesFactory, - qid: QID - ) -> None: + criteria: Sequence[FieldLabel], + target: ResolverID, + dnsreplies_factory: DNSRepliesFactory, + qid: QID, +) -> None: assert lmdb is not None, "LMDB wasn't initialized!" answers = read_answers_lmdb(dnsreplies_factory, qid) others_agree, target_diffs = compare(answers, criteria, target) @@ -69,11 +73,13 @@ def export_json(filename: str, report: DiffReport): # NOTE: msgdiff is the first tool in the toolchain to generate report.json # thus it doesn't make sense to re-use existing report.json file if os.path.exists(filename): - backup_filename = filename + '.bak' + backup_filename = filename + ".bak" os.rename(filename, backup_filename) logging.warning( - 'JSON report already exists, overwriting file. Original ' - 'file backed up as %s', backup_filename) + "JSON report already exists, overwriting file. Original " + "file backed up as %s", + backup_filename, + ) report.export_json(filename) @@ -81,18 +87,14 @@ def prepare_report(lmdb_, servers: Sequence[ResolverID]) -> DiffReport: qdb = lmdb_.open_db(LMDB.QUERIES) adb = lmdb_.open_db(LMDB.ANSWERS) with lmdb_.env.begin() as txn: - total_queries = txn.stat(qdb)['entries'] - total_answers = txn.stat(adb)['entries'] + total_queries = txn.stat(qdb)["entries"] + total_answers = txn.stat(adb)["entries"] meta = MetaDatabase(lmdb_, servers) start_time = meta.read_start_time() end_time = meta.read_end_time() - return DiffReport( - start_time, - end_time, - total_queries, - total_answers) + return DiffReport(start_time, end_time, total_queries, total_answers) def main(): @@ -100,16 +102,17 @@ def main(): cli.setup_logging() parser = argparse.ArgumentParser( - description='compute diff from answers stored in LMDB and write diffs to LMDB') + description="compute diff from answers stored in LMDB and write diffs to LMDB" + ) cli.add_arg_envdir(parser) cli.add_arg_config(parser) cli.add_arg_datafile(parser) args = parser.parse_args() datafile = cli.get_datafile(args, check_exists=False) - criteria = args.cfg['diff']['criteria'] - target = args.cfg['diff']['target'] - servers = args.cfg['servers']['names'] + criteria = args.cfg["diff"]["criteria"] + target = args.cfg["diff"]["target"] + servers = args.cfg["servers"]["names"] with LMDB(args.envdir) as lmdb_: # NOTE: To avoid an lmdb.BadRslotError, probably caused by weird @@ -126,12 +129,13 @@ def main(): dnsreplies_factory = DNSRepliesFactory(servers) compare_func = partial( - compare_lmdb_wrapper, criteria, target, dnsreplies_factory) + compare_lmdb_wrapper, criteria, target, dnsreplies_factory + ) with pool.Pool() as p: for _ in p.imap_unordered(compare_func, qid_stream, chunksize=10): pass export_json(datafile, report) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/orchestrator.py b/orchestrator.py index bbbfb4b8149a96ceaa5e046c899a12e223e161f4..3ee3ee189ca6dce234804e79f6700a5d84fc1a1c 100755 --- a/orchestrator.py +++ b/orchestrator.py @@ -12,18 +12,22 @@ from respdiff.database import LMDB, MetaDatabase def main(): cli.setup_logging() parser = argparse.ArgumentParser( - description='read queries from LMDB, send them in parallel to servers ' - 'listed in configuration file, and record answers into LMDB') + description="read queries from LMDB, send them in parallel to servers " + "listed in configuration file, and record answers into LMDB" + ) cli.add_arg_envdir(parser) cli.add_arg_config(parser) - parser.add_argument('--ignore-timeout', action="store_true", - help='continue despite consecutive timeouts from resolvers') + parser.add_argument( + "--ignore-timeout", + action="store_true", + help="continue despite consecutive timeouts from resolvers", + ) args = parser.parse_args() sendrecv.module_init(args) with LMDB(args.envdir) as lmdb: - meta = MetaDatabase(lmdb, args.cfg['servers']['names'], create=True) + meta = MetaDatabase(lmdb, args.cfg["servers"]["names"], create=True) meta.write_version() meta.write_start_time() @@ -35,24 +39,25 @@ def main(): try: # process queries in parallel with pool.Pool( - processes=args.cfg['sendrecv']['jobs'], - initializer=sendrecv.worker_init) as p: + processes=args.cfg["sendrecv"]["jobs"], initializer=sendrecv.worker_init + ) as p: i = 0 for qkey, blob in p.imap_unordered( - sendrecv.worker_perform_query, qstream, chunksize=100): + sendrecv.worker_perform_query, qstream, chunksize=100 + ): i += 1 if i % 10000 == 0: - logging.info('Received {:d} answers'.format(i)) + logging.info("Received {:d} answers".format(i)) txn.put(qkey, blob) except KeyboardInterrupt: - logging.info('SIGINT received, exiting...') + logging.info("SIGINT received, exiting...") sys.exit(130) except RuntimeError as err: logging.error(err) sys.exit(1) finally: # attempt to preserve data if something went wrong (or not) - logging.debug('Comitting LMDB transaction...') + logging.debug("Comitting LMDB transaction...") txn.commit() meta.write_end_time() diff --git a/pylintrc b/pylintrc index 4c507730db437471c8916d3832496dc54afa3f51..e0c0b5697b90e78093c92944407e6a26a45f93c9 100644 --- a/pylintrc +++ b/pylintrc @@ -12,7 +12,6 @@ disable= line-too-long, # checked by flake8 invalid-name, broad-except, - bad-continuation, global-statement, no-else-return, duplicate-code, diff --git a/qexport.py b/qexport.py index 0198be68c2caad47a8cfa0586ff9fae93be6182e..85d85899c659dbe1d93947657611da364173d923 100755 --- a/qexport.py +++ b/qexport.py @@ -14,28 +14,29 @@ from respdiff.typing import QID def get_qids_to_export( - args: argparse.Namespace, - reports: Sequence[DiffReport] - ) -> Set[QID]: + args: argparse.Namespace, reports: Sequence[DiffReport] +) -> Set[QID]: qids = set() # type: Set[QID] for report in reports: if args.failing: if report.summary is None: raise ValueError( - "Report {} is missing summary!".format(report.fileorigin)) + "Report {} is missing summary!".format(report.fileorigin) + ) failing_qids = set(report.summary.keys()) qids.update(failing_qids) if args.unstable: if report.other_disagreements is None: raise ValueError( - "Report {} is missing other disagreements!".format(report.fileorigin)) + "Report {} is missing other disagreements!".format( + report.fileorigin + ) + ) unstable_qids = report.other_disagreements.queries qids.update(unstable_qids) if args.qidlist: - with open(args.qidlist, encoding='UTF-8') as qidlist_file: - qids.update(int(qid.strip()) - for qid in qidlist_file - if qid.strip()) + with open(args.qidlist, encoding="UTF-8") as qidlist_file: + qids.update(int(qid.strip()) for qid in qidlist_file if qid.strip()) return qids @@ -49,7 +50,7 @@ def export_qids_to_qname_qtype(qids: Set[QID], lmdb, file=sys.stdout): try: query = qwire_to_qname_qtype(qwire) except ValueError as exc: - logging.debug('Omitting QID %d from export: %s', qid, exc) + logging.debug("Omitting QID %d from export: %s", qid, exc) else: print(query, file=file) @@ -60,7 +61,7 @@ def export_qids_to_qname(qids: Set[QID], lmdb, file=sys.stdout): try: qname = qwire_to_qname(qwire) except ValueError as exc: - logging.debug('Omitting QID %d from export: %s', qid, exc) + logging.debug("Omitting QID %d from export: %s", qid, exc) else: if qname not in domains: print(qname, file=file) @@ -71,36 +72,51 @@ def export_qids_to_base64url(qids: Set[QID], lmdb, file=sys.stdout): wires = set() # type: Set[bytes] for _, qwire in get_query_iterator(lmdb, qids): if qwire not in wires: - print(base64.urlsafe_b64encode(qwire).decode('ascii'), file=file) + print(base64.urlsafe_b64encode(qwire).decode("ascii"), file=file) wires.add(qwire) def main(): cli.setup_logging() - parser = argparse.ArgumentParser(description="export queries from reports' summaries") - cli.add_arg_report_filename(parser, nargs='+') - parser.add_argument('--envdir', type=str, - help="LMDB environment (required when output format isn't 'qid')") - parser.add_argument('-f', '--format', type=str, choices=['query', 'qid', 'domain', 'base64url'], - default='domain', help="output data format") - parser.add_argument('-o', '--output', type=str, help='output file') - parser.add_argument('--failing', action='store_true', help="get target disagreements") - parser.add_argument('--unstable', action='store_true', help="get upstream unstable") - parser.add_argument('--qidlist', type=str, help='path to file with list of QIDs to export') + parser = argparse.ArgumentParser( + description="export queries from reports' summaries" + ) + cli.add_arg_report_filename(parser, nargs="+") + parser.add_argument( + "--envdir", + type=str, + help="LMDB environment (required when output format isn't 'qid')", + ) + parser.add_argument( + "-f", + "--format", + type=str, + choices=["query", "qid", "domain", "base64url"], + default="domain", + help="output data format", + ) + parser.add_argument("-o", "--output", type=str, help="output file") + parser.add_argument( + "--failing", action="store_true", help="get target disagreements" + ) + parser.add_argument("--unstable", action="store_true", help="get upstream unstable") + parser.add_argument( + "--qidlist", type=str, help="path to file with list of QIDs to export" + ) args = parser.parse_args() - if args.format != 'qid' and not args.envdir: + if args.format != "qid" and not args.envdir: logging.critical("--envdir required when output format isn't 'qid'") sys.exit(1) if not args.failing and not args.unstable and not args.qidlist: - logging.critical('No filter selected!') + logging.critical("No filter selected!") sys.exit(1) reports = cli.get_reports_from_filenames(args) if not reports: - logging.critical('No reports found!') + logging.critical("No reports found!") sys.exit(1) try: @@ -110,21 +126,21 @@ def main(): sys.exit(1) with cli.smart_open(args.output) as fh: - if args.format == 'qid': + if args.format == "qid": export_qids(qids, fh) else: with LMDB(args.envdir, readonly=True) as lmdb: lmdb.open_db(LMDB.QUERIES) - if args.format == 'query': + if args.format == "query": export_qids_to_qname_qtype(qids, lmdb, fh) - elif args.format == 'domain': + elif args.format == "domain": export_qids_to_qname(qids, lmdb, fh) - elif args.format == 'base64url': + elif args.format == "base64url": export_qids_to_base64url(qids, lmdb, fh) else: - raise ValueError('unsupported output format') + raise ValueError("unsupported output format") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/qprep.py b/qprep.py index bedd7bf2215aa988c57f621c85ae451abce4dfab..6f94c972277f09b472dafd64fab800813232ec86 100755 --- a/qprep.py +++ b/qprep.py @@ -26,7 +26,7 @@ def read_lines(instream): i = 1 for line in instream: if i % REPORT_CHUNKS == 0: - logging.info('Read %d lines', i) + logging.info("Read %d lines", i) line = line.strip() if line: yield (i, line, line) @@ -62,14 +62,14 @@ def parse_pcap(pcap_file): pcap_file = dpkt.pcap.Reader(pcap_file) for _, frame in pcap_file: if i % REPORT_CHUNKS == 0: - logging.info('Read %d frames', i) - yield (i, frame, 'frame no. {}'.format(i)) + logging.info("Read %d frames", i) + yield (i, frame, "frame no. {}".format(i)) i += 1 def wrk_process_line( - args: Tuple[int, str, str] - ) -> Tuple[Optional[int], Optional[bytes]]: + args: Tuple[int, str, str] +) -> Tuple[Optional[int], Optional[bytes]]: """ Worker: parse input line, creates a packet in binary format @@ -80,16 +80,19 @@ def wrk_process_line( try: msg = msg_from_text(line) if blacklist.is_blacklisted(msg): - logging.debug('Blacklisted query "%s", skipping QID %d', - log_repr, qid) + logging.debug('Blacklisted query "%s", skipping QID %d', log_repr, qid) return None, None return qid, msg.to_wire() except (ValueError, struct.error, dns.exception.DNSException) as ex: - logging.error('Invalid query specification "%s": %s, skipping QID %d', line, ex, qid) + logging.error( + 'Invalid query specification "%s": %s, skipping QID %d', line, ex, qid + ) return None, None -def wrk_process_frame(args: Tuple[int, bytes, str]) -> Tuple[Optional[int], Optional[bytes]]: +def wrk_process_frame( + args: Tuple[int, bytes, str] +) -> Tuple[Optional[int], Optional[bytes]]: """ Worker: convert packet from pcap to binary data """ @@ -99,10 +102,8 @@ def wrk_process_frame(args: Tuple[int, bytes, str]) -> Tuple[Optional[int], Opti def wrk_process_wire_packet( - qid: int, - wire_packet: bytes, - log_repr: str - ) -> Tuple[Optional[int], Optional[bytes]]: + qid: int, wire_packet: bytes, log_repr: str +) -> Tuple[Optional[int], Optional[bytes]]: """ Worker: Return packet's data if it's not blacklisted @@ -117,8 +118,7 @@ def wrk_process_wire_packet( pass else: if blacklist.is_blacklisted(msg): - logging.debug('Blacklisted query "%s", skipping QID %d', - log_repr, qid) + logging.debug('Blacklisted query "%s", skipping QID %d', log_repr, qid) return None, None return qid, wire_packet @@ -140,11 +140,14 @@ def msg_from_text(text): try: qname, qtype = text.split() except ValueError as e: - raise ValueError('space is only allowed as separator between qname qtype') from e - qname = dns.name.from_text(qname.encode('ascii')) + raise ValueError( + "space is only allowed as separator between qname qtype" + ) from e + qname = dns.name.from_text(qname.encode("ascii")) qtype = int_or_fromtext(qtype, dns.rdatatype.from_text) - msg = dns.message.make_query(qname, qtype, dns.rdataclass.IN, - want_dnssec=True, payload=4096) + msg = dns.message.make_query( + qname, qtype, dns.rdataclass.IN, want_dnssec=True, payload=4096 + ) return msg @@ -152,22 +155,31 @@ def main(): cli.setup_logging() parser = argparse.ArgumentParser( formatter_class=argparse.RawTextHelpFormatter, - description='Convert queries data from standard input and store ' - 'wire format into LMDB "queries" DB.') + description="Convert queries data from standard input and store " + 'wire format into LMDB "queries" DB.', + ) cli.add_arg_envdir(parser) - parser.add_argument('-f', '--in-format', type=str, choices=['text', 'pcap'], default='text', - help='define format for input data, default value is text\n' - 'Expected input for "text" is: "<qname> <RR type>", ' - 'one query per line.\n' - 'Expected input for "pcap" is content of the pcap file.') - parser.add_argument('--pcap-file', type=argparse.FileType('rb')) + parser.add_argument( + "-f", + "--in-format", + type=str, + choices=["text", "pcap"], + default="text", + help="define format for input data, default value is text\n" + 'Expected input for "text" is: "<qname> <RR type>", ' + "one query per line.\n" + 'Expected input for "pcap" is content of the pcap file.', + ) + parser.add_argument("--pcap-file", type=argparse.FileType("rb")) args = parser.parse_args() - if args.in_format == 'text' and args.pcap_file: - logging.critical("Argument --pcap-file can be use only in combination with -f pcap") + if args.in_format == "text" and args.pcap_file: + logging.critical( + "Argument --pcap-file can be use only in combination with -f pcap" + ) sys.exit(1) - if args.in_format == 'pcap' and not args.pcap_file: + if args.in_format == "pcap" and not args.pcap_file: logging.critical("Missing path to pcap file, use argument --pcap-file") sys.exit(1) @@ -176,12 +188,12 @@ def main(): txn = lmdb.env.begin(qdb, write=True) try: with pool.Pool( - initializer=lambda: signal.signal(signal.SIGINT, signal.SIG_IGN) - ) as workers: - if args.in_format == 'text': + initializer=lambda: signal.signal(signal.SIGINT, signal.SIG_IGN) + ) as workers: + if args.in_format == "text": data_stream = read_lines(sys.stdin) method = wrk_process_line - elif args.in_format == 'pcap': + elif args.in_format == "pcap": data_stream = parse_pcap(args.pcap_file) method = wrk_process_frame else: @@ -192,16 +204,16 @@ def main(): key = qid2key(qid) txn.put(key, wire) except KeyboardInterrupt: - logging.info('SIGINT received, exiting...') + logging.info("SIGINT received, exiting...") sys.exit(130) except RuntimeError as err: logging.error(err) sys.exit(1) finally: # attempt to preserve data if something went wrong (or not) - logging.debug('Comitting LMDB transaction...') + logging.debug("Comitting LMDB transaction...") txn.commit() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/respdiff/__init__.py b/respdiff/__init__.py index 65432d4a178f773e72028fa9882328c3daf134cc..797e22d45c60dc917b8f2ab40aef785e1cb276a4 100644 --- a/respdiff/__init__.py +++ b/respdiff/__init__.py @@ -2,7 +2,7 @@ import sys # LMDB isn't portable between LE/BE platforms # upon import, check we're on a little endian platform -assert sys.byteorder == 'little', 'Big endian platforms are not supported' +assert sys.byteorder == "little", "Big endian platforms are not supported" # Check minimal Python version -assert sys.version_info >= (3, 5, 2), 'Minimal supported Python version is 3.5.2' +assert sys.version_info >= (3, 5, 2), "Minimal supported Python version is 3.5.2" diff --git a/respdiff/blacklist.py b/respdiff/blacklist.py index e050c56e7af5a2432b8bd55dc01ce0f47f5e83c4..a2c757c91b7107053460a53d468d9c9bdb258c8c 100644 --- a/respdiff/blacklist.py +++ b/respdiff/blacklist.py @@ -2,19 +2,37 @@ import dns from dns.message import Message -_BLACKLIST_SUBDOMAINS = [dns.name.from_text(name) for name in [ - 'local.', - # dotnxdomain.net and dashnxdomain.net are used by APNIC for ephemeral - # single-query tests so there is no point in asking these repeatedly - 'dotnxdomain.net.', 'dashnxdomain.net.', - # Some malware names from shadowserver.org - # Asking them repeatedly may give impression that we're infected. - 'ak-is.net', 'boceureid.com', 'dkfdbbtmeeow.net', 'dssjistyyqhs.online', - 'gowwmcawjbuwrvrwc.com', 'hhjxobtybumg.pw', 'iqodqxy.info', 'k5service.org', - 'kibot.pw', 'mdtrlvfjppreyv.in', 'oxekbtumwc.info', 'raz4756pi7zop.com', - 'rbqfe.net', 'theyardservice.com', 'uwnlzvl.info', 'weg11uvrym3.com', - 'winsrw.com', 'xfwnbk.info', 'xmrimtl.net', -]] +_BLACKLIST_SUBDOMAINS = [ + dns.name.from_text(name) + for name in [ + "local.", + # dotnxdomain.net and dashnxdomain.net are used by APNIC for ephemeral + # single-query tests so there is no point in asking these repeatedly + "dotnxdomain.net.", + "dashnxdomain.net.", + # Some malware names from shadowserver.org + # Asking them repeatedly may give impression that we're infected. + "ak-is.net", + "boceureid.com", + "dkfdbbtmeeow.net", + "dssjistyyqhs.online", + "gowwmcawjbuwrvrwc.com", + "hhjxobtybumg.pw", + "iqodqxy.info", + "k5service.org", + "kibot.pw", + "mdtrlvfjppreyv.in", + "oxekbtumwc.info", + "raz4756pi7zop.com", + "rbqfe.net", + "theyardservice.com", + "uwnlzvl.info", + "weg11uvrym3.com", + "winsrw.com", + "xfwnbk.info", + "xmrimtl.net", + ] +] def is_blacklisted(dnsmsg: Message) -> bool: @@ -23,7 +41,7 @@ def is_blacklisted(dnsmsg: Message) -> bool: """ try: flags = dns.flags.to_text(dnsmsg.flags).split() - if 'QR' in flags: # not a query + if "QR" in flags: # not a query return True if len(dnsmsg.question) != 1: # weird but valid packet (maybe DNS Cookies) @@ -32,8 +50,7 @@ def is_blacklisted(dnsmsg: Message) -> bool: # there is not standard describing common behavior for ANY/RRSIG query if question.rdtype in {dns.rdatatype.ANY, dns.rdatatype.RRSIG}: return True - return any(question.name.is_subdomain(name) - for name in _BLACKLIST_SUBDOMAINS) + return any(question.name.is_subdomain(name) for name in _BLACKLIST_SUBDOMAINS) except Exception: # weird stuff, it's better to test resolver with this as well! return False diff --git a/respdiff/cfg.py b/respdiff/cfg.py index 60bb05026ed3b29057cae5786455005535492b74..87acab89fb4d387a5a9fcc0b1cc3ab04e2d5fa08 100644 --- a/respdiff/cfg.py +++ b/respdiff/cfg.py @@ -12,8 +12,20 @@ import dns.inet ALL_FIELDS = [ - 'timeout', 'malformed', 'opcode', 'question', 'rcode', 'flags', 'answertypes', - 'answerrrsigs', 'answer', 'authority', 'additional', 'edns', 'nsid'] + "timeout", + "malformed", + "opcode", + "question", + "rcode", + "flags", + "answertypes", + "answerrrsigs", + "answer", + "authority", + "additional", + "edns", + "nsid", +] ALL_FIELDS_SET = set(ALL_FIELDS) @@ -30,45 +42,45 @@ def comma_list(lstr): """ Split string 'a, b' into list [a, b] """ - return [name.strip() for name in lstr.split(',')] + return [name.strip() for name in lstr.split(",")] def transport_opt(ostr): - if ostr not in {'udp', 'tcp', 'tls'}: - raise ValueError('unsupported transport') + if ostr not in {"udp", "tcp", "tls"}: + raise ValueError("unsupported transport") return ostr # declarative config format description for always-present sections # dict structure: dict[section name][key name] = (type, required) _CFGFMT = { - 'sendrecv': { - 'timeout': (float, True), - 'jobs': (int, True), - 'time_delay_min': (float, True), - 'time_delay_max': (float, True), - 'max_timeouts': (int, False), + "sendrecv": { + "timeout": (float, True), + "jobs": (int, True), + "time_delay_min": (float, True), + "time_delay_max": (float, True), + "max_timeouts": (int, False), }, - 'servers': { - 'names': (comma_list, True), + "servers": { + "names": (comma_list, True), }, - 'diff': { - 'target': (str, True), - 'criteria': (comma_list, True), + "diff": { + "target": (str, True), + "criteria": (comma_list, True), }, - 'report': { - 'field_weights': (comma_list, True), + "report": { + "field_weights": (comma_list, True), }, } # declarative config format description for per-server section # dict structure: dict[key name] = type _CFGFMT_SERVER = { - 'ip': (ipaddr_check, True), - 'port': (int, True), - 'transport': (transport_opt, True), - 'graph_color': (str, False), - 'restart_script': (str, False), + "ip": (ipaddr_check, True), + "port": (int, True), + "transport": (transport_opt, True), + "graph_color": (str, False), + "restart_script": (str, False), } @@ -87,20 +99,25 @@ def cfg2dict_convert(fmt, cparser): for valname, (valfmt, valreq) in sectfmt.items(): try: if not cparser[sectname][valname].strip(): - raise ValueError('empty values are not allowed') + raise ValueError("empty values are not allowed") sectdict[valname] = valfmt(cparser[sectname][valname]) except ValueError as ex: - raise ValueError('config section [{}] key "{}" has invalid format: ' - '{}; expected format: {}'.format( - sectname, valname, ex, valfmt)) from ex + raise ValueError( + 'config section [{}] key "{}" has invalid format: ' + "{}; expected format: {}".format(sectname, valname, ex, valfmt) + ) from ex except KeyError as ex: if valreq: - raise KeyError('config section [{}] key "{}" not found'.format( - sectname, valname)) from ex + raise KeyError( + 'config section [{}] key "{}" not found'.format( + sectname, valname + ) + ) from ex unsupported_keys = set(cparser[sectname].keys()) - set(sectfmt.keys()) if unsupported_keys: - raise ValueError('unexpected keys {} in section [{}]'.format( - unsupported_keys, sectname)) + raise ValueError( + "unexpected keys {} in section [{}]".format(unsupported_keys, sectname) + ) return cdict @@ -109,38 +126,53 @@ def cfg2dict_check_sect(fmt, cfg): Check non-existence of unhandled config sections. """ supported_sections = set(fmt.keys()) - present_sections = set(cfg.keys()) - {'DEFAULT'} + present_sections = set(cfg.keys()) - {"DEFAULT"} unsupported_sections = present_sections - supported_sections if unsupported_sections: - raise ValueError('unexpected config sections {}'.format( - ', '.join('[{}]'.format(sn) for sn in unsupported_sections))) + raise ValueError( + "unexpected config sections {}".format( + ", ".join("[{}]".format(sn) for sn in unsupported_sections) + ) + ) def cfg2dict_check_diff(cdict): """ Check if diff target is listed among servers. """ - if cdict['diff']['target'] not in cdict['servers']['names']: - raise ValueError('[diff] target value "{}" must be listed in [servers] names'.format( - cdict['diff']['target'])) + if cdict["diff"]["target"] not in cdict["servers"]["names"]: + raise ValueError( + '[diff] target value "{}" must be listed in [servers] names'.format( + cdict["diff"]["target"] + ) + ) def cfg2dict_check_fields(cdict): """Check if all fields are known and that all have a weight assigned""" - unknown_criteria = set(cdict['diff']['criteria']) - ALL_FIELDS_SET + unknown_criteria = set(cdict["diff"]["criteria"]) - ALL_FIELDS_SET if unknown_criteria: - raise ValueError('[diff] criteria: unknown fields: {}'.format( - ', '.join(['"{}"'.format(field) for field in unknown_criteria]))) + raise ValueError( + "[diff] criteria: unknown fields: {}".format( + ", ".join(['"{}"'.format(field) for field in unknown_criteria]) + ) + ) - unknown_field_weights = set(cdict['report']['field_weights']) - ALL_FIELDS_SET + unknown_field_weights = set(cdict["report"]["field_weights"]) - ALL_FIELDS_SET if unknown_field_weights: - raise ValueError('[report] field_weights: unknown fields: {}'.format( - ', '.join(['"{}"'.format(field) for field in unknown_field_weights]))) + raise ValueError( + "[report] field_weights: unknown fields: {}".format( + ", ".join(['"{}"'.format(field) for field in unknown_field_weights]) + ) + ) - missing_field_weights = ALL_FIELDS_SET - set(cdict['report']['field_weights']) + missing_field_weights = ALL_FIELDS_SET - set(cdict["report"]["field_weights"]) if missing_field_weights: - raise ValueError('[report] field_weights: missing fields: {}'.format( - ', '.join(['"{}"'.format(field) for field in missing_field_weights]))) + raise ValueError( + "[report] field_weights: missing fields: {}".format( + ", ".join(['"{}"'.format(field) for field in missing_field_weights]) + ) + ) def read_cfg(filename): @@ -155,10 +187,11 @@ def read_cfg(filename): try: parser = configparser.ConfigParser( - delimiters='=', - comment_prefixes='#', + delimiters="=", + comment_prefixes="#", interpolation=None, - empty_lines_in_values=False) + empty_lines_in_values=False, + ) parser.read(filename) # parse things which must be present @@ -166,7 +199,7 @@ def read_cfg(filename): # parse variable server-specific data cfgfmt_servers = _CFGFMT.copy() - for server in cdict['servers']['names']: + for server in cdict["servers"]["names"]: cfgfmt_servers[server] = _CFGFMT_SERVER cdict = cfg2dict_convert(cfgfmt_servers, parser) @@ -177,13 +210,13 @@ def read_cfg(filename): # check fields (criteria, field_weights) cfg2dict_check_fields(cdict) except Exception as exc: - logging.critical('Failed to parse config: %s', exc) + logging.critical("Failed to parse config: %s", exc) raise ValueError(exc) from exc return cdict -if __name__ == '__main__': +if __name__ == "__main__": from pprint import pprint import sys diff --git a/respdiff/cli.py b/respdiff/cli.py index b9e7ed05c05cd98b0a55f090782244c524c72ef6..6e0a09804b366dedf6ac124d3f961c8abba2fcb8 100644 --- a/respdiff/cli.py +++ b/respdiff/cli.py @@ -20,10 +20,10 @@ ChangeStatsTuple = Tuple[int, Optional[float], Optional[int], Optional[float]] ChangeStatsTupleStr = Tuple[int, Optional[float], Optional[str], Optional[float]] LOGGING_LEVEL = logging.DEBUG -CONFIG_FILENAME = 'respdiff.cfg' -REPORT_FILENAME = 'report.json' -STATS_FILENAME = 'stats.json' -DNSVIZ_FILENAME = 'dnsviz.json' +CONFIG_FILENAME = "respdiff.cfg" +REPORT_FILENAME = "report.json" +STATS_FILENAME = "stats.json" +DNSVIZ_FILENAME = "dnsviz.json" DEFAULT_PRINT_QUERY_LIMIT = 10 @@ -36,7 +36,7 @@ def read_stats(filename: str) -> SummaryStatistics: def _handle_empty_report(exc: Exception, skip_empty: bool): if skip_empty: - logging.debug('%s Omitting...', exc) + logging.debug("%s Omitting...", exc) else: logging.error(str(exc)) raise ValueError(exc) @@ -51,77 +51,107 @@ def read_report(filename: str, skip_empty: bool = False) -> Optional[DiffReport] def load_summaries( - reports: Sequence[DiffReport], - skip_empty: bool = False - ) -> Sequence[Summary]: + reports: Sequence[DiffReport], skip_empty: bool = False +) -> Sequence[Summary]: summaries = [] for report in reports: if report.summary is None: _handle_empty_report( ValueError('Empty diffsum in "{}"!'.format(report.fileorigin)), - skip_empty) + skip_empty, + ) else: summaries.append(report.summary) return summaries def setup_logging(level: int = LOGGING_LEVEL) -> None: - logging.basicConfig(format='%(asctime)s %(levelname)8s %(message)s', level=level) - logger = logging.getLogger('matplotlib') + logging.basicConfig(format="%(asctime)s %(levelname)8s %(message)s", level=level) + logger = logging.getLogger("matplotlib") # set WARNING for Matplotlib logger.setLevel(logging.WARNING) def add_arg_config(parser: ArgumentParser) -> None: - parser.add_argument('-c', '--config', type=read_cfg, - default=CONFIG_FILENAME, dest='cfg', - help='config file (default: {})'.format(CONFIG_FILENAME)) + parser.add_argument( + "-c", + "--config", + type=read_cfg, + default=CONFIG_FILENAME, + dest="cfg", + help="config file (default: {})".format(CONFIG_FILENAME), + ) def add_arg_envdir(parser: ArgumentParser) -> None: - parser.add_argument('envdir', type=str, - help='LMDB environment to read/write queries, answers and diffs') + parser.add_argument( + "envdir", + type=str, + help="LMDB environment to read/write queries, answers and diffs", + ) def add_arg_datafile(parser: ArgumentParser) -> None: - parser.add_argument('-d', '--datafile', type=str, - help='JSON report file (default: <envdir>/{})'.format( - REPORT_FILENAME)) + parser.add_argument( + "-d", + "--datafile", + type=str, + help="JSON report file (default: <envdir>/{})".format(REPORT_FILENAME), + ) def add_arg_limit(parser: ArgumentParser) -> None: - parser.add_argument('-l', '--limit', type=int, - default=DEFAULT_PRINT_QUERY_LIMIT, - help='number of displayed mismatches in fields (default: {}; ' - 'use 0 to display all)'.format(DEFAULT_PRINT_QUERY_LIMIT)) + parser.add_argument( + "-l", + "--limit", + type=int, + default=DEFAULT_PRINT_QUERY_LIMIT, + help="number of displayed mismatches in fields (default: {}; " + "use 0 to display all)".format(DEFAULT_PRINT_QUERY_LIMIT), + ) def add_arg_stats(parser: ArgumentParser) -> None: - parser.add_argument('-s', '--stats', type=read_stats, - default=STATS_FILENAME, - help='statistics file (default: {})'.format(STATS_FILENAME)) - - -def add_arg_stats_filename(parser: ArgumentParser, default: str = STATS_FILENAME) -> None: - parser.add_argument('-s', '--stats', type=str, - default=default, dest='stats_filename', - help='statistics file (default: {})'.format(default)) + parser.add_argument( + "-s", + "--stats", + type=read_stats, + default=STATS_FILENAME, + help="statistics file (default: {})".format(STATS_FILENAME), + ) + + +def add_arg_stats_filename( + parser: ArgumentParser, default: str = STATS_FILENAME +) -> None: + parser.add_argument( + "-s", + "--stats", + type=str, + default=default, + dest="stats_filename", + help="statistics file (default: {})".format(default), + ) def add_arg_dnsviz(parser: ArgumentParser, default: str = DNSVIZ_FILENAME) -> None: - parser.add_argument('--dnsviz', type=str, default=default, - help='dnsviz grok output (default: {})'.format(default)) + parser.add_argument( + "--dnsviz", + type=str, + default=default, + help="dnsviz grok output (default: {})".format(default), + ) def add_arg_report(parser: ArgumentParser) -> None: - parser.add_argument('report', type=read_report, nargs='*', - help='JSON report file(s)') + parser.add_argument( + "report", type=read_report, nargs="*", help="JSON report file(s)" + ) -def add_arg_report_filename(parser: ArgumentParser, nargs='*') -> None: - parser.add_argument('report', type=str, nargs=nargs, - help='JSON report file(s)') +def add_arg_report_filename(parser: ArgumentParser, nargs="*") -> None: + parser.add_argument("report", type=str, nargs=nargs, help="JSON report file(s)") def get_reports_from_filenames(args: Namespace) -> Sequence[DiffReport]: @@ -133,7 +163,9 @@ def get_reports_from_filenames(args: Namespace) -> Sequence[DiffReport]: return reports -def get_datafile(args: Namespace, key: str = 'datafile', check_exists: bool = True) -> str: +def get_datafile( + args: Namespace, key: str = "datafile", check_exists: bool = True +) -> str: datafile = getattr(args, key, None) if datafile is None: datafile = os.path.join(args.envdir, REPORT_FILENAME) @@ -155,8 +187,8 @@ def check_metadb_servers_version(lmdb, servers: Sequence[str]) -> None: @contextlib.contextmanager def smart_open(filename=None): - if filename and filename != '-': - fh = open(filename, 'w', encoding='UTF-8') + if filename and filename != "-": + fh = open(filename, "w", encoding="UTF-8") else: fh = sys.stdout @@ -168,47 +200,47 @@ def smart_open(filename=None): def format_stats_line( - description: str, - number: int, - pct: float = None, - diff: int = None, - diff_pct: float = None, - additional: str = None - ) -> str: + description: str, + number: int, + pct: float = None, + diff: int = None, + diff_pct: float = None, + additional: str = None, +) -> str: s = {} # type: Dict[str, str] - s['description'] = '{:21s}'.format(description) - s['number'] = '{:8d}'.format(number) - s['pct'] = '{:6.2f} %'.format(pct) if pct is not None else ' ' * 8 - s['additional'] = '{:30s}'.format(additional) if additional is not None else ' ' * 30 - s['diff'] = '{:+6d}'.format(diff) if diff is not None else ' ' * 6 - s['diff_pct'] = '{:+7.2f} %'.format(diff_pct) if diff_pct is not None else ' ' * 9 + s["description"] = "{:21s}".format(description) + s["number"] = "{:8d}".format(number) + s["pct"] = "{:6.2f} %".format(pct) if pct is not None else " " * 8 + s["additional"] = ( + "{:30s}".format(additional) if additional is not None else " " * 30 + ) + s["diff"] = "{:+6d}".format(diff) if diff is not None else " " * 6 + s["diff_pct"] = "{:+7.2f} %".format(diff_pct) if diff_pct is not None else " " * 9 - return '{description} {number} {pct} {additional} {diff} {diff_pct}'.format(**s) + return "{description} {number} {pct} {additional} {diff} {diff_pct}".format(**s) def get_stats_data( - n: int, - total: int = None, - ref_n: int = None, - ) -> ChangeStatsTuple: + n: int, + total: int = None, + ref_n: int = None, +) -> ChangeStatsTuple: """ Return absolute and relative data statistics Optionally, the data is compared with a reference. """ - def percentage( - dividend: Number, - divisor: Optional[Number] - ) -> Optional[float]: + + def percentage(dividend: Number, divisor: Optional[Number]) -> Optional[float]: """Return dividend/divisor value in %""" if divisor is None: return None if divisor == 0: if dividend > 0: - return float('+inf') + return float("+inf") if dividend < 0: - return float('-inf') - return float('nan') + return float("-inf") + return float("nan") return dividend * 100.0 / divisor pct = percentage(n, total) @@ -223,25 +255,20 @@ def get_stats_data( def get_table_stats_row( - count: int, - total: int, - ref_count: Optional[int] = None - ) -> Union[StatsTuple, ChangeStatsTupleStr]: - n, pct, diff, diff_pct = get_stats_data( # type: ignore - count, - total, - ref_count) + count: int, total: int, ref_count: Optional[int] = None +) -> Union[StatsTuple, ChangeStatsTupleStr]: + n, pct, diff, diff_pct = get_stats_data(count, total, ref_count) # type: ignore if ref_count is None: return n, pct - s_diff = '{:+d}'.format(diff) if diff is not None else None + s_diff = "{:+d}".format(diff) if diff is not None else None return n, pct, s_diff, diff_pct def print_fields_overview( - field_counters: Mapping[FieldLabel, Counter], - n_disagreements: int, - ref_field_counters: Optional[Mapping[FieldLabel, Counter]] = None, - ) -> None: + field_counters: Mapping[FieldLabel, Counter], + n_disagreements: int, + ref_field_counters: Optional[Mapping[FieldLabel, Counter]] = None, +) -> None: rows = [] def get_field_count(counter: Counter) -> int: @@ -256,117 +283,164 @@ def print_fields_overview( if ref_field_counters is not None: ref_counter = ref_field_counters.get(field, Counter()) ref_field_count = get_field_count(ref_counter) - rows.append((field, *get_table_stats_row( - field_count, n_disagreements, ref_field_count))) + rows.append( + (field, *get_table_stats_row(field_count, n_disagreements, ref_field_count)) + ) - headers = ['Field', 'Count', '% of mismatches'] + headers = ["Field", "Count", "% of mismatches"] if ref_field_counters is not None: - headers.extend(['Change', 'Change (%)']) + headers.extend(["Change", "Change (%)"]) - print('== Target Disagreements') - print(tabulate( # type: ignore - sorted(rows, key=lambda data: data[1], reverse=True), # type: ignore - headers, - tablefmt='simple', - floatfmt=('s', 'd', '.2f', 's', '+.2f'))) # type: ignore - print('') + print("== Target Disagreements") + print( + tabulate( # type: ignore + sorted(rows, key=lambda data: data[1], reverse=True), # type: ignore + headers, + tablefmt="simple", + floatfmt=("s", "d", ".2f", "s", "+.2f"), + ) + ) # type: ignore + print("") def print_field_mismatch_stats( - field: FieldLabel, - counter: Counter, - n_disagreements: int, - ref_counter: Counter = None - ) -> None: + field: FieldLabel, + counter: Counter, + n_disagreements: int, + ref_counter: Counter = None, +) -> None: rows = [] ref_count = None for mismatch, count in counter.items(): if ref_counter is not None: ref_count = ref_counter[mismatch] - rows.append(( - DataMismatch.format_value(mismatch.exp_val), - DataMismatch.format_value(mismatch.got_val), - *get_table_stats_row( - count, n_disagreements, ref_count))) - - headers = ['Expected', 'Got', 'Count', '% of mimatches'] + rows.append( + ( + DataMismatch.format_value(mismatch.exp_val), + DataMismatch.format_value(mismatch.got_val), + *get_table_stats_row(count, n_disagreements, ref_count), + ) + ) + + headers = ["Expected", "Got", "Count", "% of mimatches"] if ref_counter is not None: - headers.extend(['Change', 'Change (%)']) + headers.extend(["Change", "Change (%)"]) print('== Field "{}" mismatch statistics'.format(field)) - print(tabulate( # type: ignore - sorted(rows, key=lambda data: data[2], reverse=True), # type: ignore - headers, - tablefmt='simple', - floatfmt=('s', 's', 'd', '.2f', 's', '+.2f'))) # type: ignore - print('') + print( + tabulate( # type: ignore + sorted(rows, key=lambda data: data[2], reverse=True), # type: ignore + headers, + tablefmt="simple", + floatfmt=("s", "s", "d", ".2f", "s", "+.2f"), + ) + ) # type: ignore + print("") def print_global_stats(report: DiffReport, reference: DiffReport = None) -> None: - ref_duration = getattr(reference, 'duration', None) - ref_total_answers = getattr(reference, 'total_answers', None) - ref_total_queries = getattr(reference, 'total_queries', None) - - if (report.duration is None - or report.total_answers is None - or report.total_queries is None): + ref_duration = getattr(reference, "duration", None) + ref_total_answers = getattr(reference, "total_answers", None) + ref_total_queries = getattr(reference, "total_queries", None) + + if ( + report.duration is None + or report.total_answers is None + or report.total_queries is None + ): raise RuntimeError("Report doesn't containt necassary data!") - print('== Global statistics') - print(format_stats_line('duration', *get_stats_data( - report.duration, ref_n=ref_duration), - additional='seconds')) - print(format_stats_line('queries', *get_stats_data( - report.total_queries, ref_n=ref_total_queries))) - print(format_stats_line('answers', *get_stats_data( - report.total_answers, report.total_queries, - ref_total_answers), - additional='of queries')) - print('') + print("== Global statistics") + print( + format_stats_line( + "duration", + *get_stats_data(report.duration, ref_n=ref_duration), + additional="seconds" + ) + ) + print( + format_stats_line( + "queries", *get_stats_data(report.total_queries, ref_n=ref_total_queries) + ) + ) + print( + format_stats_line( + "answers", + *get_stats_data( + report.total_answers, report.total_queries, ref_total_answers + ), + additional="of queries" + ) + ) + print("") def print_differences_stats(report: DiffReport, reference: DiffReport = None) -> None: - ref_summary = getattr(reference, 'summary', None) - ref_manual_ignore = getattr(ref_summary, 'manual_ignore', None) - ref_upstream_unstable = getattr(ref_summary, 'upstream_unstable', None) - ref_not_reproducible = getattr(ref_summary, 'not_reproducible', None) + ref_summary = getattr(reference, "summary", None) + ref_manual_ignore = getattr(ref_summary, "manual_ignore", None) + ref_upstream_unstable = getattr(ref_summary, "upstream_unstable", None) + ref_not_reproducible = getattr(ref_summary, "not_reproducible", None) ref_target_disagrees = len(ref_summary) if ref_summary is not None else None if report.summary is None: raise RuntimeError("Report doesn't containt necassary data!") - print('== Differences statistics') - print(format_stats_line('manually ignored', *get_stats_data( - report.summary.manual_ignore, report.total_answers, - ref_manual_ignore), - additional='of answers (ignoring)')) - print(format_stats_line('upstream unstable', *get_stats_data( - report.summary.upstream_unstable, report.total_answers, - ref_upstream_unstable), - additional='of answers (ignoring)')) - print(format_stats_line('not 100% reproducible', *get_stats_data( - report.summary.not_reproducible, report.total_answers, - ref_not_reproducible), - additional='of answers (ignoring)')) - print(format_stats_line('target disagrees', *get_stats_data( - len(report.summary), report.summary.usable_answers, - ref_target_disagrees), - additional='of not ignored answers')) - print('') + print("== Differences statistics") + print( + format_stats_line( + "manually ignored", + *get_stats_data( + report.summary.manual_ignore, report.total_answers, ref_manual_ignore + ), + additional="of answers (ignoring)" + ) + ) + print( + format_stats_line( + "upstream unstable", + *get_stats_data( + report.summary.upstream_unstable, + report.total_answers, + ref_upstream_unstable, + ), + additional="of answers (ignoring)" + ) + ) + print( + format_stats_line( + "not 100% reproducible", + *get_stats_data( + report.summary.not_reproducible, + report.total_answers, + ref_not_reproducible, + ), + additional="of answers (ignoring)" + ) + ) + print( + format_stats_line( + "target disagrees", + *get_stats_data( + len(report.summary), report.summary.usable_answers, ref_target_disagrees + ), + additional="of not ignored answers" + ) + ) + print("") def print_mismatch_queries( - field: FieldLabel, - mismatch: DataMismatch, - queries: Sequence[Tuple[str, int, str]], - limit: Optional[int] = DEFAULT_PRINT_QUERY_LIMIT - ) -> None: + field: FieldLabel, + mismatch: DataMismatch, + queries: Sequence[Tuple[str, int, str]], + limit: Optional[int] = DEFAULT_PRINT_QUERY_LIMIT, +) -> None: if limit == 0: limit = None def sort_key(data: Tuple[str, int, str]) -> Tuple[int, int]: - order = ['+', ' ', '-'] + order = ["+", " ", "-"] try: return order.index(data[0]), -data[1] except ValueError: @@ -379,13 +453,10 @@ def print_mismatch_queries( to_print = to_print[:limit] print('== Field "{}", mismatch "{}" query details'.format(field, mismatch)) - print(format_line('', 'Count', 'Query')) + print(format_line("", "Count", "Query")) for diff, count, query in to_print: print(format_line(diff, str(count), query)) if limit is not None and limit < len(queries): - print(format_line( - 'x', - str(len(queries) - limit), - 'queries omitted')) - print('') + print(format_line("x", str(len(queries) - limit), "queries omitted")) + print("") diff --git a/respdiff/database.py b/respdiff/database.py index 13abc18dbe2bfdd7a7a6f5c843ffb3900b1f1df0..b78be56ab873f45f116227cfed8f698c0e847050 100644 --- a/respdiff/database.py +++ b/respdiff/database.py @@ -4,7 +4,16 @@ import os import struct import time from typing import ( # noqa - Any, Callable, Dict, Iterator, List, Mapping, Optional, Tuple, Sequence) + Any, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, + Tuple, + Sequence, +) import dns.exception import dns.message @@ -13,53 +22,56 @@ import lmdb from .typing import ResolverID, QID, QKey, WireFormat -BIN_FORMAT_VERSION = '2018-05-21' +BIN_FORMAT_VERSION = "2018-05-21" def qid2key(qid: QID) -> QKey: - return struct.pack('<I', qid) + return struct.pack("<I", qid) def key2qid(key: QKey) -> QID: - return struct.unpack('<I', key)[0] + return struct.unpack("<I", key)[0] class LMDB: - ANSWERS = b'answers' - DIFFS = b'diffs' - QUERIES = b'queries' - META = b'meta' + ANSWERS = b"answers" + DIFFS = b"diffs" + QUERIES = b"queries" + META = b"meta" ENV_DEFAULTS = { - 'map_size': 10 * 1024**3, # 10 G - 'max_readers': 384, - 'max_dbs': 5, - 'max_spare_txns': 64, + "map_size": 10 * 1024**3, # 10 G + "max_readers": 384, + "max_dbs": 5, + "max_spare_txns": 64, } # type: Dict[str, Any] DB_OPEN_DEFAULTS = { - 'integerkey': False, + "integerkey": False, # surprisingly, optimal configuration seems to be # native integer as database key *without* # integerkey support in LMDB } # type: Dict[str, Any] - def __init__(self, path: str, create: bool = False, - readonly: bool = False, fast: bool = False) -> None: + def __init__( + self, + path: str, + create: bool = False, + readonly: bool = False, + fast: bool = False, + ) -> None: self.path = path self.dbs = {} # type: Dict[bytes, Any] self.config = LMDB.ENV_DEFAULTS.copy() - self.config.update({ - 'path': path, - 'create': create, - 'readonly': readonly - }) + self.config.update({"path": path, "create": create, "readonly": readonly}) if fast: # unsafe on crashes, but faster - self.config.update({ - 'writemap': True, - 'sync': False, - 'map_async': True, - }) + self.config.update( + { + "writemap": True, + "sync": False, + "map_async": True, + } + ) if not os.path.exists(self.path): os.makedirs(self.path) @@ -71,17 +83,25 @@ class LMDB: def __exit__(self, exc_type, exc_val, exc_tb): self.env.close() - def open_db(self, dbname: bytes, create: bool = False, - check_notexists: bool = False, drop: bool = False): + def open_db( + self, + dbname: bytes, + create: bool = False, + check_notexists: bool = False, + drop: bool = False, + ): assert self.env is not None, "LMDB wasn't initialized!" if not create and not self.exists_db(dbname): msg = 'LMDB environment "{}" does not contain DB "{}"! '.format( - self.path, dbname.decode('utf-8')) + self.path, dbname.decode("utf-8") + ) raise RuntimeError(msg) if check_notexists and self.exists_db(dbname): - msg = ('LMDB environment "{}" already contains DB "{}"! ' - 'Overwritting it would invalidate data in the environment, ' - 'terminating.').format(self.path, dbname.decode('utf-8')) + msg = ( + 'LMDB environment "{}" already contains DB "{}"! ' + "Overwritting it would invalidate data in the environment, " + "terminating." + ).format(self.path, dbname.decode("utf-8")) raise RuntimeError(msg) if drop: try: @@ -106,7 +126,9 @@ class LMDB: try: return self.dbs[dbname] except KeyError as e: - raise ValueError("Database {} isn't open!".format(dbname.decode('utf-8'))) from e + raise ValueError( + "Database {} isn't open!".format(dbname.decode("utf-8")) + ) from e def key_stream(self, dbname: bytes) -> Iterator[bytes]: """yield all keys from given db""" @@ -129,70 +151,69 @@ class DNSReply: TIMEOUT_INT = 4294967295 SIZEOF_INT = 4 SIZEOF_SHORT = 2 - WIREFORMAT_VALID = 'Valid' + WIREFORMAT_VALID = "Valid" def __init__(self, wire: Optional[WireFormat], time_: float = 0) -> None: if wire is None: - self.wire = b'' - self.time = float('+inf') + self.wire = b"" + self.time = float("+inf") else: self.wire = wire self.time = time_ @property def timeout(self) -> bool: - return self.time == float('+inf') + return self.time == float("+inf") def __eq__(self, other) -> bool: if self.timeout and other.timeout: return True # float equality comparison: use 10^-7 tolerance since it's less than available # resoltuion from the time_int integer value (which is 10^-6) - return self.wire == other.wire and \ - abs(self.time - other.time) < 10 ** -7 + return self.wire == other.wire and abs(self.time - other.time) < 10**-7 @property def time_int(self) -> int: - if self.time == float('+inf'): + if self.time == float("+inf"): return self.TIMEOUT_INT - value = round(self.time * (10 ** 6)) + value = round(self.time * (10**6)) if value > self.TIMEOUT_INT: raise ValueError( 'Maximum time value exceeded: (value: "{}", max: {})'.format( - value, self.TIMEOUT_INT)) + value, self.TIMEOUT_INT + ) + ) return value @property def binary(self) -> bytes: length = len(self.wire) - return struct.pack('<I', self.time_int) + struct.pack('<H', length) + self.wire + return struct.pack("<I", self.time_int) + struct.pack("<H", length) + self.wire @classmethod - def from_binary(cls, buff: bytes) -> Tuple['DNSReply', bytes]: + def from_binary(cls, buff: bytes) -> Tuple["DNSReply", bytes]: if len(buff) < (cls.SIZEOF_INT + cls.SIZEOF_SHORT): - raise ValueError('Missing data in binary format') + raise ValueError("Missing data in binary format") offset = 0 - time_int, = struct.unpack_from('<I', buff, offset) + (time_int,) = struct.unpack_from("<I", buff, offset) offset += cls.SIZEOF_INT - length, = struct.unpack_from('<H', buff, offset) + (length,) = struct.unpack_from("<H", buff, offset) offset += cls.SIZEOF_SHORT - wire = buff[offset:(offset+length)] + wire = buff[offset:(offset + length)] offset += length if len(wire) != length: - raise ValueError('Missing data in binary format') + raise ValueError("Missing data in binary format") if time_int == cls.TIMEOUT_INT: - time_ = float('+inf') + time_ = float("+inf") else: - time_ = time_int / (10 ** 6) + time_ = time_int / (10**6) reply = DNSReply(wire, time_) return reply, buff[offset:] - def parse_wire( - self - ) -> Tuple[Optional[dns.message.Message], str]: + def parse_wire(self) -> Tuple[Optional[dns.message.Message], str]: try: return dns.message.from_wire(self.wire), self.WIREFORMAT_VALID except dns.exception.FormError as exc: @@ -201,9 +222,10 @@ class DNSReply: class DNSRepliesFactory: """Thread-safe factory to parse DNSReply objects from binary blob.""" + def __init__(self, servers: Sequence[ResolverID]) -> None: if not servers: - raise ValueError('One or more servers have to be specified') + raise ValueError("One or more servers have to be specified") self.servers = servers def parse(self, buff: bytes) -> Dict[ResolverID, DNSReply]: @@ -212,25 +234,24 @@ class DNSRepliesFactory: reply, buff = DNSReply.from_binary(buff) replies[server] = reply if buff: - raise ValueError('Trailing data in buffer') + raise ValueError("Trailing data in buffer") return replies def serialize(self, replies: Mapping[ResolverID, DNSReply]) -> bytes: if len(replies) > len(self.servers): - raise ValueError('Extra unexpected data to serialize!') + raise ValueError("Extra unexpected data to serialize!") data = [] for server in self.servers: try: reply = replies[server] except KeyError as e: raise ValueError('Missing reply for server "{}"!'.format(server)) from e - else: - data.append(reply.binary) - return b''.join(data) + data.append(reply.binary) + return b"".join(data) class Database(ABC): - DB_NAME = b'' + DB_NAME = b"" def __init__(self, lmdb_, create: bool = False) -> None: self.lmdb = lmdb_ @@ -242,14 +263,16 @@ class Database(ABC): # ensure database is open if self.db is None: if not self.DB_NAME: - raise RuntimeError('No database to initialize!') + raise RuntimeError("No database to initialize!") try: self.db = self.lmdb.get_db(self.DB_NAME) except ValueError: try: self.db = self.lmdb.open_db(self.DB_NAME, create=self.create) except lmdb.Error as exc: - raise RuntimeError('Failed to open LMDB database: {}'.format(exc)) from exc + raise RuntimeError( + "Failed to open LMDB database: {}".format(exc) + ) from exc with self.lmdb.env.begin(self.db, write=write) as txn: yield txn @@ -258,8 +281,11 @@ class Database(ABC): with self.transaction() as txn: data = txn.get(key) if data is None: - raise KeyError("Missing '{}' key in '{}' database!".format( - key.decode('ascii'), self.DB_NAME.decode('ascii'))) + raise KeyError( + "Missing '{}' key in '{}' database!".format( + key.decode("ascii"), self.DB_NAME.decode("ascii") + ) + ) return data def write_key(self, key: bytes, value: bytes) -> None: @@ -269,18 +295,15 @@ class Database(ABC): class MetaDatabase(Database): DB_NAME = LMDB.META - KEY_VERSION = b'version' - KEY_START_TIME = b'start_time' - KEY_END_TIME = b'end_time' - KEY_SERVERS = b'servers' - KEY_NAME = b'name' + KEY_VERSION = b"version" + KEY_START_TIME = b"start_time" + KEY_END_TIME = b"end_time" + KEY_SERVERS = b"servers" + KEY_NAME = b"name" def __init__( - self, - lmdb_, - servers: Sequence[ResolverID], - create: bool = False - ) -> None: + self, lmdb_, servers: Sequence[ResolverID], create: bool = False + ) -> None: super().__init__(lmdb_, create) if create: self.write_servers(servers) @@ -291,38 +314,41 @@ class MetaDatabase(Database): def read_servers(self) -> List[ResolverID]: servers = [] ndata = self.read_key(self.KEY_SERVERS) - n, = struct.unpack('<I', ndata) + (n,) = struct.unpack("<I", ndata) for i in range(n): - key = self.KEY_NAME + str(i).encode('ascii') + key = self.KEY_NAME + str(i).encode("ascii") server = self.read_key(key) - servers.append(server.decode('ascii')) + servers.append(server.decode("ascii")) return servers def write_servers(self, servers: Sequence[ResolverID]) -> None: if not servers: raise ValueError("Empty list of servers!") - n = struct.pack('<I', len(servers)) + n = struct.pack("<I", len(servers)) self.write_key(self.KEY_SERVERS, n) for i, server in enumerate(servers): - key = self.KEY_NAME + str(i).encode('ascii') - self.write_key(key, server.encode('ascii')) + key = self.KEY_NAME + str(i).encode("ascii") + self.write_key(key, server.encode("ascii")) def check_servers(self, servers: Sequence[ResolverID]) -> None: db_servers = self.read_servers() if not servers == db_servers: raise NotImplementedError( 'Servers defined in config differ from the ones in "meta" database! ' - '(config: "{}", meta db: "{}")'.format(servers, db_servers)) + '(config: "{}", meta db: "{}")'.format(servers, db_servers) + ) def write_version(self) -> None: - self.write_key(self.KEY_VERSION, BIN_FORMAT_VERSION.encode('ascii')) + self.write_key(self.KEY_VERSION, BIN_FORMAT_VERSION.encode("ascii")) def check_version(self) -> None: - version = self.read_key(self.KEY_VERSION).decode('ascii') + version = self.read_key(self.KEY_VERSION).decode("ascii") if version != BIN_FORMAT_VERSION: raise NotImplementedError( 'LMDB version mismatch! (expected "{}", got "{}")'.format( - BIN_FORMAT_VERSION, version)) + BIN_FORMAT_VERSION, version + ) + ) def write_start_time(self, timestamp: Optional[int] = None) -> None: self._write_timestamp(self.KEY_START_TIME, timestamp) @@ -342,10 +368,10 @@ class MetaDatabase(Database): except KeyError: return None else: - return struct.unpack('<I', data)[0] + return struct.unpack("<I", data)[0] def _write_timestamp(self, key: bytes, timestamp: Optional[int]) -> None: if timestamp is None: timestamp = round(time.time()) - data = struct.pack('<I', timestamp) + data = struct.pack("<I", timestamp) self.write_key(key, data) diff --git a/respdiff/dataformat.py b/respdiff/dataformat.py index 19d20e74e7fe057d1c3a62a59e804f14917247bb..affab275d7c44facf69e69f0e74701ae355b971f 100644 --- a/respdiff/dataformat.py +++ b/respdiff/dataformat.py @@ -4,8 +4,20 @@ from collections import Counter import collections.abc import json from typing import ( # noqa - Any, Callable, Dict, Hashable, ItemsView, Iterator, KeysView, Mapping, - Optional, Set, Sequence, Tuple, Union) + Any, + Callable, + Dict, + Hashable, + ItemsView, + Iterator, + KeysView, + Mapping, + Optional, + Set, + Sequence, + Tuple, + Union, +) from .match import DataMismatch from .typing import FieldLabel, QID @@ -20,23 +32,26 @@ class InvalidFileFormat(Exception): class JSONDataObject: """Object class for (de)serialization into JSON-compatible dictionary.""" + _ATTRIBUTES = {} # type: Mapping[str, Tuple[RestoreFunction, SaveFunction]] def __init__(self, **kwargs): # pylint: disable=unused-argument - self.fileorigin = '' + self.fileorigin = "" def export_json(self, filename: str) -> None: json_string = json.dumps(self.save(), indent=2) - with open(filename, 'w', encoding='UTF-8') as f: + with open(filename, "w", encoding="UTF-8") as f: f.write(json_string) @classmethod def from_json(cls, filename: str): try: - with open(filename, encoding='UTF-8') as f: + with open(filename, encoding="UTF-8") as f: restore_dict = json.load(f) except json.decoder.JSONDecodeError as e: - raise InvalidFileFormat("Couldn't parse JSON file: {}".format(filename)) from e + raise InvalidFileFormat( + "Couldn't parse JSON file: {}".format(filename) + ) from e inst = cls(_restore_dict=restore_dict) inst.fileorigin = filename return inst @@ -54,11 +69,11 @@ class JSONDataObject: return restore_dict def _restore_attr( - self, - restore_dict: Mapping[str, Any], - key: str, - restore_func: RestoreFunction = None - ) -> None: + self, + restore_dict: Mapping[str, Any], + key: str, + restore_func: RestoreFunction = None, + ) -> None: """ Restore attribute from key in dictionary. If it's missing or None, don't call restore_func() and leave attribute's value default. @@ -72,11 +87,7 @@ class JSONDataObject: value = restore_func(value) setattr(self, key, value) - def _save_attr( - self, - key: str, - save_func: SaveFunction = None - ) -> Mapping[str, Any]: + def _save_attr(self, key: str, save_func: SaveFunction = None) -> Mapping[str, Any]: """ Save attribute as a key in dictionary. If the attribute is None, save it as such (without calling save_func()). @@ -89,6 +100,7 @@ class JSONDataObject: class Diff(collections.abc.Mapping): """Read-only representation of mismatches in each field for a single query""" + __setitem__ = None __delitem__ = None @@ -107,19 +119,17 @@ class Diff(collections.abc.Mapping): return iter(self._mismatches) def get_significant_field( - self, - field_weights: Sequence[FieldLabel] - ) -> Tuple[Optional[FieldLabel], Optional[DataMismatch]]: + self, field_weights: Sequence[FieldLabel] + ) -> Tuple[Optional[FieldLabel], Optional[DataMismatch]]: for significant_field in field_weights: if significant_field in self: return significant_field, self[significant_field] return None, None def __repr__(self) -> str: - return 'Diff({})'.format( - ', '.join([ - repr(mismatch) for mismatch in self.values() - ])) + return "Diff({})".format( + ", ".join([repr(mismatch) for mismatch in self.values()]) + ) def __eq__(self, other) -> bool: if len(self) != len(other): @@ -142,9 +152,9 @@ class Disagreements(collections.abc.Mapping, JSONDataObject): """ def __init__( - self, - _restore_dict: Optional[Mapping[str, Any]] = None, - ) -> None: + self, + _restore_dict: Optional[Mapping[str, Any]] = None, + ) -> None: """ `_restore_dict` is used to restore from JSON, minimal format: "fields": { @@ -161,40 +171,44 @@ class Disagreements(collections.abc.Mapping, JSONDataObject): """ super().__init__() self._fields = collections.defaultdict( - lambda: collections.defaultdict(set) - ) # type: Dict[FieldLabel, Dict[DataMismatch, Set[QID]]] + lambda: collections.defaultdict(set) + ) # type: Dict[FieldLabel, Dict[DataMismatch, Set[QID]]] if _restore_dict is not None: self.restore(_restore_dict) def restore(self, restore_dict: Mapping[str, Any]) -> None: super().restore(restore_dict) - for field_label, field_data in restore_dict['fields'].items(): - for mismatch_data in field_data['mismatches']: + for field_label, field_data in restore_dict["fields"].items(): + for mismatch_data in field_data["mismatches"]: mismatch = DataMismatch( - mismatch_data['exp_val'], - mismatch_data['got_val']) - self._fields[field_label][mismatch] = set(mismatch_data['queries']) + mismatch_data["exp_val"], mismatch_data["got_val"] + ) + self._fields[field_label][mismatch] = set(mismatch_data["queries"]) def save(self) -> Dict[str, Any]: fields = {} for field, field_data in self._fields.items(): mismatches = [] for mismatch, mismatch_data in field_data.items(): - mismatches.append({ - 'count': len(mismatch_data), - 'exp_val': mismatch.exp_val, - 'got_val': mismatch.got_val, - 'queries': list(mismatch_data), - }) + mismatches.append( + { + "count": len(mismatch_data), + "exp_val": mismatch.exp_val, + "got_val": mismatch.got_val, + "queries": list(mismatch_data), + } + ) fields[field] = { - 'count': len(mismatches), - 'mismatches': mismatches, + "count": len(mismatches), + "mismatches": mismatches, } restore_dict = super().save() or {} - restore_dict.update({ - 'count': self.count, - 'fields': fields, - }) + restore_dict.update( + { + "count": self.count, + "fields": fields, + } + ) return restore_dict def add_mismatch(self, field: FieldLabel, mismatch: DataMismatch, qid: QID) -> None: @@ -205,9 +219,8 @@ class Disagreements(collections.abc.Mapping, JSONDataObject): return self._fields.keys() def get_field_mismatches( - self, - field: FieldLabel - ) -> ItemsView[DataMismatch, Set[QID]]: + self, field: FieldLabel + ) -> ItemsView[DataMismatch, Set[QID]]: return self._fields[field].items() @property @@ -239,7 +252,7 @@ class Disagreements(collections.abc.Mapping, JSONDataObject): class DisagreementsCounter(JSONDataObject): _ATTRIBUTES = { - 'queries': (set, list), + "queries": (set, list), } def __init__(self, _restore_dict: Mapping[str, int] = None) -> None: @@ -257,17 +270,17 @@ class DisagreementsCounter(JSONDataObject): class Summary(Disagreements): """Disagreements, where each query has no more than one mismatch.""" + _ATTRIBUTES = { - 'upstream_unstable': (None, None), - 'usable_answers': (None, None), - 'not_reproducible': (None, None), - 'manual_ignore': (None, None), + "upstream_unstable": (None, None), + "usable_answers": (None, None), + "not_reproducible": (None, None), + "manual_ignore": (None, None), } def __init__( - self, - _restore_dict: Optional[Mapping[FieldLabel, Mapping[str, Any]]] = None - ) -> None: + self, _restore_dict: Optional[Mapping[FieldLabel, Mapping[str, Any]]] = None + ) -> None: self.usable_answers = 0 self.upstream_unstable = 0 self.not_reproducible = 0 @@ -276,17 +289,17 @@ class Summary(Disagreements): def add_mismatch(self, field: FieldLabel, mismatch: DataMismatch, qid: QID) -> None: if qid in self.keys(): - raise ValueError('QID {} already exists in Summary!'.format(qid)) + raise ValueError("QID {} already exists in Summary!".format(qid)) self._fields[field][mismatch].add(qid) @staticmethod def from_report( - report: 'DiffReport', - field_weights: Sequence[FieldLabel], - reproducibility_threshold: float = 1, - without_diffrepro: bool = False, - ignore_qids: Optional[Set[QID]] = None - ) -> 'Summary': + report: "DiffReport", + field_weights: Sequence[FieldLabel], + reproducibility_threshold: float = 1, + without_diffrepro: bool = False, + ignore_qids: Optional[Set[QID]] = None, + ) -> "Summary": """ Get summary of disagreements above the specified reproduciblity threshold [0, 1]. @@ -294,9 +307,11 @@ class Summary(Disagreements): Optionally, provide a list of known unstable and/or failing QIDs which will be ignored. """ - if (report.other_disagreements is None - or report.target_disagreements is None - or report.total_answers is None): + if ( + report.other_disagreements is None + or report.target_disagreements is None + or report.total_answers is None + ): raise RuntimeError("Report has insufficient data to create Summary") if ignore_qids is None: @@ -315,7 +330,9 @@ class Summary(Disagreements): if reprocounter.retries != reprocounter.upstream_stable: summary.upstream_unstable += 1 continue # filter unstable upstream - reproducibility = float(reprocounter.verified) / reprocounter.retries + reproducibility = ( + float(reprocounter.verified) / reprocounter.retries + ) if reproducibility < reproducibility_threshold: summary.not_reproducible += 1 continue # filter less reproducible than threshold @@ -324,7 +341,8 @@ class Summary(Disagreements): summary.add_mismatch(field, mismatch, qid) summary.usable_answers = ( - report.total_answers - summary.upstream_unstable - summary.not_reproducible) + report.total_answers - summary.upstream_unstable - summary.not_reproducible + ) return summary def get_field_counters(self) -> Mapping[FieldLabel, Counter]: @@ -339,20 +357,23 @@ class Summary(Disagreements): class ReproCounter(JSONDataObject): _ATTRIBUTES = { - 'retries': (None, None), # total amount of attempts to reproduce - 'upstream_stable': (None, None), # number of cases, where others disagree - 'verified': (None, None), # the query fails, and the diff is same (reproduced) - 'different_failure': (None, None) # the query fails, but the diff doesn't match + "retries": (None, None), # total amount of attempts to reproduce + "upstream_stable": (None, None), # number of cases, where others disagree + "verified": (None, None), # the query fails, and the diff is same (reproduced) + "different_failure": ( + None, + None, + ), # the query fails, but the diff doesn't match } def __init__( - self, - retries: int = 0, - upstream_stable: int = 0, - verified: int = 0, - different_failure: int = 0, - _restore_dict: Optional[Mapping[str, int]] = None - ) -> None: + self, + retries: int = 0, + upstream_stable: int = 0, + verified: int = 0, + different_failure: int = 0, + _restore_dict: Optional[Mapping[str, int]] = None, + ) -> None: super().__init__() self.retries = retries self.upstream_stable = upstream_stable @@ -371,13 +392,16 @@ class ReproCounter(JSONDataObject): self.retries == other.retries and self.upstream_stable == other.upstream_stable and self.verified == other.verified - and self.different_failure == other.different_failure) + and self.different_failure == other.different_failure + ) class ReproData(collections.abc.Mapping, JSONDataObject): def __init__(self, _restore_dict: Optional[Mapping[str, Any]] = None) -> None: super().__init__() - self._counters = collections.defaultdict(ReproCounter) # type: Dict[QID, ReproCounter] + self._counters = collections.defaultdict( + ReproCounter + ) # type: Dict[QID, ReproCounter] if _restore_dict is not None: self.restore(_restore_dict) @@ -406,41 +430,41 @@ class ReproData(collections.abc.Mapping, JSONDataObject): yield from self._counters.keys() -QueryData = collections.namedtuple('QueryData', 'total, others_disagree, target_disagrees') +QueryData = collections.namedtuple( + "QueryData", "total, others_disagree, target_disagrees" +) class DiffReport(JSONDataObject): # pylint: disable=too-many-instance-attributes _ATTRIBUTES = { - 'start_time': (None, None), - 'end_time': (None, None), - 'total_queries': (None, None), - 'total_answers': (None, None), - 'other_disagreements': ( + "start_time": (None, None), + "end_time": (None, None), + "total_queries": (None, None), + "total_answers": (None, None), + "other_disagreements": ( lambda x: DisagreementsCounter(_restore_dict=x), - lambda x: x.save()), - 'target_disagreements': ( + lambda x: x.save(), + ), + "target_disagreements": ( lambda x: Disagreements(_restore_dict=x), - lambda x: x.save()), - 'summary': ( - lambda x: Summary(_restore_dict=x), - lambda x: x.save()), - 'reprodata': ( - lambda x: ReproData(_restore_dict=x), - lambda x: x.save()), + lambda x: x.save(), + ), + "summary": (lambda x: Summary(_restore_dict=x), lambda x: x.save()), + "reprodata": (lambda x: ReproData(_restore_dict=x), lambda x: x.save()), } def __init__( - self, - start_time: Optional[int] = None, - end_time: Optional[int] = None, - total_queries: Optional[int] = None, - total_answers: Optional[int] = None, - other_disagreements: Optional[DisagreementsCounter] = None, - target_disagreements: Optional[Disagreements] = None, - summary: Optional[Summary] = None, - reprodata: Optional[ReproData] = None, - _restore_dict: Optional[Mapping[str, Any]] = None - ) -> None: + self, + start_time: Optional[int] = None, + end_time: Optional[int] = None, + total_queries: Optional[int] = None, + total_answers: Optional[int] = None, + other_disagreements: Optional[DisagreementsCounter] = None, + target_disagreements: Optional[Disagreements] = None, + summary: Optional[Summary] = None, + reprodata: Optional[ReproData] = None, + _restore_dict: Optional[Mapping[str, Any]] = None, + ) -> None: super().__init__() self.start_time = start_time self.end_time = end_time diff --git a/respdiff/dnsviz.py b/respdiff/dnsviz.py index 7748512a5057a550898200492822bc9371e08df6..010037ce5c740f03bfa1123190cbeea802775221 100644 --- a/respdiff/dnsviz.py +++ b/respdiff/dnsviz.py @@ -6,19 +6,19 @@ import dns import dns.name -TYPES = 'A,AAAA,CNAME' -KEYS_ERROR = ['error', 'errors'] -KEYS_WARNING = ['warnings'] +TYPES = "A,AAAA,CNAME" +KEYS_ERROR = ["error", "errors"] +KEYS_WARNING = ["warnings"] class DnsvizDomainResult: def __init__( - self, - errors: Optional[Mapping[str, List]] = None, - warnings: Optional[Mapping[str, List]] = None, - ) -> None: + self, + errors: Optional[Mapping[str, List]] = None, + warnings: Optional[Mapping[str, List]] = None, + ) -> None: super().__init__() - self.errors = defaultdict(list) # type: Dict[str, List] + self.errors = defaultdict(list) # type: Dict[str, List] self.warnings = defaultdict(list) # type: Dict[str, List] if errors is not None: self.errors.update(errors) @@ -31,10 +31,10 @@ class DnsvizDomainResult: def _find_keys( - kv_iter: Iterable[Tuple[Any, Any]], - keys: Optional[List[Any]] = None, - path: Optional[List[Any]] = None - ) -> Iterator[Tuple[Any, Any]]: + kv_iter: Iterable[Tuple[Any, Any]], + keys: Optional[List[Any]] = None, + path: Optional[List[Any]] = None, +) -> Iterator[Tuple[Any, Any]]: assert isinstance(keys, list) if path is None: path = [] @@ -59,25 +59,26 @@ class DnsvizGrok(dict): domain = key_path[0] if domain not in self.domains: self.domains[domain] = DnsvizDomainResult() - path = '_'.join([str(value) for value in key_path]) + path = "_".join([str(value) for value in key_path]) if key_path[-1] in KEYS_ERROR: self.domains[domain].errors[path] += messages else: self.domains[domain].warnings[path] += messages @staticmethod - def from_json(filename: str) -> 'DnsvizGrok': - with open(filename, encoding='UTF-8') as f: + def from_json(filename: str) -> "DnsvizGrok": + with open(filename, encoding="UTF-8") as f: grok_data = json.load(f) if not isinstance(grok_data, dict): raise RuntimeError( - "File {} doesn't contain dnsviz grok json data".format(filename)) + "File {} doesn't contain dnsviz grok json data".format(filename) + ) return DnsvizGrok(grok_data) def error_domains(self) -> List[dns.name.Name]: err_domains = [] for domain, data in self.domains.items(): if data.is_error: - assert domain[-1] == '.' + assert domain[-1] == "." err_domains.append(dns.name.from_text(domain)) return err_domains diff --git a/respdiff/match.py b/respdiff/match.py index 80e429e3623c65774fd7f0bebdf7f2cc92f13fd8..f1dac687828271216e88cf4be5ba06415ff347c8 100644 --- a/respdiff/match.py +++ b/respdiff/match.py @@ -1,7 +1,15 @@ import collections import logging from typing import ( # noqa - Any, Dict, Hashable, Iterator, Mapping, Optional, Sequence, Tuple) + Any, + Dict, + Hashable, + Iterator, + Mapping, + Optional, + Sequence, + Tuple, +) import dns.rdatatype from dns.rrset import RRset @@ -21,8 +29,10 @@ class DataMismatch(Exception): if isinstance(val, dns.rrset.RRset): return str(val) logging.warning( - 'DataMismatch: unknown value type (%s), casting to str', type(val), - stack_info=True) + "DataMismatch: unknown value type (%s), casting to str", + type(val), + stack_info=True, + ) return str(val) exp_val = convert_val_type(exp_val) @@ -37,21 +47,23 @@ class DataMismatch(Exception): @staticmethod def format_value(value: MismatchValue) -> str: if isinstance(value, list): - value = ' '.join(value) + value = " ".join(value) return str(value) def __str__(self) -> str: return "expected '{}' got '{}'".format( - self.format_value(self.exp_val), - self.format_value(self.got_val)) + self.format_value(self.exp_val), self.format_value(self.got_val) + ) def __repr__(self) -> str: - return 'DataMismatch({}, {})'.format(self.exp_val, self.got_val) + return "DataMismatch({}, {})".format(self.exp_val, self.got_val) def __eq__(self, other) -> bool: - return (isinstance(other, DataMismatch) - and tuple(self.exp_val) == tuple(other.exp_val) - and tuple(self.got_val) == tuple(other.got_val)) + return ( + isinstance(other, DataMismatch) + and tuple(self.exp_val) == tuple(other.exp_val) + and tuple(self.got_val) == tuple(other.got_val) + ) @property def key(self) -> Tuple[Hashable, Hashable]: @@ -67,14 +79,14 @@ class DataMismatch(Exception): def compare_val(exp_val: MismatchValue, got_val: MismatchValue): - """ Compare values, throw exception if different. """ + """Compare values, throw exception if different.""" if exp_val != got_val: raise DataMismatch(str(exp_val), str(got_val)) return True def compare_rrs(expected: Sequence[RRset], got: Sequence[RRset]): - """ Compare lists of RR sets, throw exception if different. """ + """Compare lists of RR sets, throw exception if different.""" for rr in expected: if rr not in got: raise DataMismatch(expected, got) @@ -87,17 +99,17 @@ def compare_rrs(expected: Sequence[RRset], got: Sequence[RRset]): def compare_rrs_types( - exp_val: Sequence[RRset], - got_val: Sequence[RRset], - compare_rrsigs: bool): + exp_val: Sequence[RRset], got_val: Sequence[RRset], compare_rrsigs: bool +): """sets of RR types in both sections must match""" + def rr_ordering_key(rrset): return rrset.covers if compare_rrsigs else rrset.rdtype def key_to_text(rrtype): if not compare_rrsigs: return dns.rdatatype.to_text(rrtype) - return 'RRSIG(%s)' % dns.rdatatype.to_text(rrtype) + return "RRSIG(%s)" % dns.rdatatype.to_text(rrtype) def filter_by_rrsig(seq, rrsig): for el in seq: @@ -105,51 +117,60 @@ def compare_rrs_types( if el_rrsig == rrsig: yield el - exp_types = frozenset(rr_ordering_key(rrset) - for rrset in filter_by_rrsig(exp_val, compare_rrsigs)) - got_types = frozenset(rr_ordering_key(rrset) - for rrset in filter_by_rrsig(got_val, compare_rrsigs)) + exp_types = frozenset( + rr_ordering_key(rrset) for rrset in filter_by_rrsig(exp_val, compare_rrsigs) + ) + got_types = frozenset( + rr_ordering_key(rrset) for rrset in filter_by_rrsig(got_val, compare_rrsigs) + ) if exp_types != got_types: raise DataMismatch( tuple(key_to_text(i) for i in sorted(exp_types)), - tuple(key_to_text(i) for i in sorted(got_types))) + tuple(key_to_text(i) for i in sorted(got_types)), + ) def match_part( # pylint: disable=inconsistent-return-statements - exp_msg: dns.message.Message, - got_msg: dns.message.Message, - criteria: FieldLabel - ): - """ Compare scripted reply to given message using single criteria. """ - if criteria == 'opcode': + exp_msg: dns.message.Message, got_msg: dns.message.Message, criteria: FieldLabel +): + """Compare scripted reply to given message using single criteria.""" + if criteria == "opcode": return compare_val(exp_msg.opcode(), got_msg.opcode()) - elif criteria == 'flags': - return compare_val(dns.flags.to_text(exp_msg.flags), dns.flags.to_text(got_msg.flags)) - elif criteria == 'rcode': - return compare_val(dns.rcode.to_text(exp_msg.rcode()), dns.rcode.to_text(got_msg.rcode())) - elif criteria == 'question': + elif criteria == "flags": + return compare_val( + dns.flags.to_text(exp_msg.flags), dns.flags.to_text(got_msg.flags) + ) + elif criteria == "rcode": + return compare_val( + dns.rcode.to_text(exp_msg.rcode()), dns.rcode.to_text(got_msg.rcode()) + ) + elif criteria == "question": question_match = compare_rrs(exp_msg.question, got_msg.question) if not exp_msg.question: # 0 RRs, nothing else to compare return True - assert len(exp_msg.question) == 1, "multiple question in single DNS query unsupported" - case_match = compare_val(got_msg.question[0].name.labels, exp_msg.question[0].name.labels) + assert ( + len(exp_msg.question) == 1 + ), "multiple question in single DNS query unsupported" + case_match = compare_val( + got_msg.question[0].name.labels, exp_msg.question[0].name.labels + ) return question_match and case_match - elif criteria in ('answer', 'ttl'): + elif criteria in ("answer", "ttl"): return compare_rrs(exp_msg.answer, got_msg.answer) - elif criteria == 'answertypes': + elif criteria == "answertypes": return compare_rrs_types(exp_msg.answer, got_msg.answer, compare_rrsigs=False) - elif criteria == 'answerrrsigs': + elif criteria == "answerrrsigs": return compare_rrs_types(exp_msg.answer, got_msg.answer, compare_rrsigs=True) - elif criteria == 'authority': + elif criteria == "authority": return compare_rrs(exp_msg.authority, got_msg.authority) - elif criteria == 'additional': + elif criteria == "additional": return compare_rrs(exp_msg.additional, got_msg.additional) - elif criteria == 'edns': + elif criteria == "edns": if got_msg.edns != exp_msg.edns: raise DataMismatch(str(exp_msg.edns), str(got_msg.edns)) if got_msg.payload != exp_msg.payload: raise DataMismatch(str(exp_msg.payload), str(got_msg.payload)) - elif criteria == 'nsid': + elif criteria == "nsid": nsid_opt = None for opt in exp_msg.options: if opt.otype == dns.edns.NSID: @@ -159,23 +180,21 @@ def match_part( # pylint: disable=inconsistent-return-statements for opt in got_msg.options: if opt.otype == dns.edns.NSID: if not nsid_opt: - raise DataMismatch('', str(opt.data)) + raise DataMismatch("", str(opt.data)) if opt == nsid_opt: return True else: raise DataMismatch(str(nsid_opt.data), str(opt.data)) if nsid_opt: - raise DataMismatch(str(nsid_opt.data), '') + raise DataMismatch(str(nsid_opt.data), "") else: raise NotImplementedError('unknown match request "%s"' % criteria) def match( - expected: DNSReply, - got: DNSReply, - match_fields: Sequence[FieldLabel] - ) -> Iterator[Tuple[FieldLabel, DataMismatch]]: - """ Compare scripted reply to given message based on match criteria. """ + expected: DNSReply, got: DNSReply, match_fields: Sequence[FieldLabel] +) -> Iterator[Tuple[FieldLabel, DataMismatch]]: + """Compare scripted reply to given message based on match criteria.""" exp_msg, exp_res = expected.parse_wire() got_msg, got_res = got.parse_wire() exp_malformed = exp_res != DNSReply.WIREFORMAT_VALID @@ -183,15 +202,16 @@ def match( if expected.timeout or got.timeout: if not expected.timeout: - yield 'timeout', DataMismatch('answer', 'timeout') + yield "timeout", DataMismatch("answer", "timeout") if not got.timeout: - yield 'timeout', DataMismatch('timeout', 'answer') + yield "timeout", DataMismatch("timeout", "answer") elif exp_malformed or got_malformed: if exp_res == got_res: logging.warning( - 'match: DNS replies malformed in the same way! (%s)', exp_res) + "match: DNS replies malformed in the same way! (%s)", exp_res + ) else: - yield 'malformed', DataMismatch(exp_res, got_res) + yield "malformed", DataMismatch(exp_res, got_res) if expected.timeout or got.timeout or exp_malformed or got_malformed: return # don't attempt to match any other fields @@ -208,19 +228,19 @@ def match( def diff_pair( - answers: Mapping[ResolverID, DNSReply], - criteria: Sequence[FieldLabel], - name1: ResolverID, - name2: ResolverID - ) -> Iterator[Tuple[FieldLabel, DataMismatch]]: + answers: Mapping[ResolverID, DNSReply], + criteria: Sequence[FieldLabel], + name1: ResolverID, + name2: ResolverID, +) -> Iterator[Tuple[FieldLabel, DataMismatch]]: yield from match(answers[name1], answers[name2], criteria) def transitive_equality( - answers: Mapping[ResolverID, DNSReply], - criteria: Sequence[FieldLabel], - resolvers: Sequence[ResolverID] - ) -> bool: + answers: Mapping[ResolverID, DNSReply], + criteria: Sequence[FieldLabel], + resolvers: Sequence[ResolverID], +) -> bool: """ Compare answers from all resolvers. Optimization is based on transitivity of equivalence relation. @@ -228,16 +248,19 @@ def transitive_equality( assert len(resolvers) >= 2 res_a = resolvers[0] # compare all others to this resolver res_others = resolvers[1:] - return all(map( - lambda res_b: not any(diff_pair(answers, criteria, res_a, res_b)), - res_others)) + return all( + map( + lambda res_b: not any(diff_pair(answers, criteria, res_a, res_b)), + res_others, + ) + ) def compare( - answers: Mapping[ResolverID, DNSReply], - criteria: Sequence[FieldLabel], - target: ResolverID - ) -> Tuple[bool, Optional[Mapping[FieldLabel, DataMismatch]]]: + answers: Mapping[ResolverID, DNSReply], + criteria: Sequence[FieldLabel], + target: ResolverID, +) -> Tuple[bool, Optional[Mapping[FieldLabel, DataMismatch]]]: others = list(answers.keys()) try: others.remove(target) diff --git a/respdiff/qstats.py b/respdiff/qstats.py index 94a155ab4728e4b1f6f5475267d2952a817cd31d..33b895f336bf915b2b09a82910f9637361396a16 100644 --- a/respdiff/qstats.py +++ b/respdiff/qstats.py @@ -7,8 +7,12 @@ from .dataformat import DiffReport, JSONDataObject, QueryData from .typing import QID -UPSTREAM_UNSTABLE_THRESHOLD = 0.1 # consider query unstable when 10 % of results are unstable -ALLOWED_FAIL_THRESHOLD = 0.05 # ignore up to 5 % of FAIL results for a given query (as noise) +UPSTREAM_UNSTABLE_THRESHOLD = ( + 0.1 # consider query unstable when 10 % of results are unstable +) +ALLOWED_FAIL_THRESHOLD = ( + 0.05 # ignore up to 5 % of FAIL results for a given query (as noise) +) class QueryStatus(Enum): @@ -27,16 +31,16 @@ def get_query_status(query_data: QueryData) -> QueryStatus: class QueryStatistics(JSONDataObject): _ATTRIBUTES = { - 'failing': (set, list), - 'unstable': (set, list), + "failing": (set, list), + "unstable": (set, list), } def __init__( - self, - failing: Optional[Set[QID]] = None, - unstable: Optional[Set[QID]] = None, - _restore_dict: Optional[Mapping[str, Any]] = None - ) -> None: + self, + failing: Optional[Set[QID]] = None, + unstable: Optional[Set[QID]] = None, + _restore_dict: Optional[Mapping[str, Any]] = None, + ) -> None: super().__init__() self.failing = failing if failing is not None else set() self.unstable = unstable if unstable is not None else set() @@ -51,9 +55,9 @@ class QueryStatistics(JSONDataObject): self.unstable.add(qid) @staticmethod - def from_reports(reports: Sequence[DiffReport]) -> 'QueryStatistics': + def from_reports(reports: Sequence[DiffReport]) -> "QueryStatistics": """Create query statistics from multiple reports - usually used as a reference""" - others_disagree = collections.Counter() # type: collections.Counter + others_disagree = collections.Counter() # type: collections.Counter target_disagrees = collections.Counter() # type: collections.Counter reprodata_present = False @@ -68,7 +72,9 @@ class QueryStatistics(JSONDataObject): for qid in report.target_disagreements: target_disagrees[qid] += 1 if reprodata_present: - logging.warning("reprodata ignored when creating query stability statistics") + logging.warning( + "reprodata ignored when creating query stability statistics" + ) # evaluate total = len(reports) @@ -77,5 +83,6 @@ class QueryStatistics(JSONDataObject): suspect_queries.update(target_disagrees.keys()) for qid in suspect_queries: query_statistics.add_query( - qid, QueryData(total, others_disagree[qid], target_disagrees[qid])) + qid, QueryData(total, others_disagree[qid], target_disagrees[qid]) + ) return query_statistics diff --git a/respdiff/query.py b/respdiff/query.py index 99b52b6b7605698c13e42159cdd794e926f40c37..9b6b0ef9435f775ea758eb67cc7ad65dbf399b10 100644 --- a/respdiff/query.py +++ b/respdiff/query.py @@ -9,10 +9,7 @@ from .database import LMDB, qid2key from .typing import QID, WireFormat -def get_query_iterator( - lmdb_, - qids: Iterable[QID] - ) -> Iterator[Tuple[QID, WireFormat]]: +def get_query_iterator(lmdb_, qids: Iterable[QID]) -> Iterator[Tuple[QID, WireFormat]]: qdb = lmdb_.get_db(LMDB.QUERIES) with lmdb_.env.begin(qdb) as txn: for qid in qids: @@ -25,9 +22,9 @@ def qwire_to_qname(qwire: WireFormat) -> str: try: qmsg = dns.message.from_wire(qwire) except dns.exception.DNSException as exc: - raise ValueError('unable to parse qname from wire format') from exc + raise ValueError("unable to parse qname from wire format") from exc if not qmsg.question: - raise ValueError('no qname in wire format') + raise ValueError("no qname in wire format") return qmsg.question[0].name @@ -36,12 +33,12 @@ def qwire_to_qname_qtype(qwire: WireFormat) -> str: try: qmsg = dns.message.from_wire(qwire) except dns.exception.DNSException as exc: - raise ValueError('unable to parse qname from wire format') from exc + raise ValueError("unable to parse qname from wire format") from exc if not qmsg.question: - raise ValueError('no qname in wire format') - return '{} {}'.format( - qmsg.question[0].name, - dns.rdatatype.to_text(qmsg.question[0].rdtype)) + raise ValueError("no qname in wire format") + return "{} {}".format( + qmsg.question[0].name, dns.rdatatype.to_text(qmsg.question[0].rdtype) + ) def qwire_to_msgid_qname_qtype(qwire: WireFormat) -> str: @@ -49,46 +46,47 @@ def qwire_to_msgid_qname_qtype(qwire: WireFormat) -> str: try: qmsg = dns.message.from_wire(qwire) except dns.exception.DNSException as exc: - raise ValueError('unable to parse qname from wire format') from exc + raise ValueError("unable to parse qname from wire format") from exc if not qmsg.question: - raise ValueError('no qname in wire format') - return '[{:05d}] {} {}'.format( - qmsg.id, - qmsg.question[0].name, - dns.rdatatype.to_text(qmsg.question[0].rdtype)) + raise ValueError("no qname in wire format") + return "[{:05d}] {} {}".format( + qmsg.id, qmsg.question[0].name, dns.rdatatype.to_text(qmsg.question[0].rdtype) + ) def convert_queries( - query_iterator: Iterator[Tuple[QID, WireFormat]], - qwire_to_text_func: Callable[[WireFormat], str] = qwire_to_qname_qtype - ) -> Counter: + query_iterator: Iterator[Tuple[QID, WireFormat]], + qwire_to_text_func: Callable[[WireFormat], str] = qwire_to_qname_qtype, +) -> Counter: qcounter = Counter() # type: Counter for qid, qwire in query_iterator: try: text = qwire_to_text_func(qwire) except ValueError as exc: - logging.debug('Omitting QID %d: %s', qid, exc) + logging.debug("Omitting QID %d: %s", qid, exc) else: qcounter[text] += 1 return qcounter def get_printable_queries_format( - queries_mismatch: Counter, - queries_all: Counter = None, # all queries (needed for comparison with ref) - ref_queries_mismatch: Counter = None, # ref queries for the same mismatch - ref_queries_all: Counter = None # ref queries from all mismatches - ) -> Sequence[Tuple[str, int, str]]: + queries_mismatch: Counter, + queries_all: Counter = None, # all queries (needed for comparison with ref) + ref_queries_mismatch: Counter = None, # ref queries for the same mismatch + ref_queries_all: Counter = None, # ref queries from all mismatches +) -> Sequence[Tuple[str, int, str]]: def get_query_diff(query: str) -> str: - if (ref_queries_mismatch is None - or ref_queries_all is None - or queries_all is None): - return ' ' # no reference to compare to + if ( + ref_queries_mismatch is None + or ref_queries_all is None + or queries_all is None + ): + return " " # no reference to compare to if query in queries_mismatch and query not in ref_queries_all: - return '+' # previously unseen query has appeared + return "+" # previously unseen query has appeared if query in ref_queries_mismatch and query not in queries_all: - return '-' # query no longer appears in any mismatch category - return ' ' # no change, or query has moved to a different mismatch category + return "-" # query no longer appears in any mismatch category + return " " # no change, or query has moved to a different mismatch category query_set = set(queries_mismatch.keys()) if ref_queries_mismatch is not None: @@ -101,9 +99,9 @@ def get_printable_queries_format( for query in query_set: diff = get_query_diff(query) count = queries_mismatch[query] - if diff == ' ' and count == 0: + if diff == " " and count == 0: continue # omit queries that just moved between categories - if diff == '-': + if diff == "-": assert ref_queries_mismatch is not None count = ref_queries_mismatch[query] # show how many cases were removed queries.append((diff, count, query)) diff --git a/respdiff/repro.py b/respdiff/repro.py index e9af03042f5780d029057b51ed52ce8a20235ad7..8828887c3bdf6923f1f08cbea2aed505200b367a 100644 --- a/respdiff/repro.py +++ b/respdiff/repro.py @@ -4,11 +4,27 @@ from multiprocessing import pool import random import subprocess from typing import ( # noqa - AbstractSet, Any, Iterator, Iterable, Mapping, Optional, Sequence, Tuple, - TypeVar, Union) + AbstractSet, + Any, + Iterator, + Iterable, + Mapping, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) from .database import ( - DNSRepliesFactory, DNSReply, key2qid, ResolverID, qid2key, QKey, WireFormat) + DNSRepliesFactory, + DNSReply, + key2qid, + ResolverID, + qid2key, + QKey, + WireFormat, +) from .dataformat import Diff, DiffReport, FieldLabel from .match import compare from .query import get_query_iterator @@ -16,25 +32,25 @@ from .sendrecv import worker_perform_single_query from .typing import QID # noqa -T = TypeVar('T') +T = TypeVar("T") def restart_resolver(script_path: str) -> None: try: subprocess.check_call(script_path) except subprocess.CalledProcessError as exc: - logging.warning('Resolver restart failed (exit code %d): %s', - exc.returncode, script_path) + logging.warning( + "Resolver restart failed (exit code %d): %s", exc.returncode, script_path + ) except PermissionError: - logging.warning('Resolver restart failed (permission error): %s', - script_path) + logging.warning("Resolver restart failed (permission error): %s", script_path) def get_restart_scripts(config: Mapping[str, Any]) -> Mapping[ResolverID, str]: restart_scripts = {} - for resolver in config['servers']['names']: + for resolver in config["servers"]["names"]: try: - restart_scripts[resolver] = config[resolver]['restart_script'] + restart_scripts[resolver] = config[resolver]["restart_script"] except KeyError: logging.warning('No restart script available for "%s"!', resolver) return restart_scripts @@ -51,12 +67,12 @@ def chunker(iterable: Iterable[T], size: int) -> Iterator[Iterable[T]]: def process_answers( - qkey: QKey, - answers: Mapping[ResolverID, DNSReply], - report: DiffReport, - criteria: Sequence[FieldLabel], - target: ResolverID - ) -> None: + qkey: QKey, + answers: Mapping[ResolverID, DNSReply], + report: DiffReport, + criteria: Sequence[FieldLabel], + target: ResolverID, +) -> None: if report.target_disagreements is None or report.reprodata is None: raise RuntimeError("Report doesn't contain necessary data!") qid = key2qid(qkey) @@ -75,15 +91,17 @@ def process_answers( def query_stream_from_disagreements( - lmdb, - report: DiffReport, - skip_unstable: bool = True, - skip_non_reproducible: bool = True, - shuffle: bool = True - ) -> Iterator[Tuple[QKey, WireFormat]]: + lmdb, + report: DiffReport, + skip_unstable: bool = True, + skip_non_reproducible: bool = True, + shuffle: bool = True, +) -> Iterator[Tuple[QKey, WireFormat]]: if report.target_disagreements is None or report.reprodata is None: raise RuntimeError("Report doesn't contain necessary data!") - qids = report.target_disagreements.keys() # type: Union[Sequence[QID], AbstractSet[QID]] + qids = ( + report.target_disagreements.keys() + ) # type: Union[Sequence[QID], AbstractSet[QID]] if shuffle: # create a new, randomized list from disagreements qids = list(qids) @@ -94,23 +112,23 @@ def query_stream_from_disagreements( reprocounter = report.reprodata[qid] # verify if answers are stable if skip_unstable and reprocounter.retries != reprocounter.upstream_stable: - logging.debug('Skipping QID %7d: unstable upstream', diff.qid) + logging.debug("Skipping QID %7d: unstable upstream", diff.qid) continue if skip_non_reproducible and reprocounter.retries != reprocounter.verified: - logging.debug('Skipping QID %7d: not 100 %% reproducible', diff.qid) + logging.debug("Skipping QID %7d: not 100 %% reproducible", diff.qid) continue yield qid2key(qid), qwire def reproduce_queries( - query_stream: Iterator[Tuple[QKey, WireFormat]], - report: DiffReport, - dnsreplies_factory: DNSRepliesFactory, - criteria: Sequence[FieldLabel], - target: ResolverID, - restart_scripts: Optional[Mapping[ResolverID, str]] = None, - nproc: int = 1 - ) -> None: + query_stream: Iterator[Tuple[QKey, WireFormat]], + report: DiffReport, + dnsreplies_factory: DNSRepliesFactory, + criteria: Sequence[FieldLabel], + target: ResolverID, + restart_scripts: Optional[Mapping[ResolverID, str]] = None, + nproc: int = 1, +) -> None: if restart_scripts is None: restart_scripts = {} with pool.Pool(processes=nproc) as p: @@ -121,12 +139,14 @@ def reproduce_queries( restart_resolver(script) process_args = [args for args in process_args if args is not None] - for qkey, replies_data, in p.imap_unordered( - worker_perform_single_query, - process_args, - chunksize=1): + for ( + qkey, + replies_data, + ) in p.imap_unordered( + worker_perform_single_query, process_args, chunksize=1 + ): replies = dnsreplies_factory.parse(replies_data) process_answers(qkey, replies, report, criteria, target) done += len(process_args) - logging.debug('Processed {:4d} queries'.format(done)) + logging.debug("Processed {:4d} queries".format(done)) diff --git a/respdiff/sendrecv.py b/respdiff/sendrecv.py index 5920ecba61f370d61e664941bbe370c3d54f5add..9330c724eb2977d29c865030a0ca41b25b7a1e8f 100644 --- a/respdiff/sendrecv.py +++ b/respdiff/sendrecv.py @@ -8,8 +8,8 @@ The entire module keeps a global state, which enables its easy use with both threads or processes. Make sure not to break this compatibility. """ - from argparse import Namespace +import logging import random import signal import selectors @@ -18,7 +18,14 @@ import ssl import struct import time import threading -from typing import Any, Dict, List, Mapping, Sequence, Tuple # noqa: type hints +from typing import ( + Any, + Dict, + Mapping, + Optional, + Sequence, + Tuple, +) # noqa: type hints import dns.inet import dns.message @@ -40,7 +47,9 @@ ResolverSockets = Sequence[Tuple[ResolverID, Socket, IsStreamFlag]] # module-wide state variables __resolvers = [] # type: Sequence[Tuple[ResolverID, IP, Protocol, Port]] __worker_state = threading.local() -__max_timeouts = 10 # crash when N consecutive timeouts are received from a single resolver +__max_timeouts = ( + 10 # crash when N consecutive timeouts are received from a single resolver +) __ignore_timeout = False __timeout = 16 __time_delay_min = 0 @@ -56,6 +65,15 @@ class TcpDnsLengthError(ConnectionError): pass +class ResolverConnectionError(ConnectionError): + def __init__(self, resolver: ResolverID, message: str): + super().__init__(message) + self.resolver = resolver + + def __str__(self): + return f"[{self.resolver}] {super().__str__()}" + + def module_init(args: Namespace) -> None: global __resolvers global __max_timeouts @@ -66,11 +84,11 @@ def module_init(args: Namespace) -> None: global __dnsreplies_factory __resolvers = get_resolvers(args.cfg) - __timeout = args.cfg['sendrecv']['timeout'] - __time_delay_min = args.cfg['sendrecv']['time_delay_min'] - __time_delay_max = args.cfg['sendrecv']['time_delay_max'] + __timeout = args.cfg["sendrecv"]["timeout"] + __time_delay_min = args.cfg["sendrecv"]["time_delay_min"] + __time_delay_max = args.cfg["sendrecv"]["time_delay_max"] try: - __max_timeouts = args.cfg['sendrecv']['max_timeouts'] + __max_timeouts = args.cfg["sendrecv"]["max_timeouts"] except KeyError: pass try: @@ -125,7 +143,9 @@ def worker_perform_query(args: Tuple[QKey, WireFormat]) -> Tuple[QKey, RepliesBl return qkey, blob -def worker_perform_single_query(args: Tuple[QKey, WireFormat]) -> Tuple[QKey, RepliesBlob]: +def worker_perform_single_query( + args: Tuple[QKey, WireFormat] +) -> Tuple[QKey, RepliesBlob]: """Perform a single DNS query with setup and teardown of sockets. Used by diffrepro.""" qkey, qwire = args worker_reinit() @@ -140,12 +160,12 @@ def worker_perform_single_query(args: Tuple[QKey, WireFormat]) -> Tuple[QKey, Re def get_resolvers( - config: Mapping[str, Any] - ) -> Sequence[Tuple[ResolverID, IP, Protocol, Port]]: + config: Mapping[str, Any] +) -> Sequence[Tuple[ResolverID, IP, Protocol, Port]]: resolvers = [] - for resname in config['servers']['names']: + for resname in config["servers"]["names"]: rescfg = config[resname] - resolvers.append((resname, rescfg['ip'], rescfg['transport'], rescfg['port'])) + resolvers.append((resname, rescfg["ip"], rescfg["transport"], rescfg["port"])) return resolvers @@ -160,7 +180,9 @@ def _check_timeout(replies: Mapping[ResolverID, DNSReply]) -> None: raise RuntimeError( "Resolver '{}' timed-out {:d} times in a row. " "Use '--ignore-timeout' to supress this error.".format( - resolver, __max_timeouts)) + resolver, __max_timeouts + ) + ) def make_ssl_context(): @@ -169,19 +191,19 @@ def make_ssl_context(): # https://docs.python.org/3/library/ssl.html#tls-1-3 # NOTE forcing TLS v1.2 is hacky, because of different py3/openssl versions... - if getattr(ssl, 'PROTOCOL_TLS', None) is not None: + if getattr(ssl, "PROTOCOL_TLS", None) is not None: context = ssl.SSLContext(ssl.PROTOCOL_TLS) # pylint: disable=no-member else: context = ssl.SSLContext() - if getattr(ssl, 'maximum_version', None) is not None: + if getattr(ssl, "maximum_version", None) is not None: context.maximum_version = ssl.TLSVersion.TLSv1_2 # pylint: disable=no-member else: context.options |= ssl.OP_NO_SSLv2 context.options |= ssl.OP_NO_SSLv3 context.options |= ssl.OP_NO_TLSv1 context.options |= ssl.OP_NO_TLSv1_1 - if getattr(ssl, 'OP_NO_TLSv1_3', None) is not None: + if getattr(ssl, "OP_NO_TLSv1_3", None) is not None: context.options |= ssl.OP_NO_TLSv1_3 # pylint: disable=no-member # turn off certificate verification @@ -191,7 +213,9 @@ def make_ssl_context(): return context -def sock_init(retry: int = 3) -> Tuple[Selector, Sequence[Tuple[ResolverID, Socket, IsStreamFlag]]]: +def sock_init( + retry: int = 3, +) -> Tuple[Selector, Sequence[Tuple[ResolverID, Socket, IsStreamFlag]]]: sockets = [] selector = selectors.DefaultSelector() for name, ipaddr, transport, port in __resolvers: @@ -201,22 +225,22 @@ def sock_init(retry: int = 3) -> Tuple[Selector, Sequence[Tuple[ResolverID, Sock elif af == dns.inet.AF_INET6: destination = (ipaddr, port, 0, 0) else: - raise NotImplementedError('AF') + raise NotImplementedError("AF") - if transport in {'tcp', 'tls'}: + if transport in {"tcp", "tls"}: socktype = socket.SOCK_STREAM isstream = True - elif transport == 'udp': + elif transport == "udp": socktype = socket.SOCK_DGRAM isstream = False else: - raise NotImplementedError('transport: {}'.format(transport)) + raise NotImplementedError("transport: {}".format(transport)) # attempt to connect to socket attempt = 1 while attempt <= retry: sock = socket.socket(af, socktype, 0) - if transport == 'tls': + if transport == "tls": ctx = make_ssl_context() sock = ctx.wrap_socket(sock) try: @@ -224,7 +248,9 @@ def sock_init(retry: int = 3) -> Tuple[Selector, Sequence[Tuple[ResolverID, Sock except ConnectionRefusedError as e: # TCP socket is closed raise RuntimeError( "socket: Failed to connect to {dest[0]} port {dest[1]}".format( - dest=destination)) from e + dest=destination + ) + ) from e except OSError as exc: if exc.errno != 0 and not isinstance(exc, ConnectionResetError): raise @@ -235,7 +261,9 @@ def sock_init(retry: int = 3) -> Tuple[Selector, Sequence[Tuple[ResolverID, Sock if attempt > retry: raise RuntimeError( "socket: Failed to connect to {dest[0]} port {dest[1]}".format( - dest=destination)) from exc + dest=destination + ) + ) from exc else: break sock.setblocking(False) @@ -251,37 +279,46 @@ def _recv_msg(sock: Socket, isstream: IsStreamFlag) -> WireFormat: try: blength = sock.recv(2) except ssl.SSLWantReadError as e: - raise TcpDnsLengthError('failed to recv DNS packet length') from e - else: - if len(blength) != 2: # FIN / RST - raise TcpDnsLengthError('failed to recv DNS packet length') - (length, ) = struct.unpack('!H', blength) + raise TcpDnsLengthError("failed to recv DNS packet length") from e + if len(blength) != 2: # FIN / RST + raise TcpDnsLengthError("failed to recv DNS packet length") + (length,) = struct.unpack("!H", blength) else: length = 65535 # max. UDP message size, no IPv6 jumbograms return sock.recv(length) -def _send_recv_parallel( - dgram: WireFormat, # DNS message suitable for UDP transport - selector: Selector, - sockets: ResolverSockets, - timeout: float - ) -> Tuple[Mapping[ResolverID, DNSReply], ReinitFlag]: - replies = {} # type: Dict[ResolverID, DNSReply] - streammsg = None +def _create_sendbuf(dnsdata: WireFormat, isstream: IsStreamFlag) -> bytes: + if isstream: # prepend length, RFC 1035 section 4.2.2 + length = len(dnsdata) + return struct.pack("!H", length) + dnsdata + return dnsdata + + +def _get_resolver_from_sock( + sockets: ResolverSockets, sock: Socket +) -> Optional[ResolverID]: + for resolver, resolver_sock, _ in sockets: + if sock == resolver_sock: + return resolver + return None + + +def _recv_from_resolvers( + selector: Selector, sockets: ResolverSockets, msgid: bytes, timeout: float +) -> Tuple[Dict[ResolverID, DNSReply], bool]: + + def raise_resolver_exc(sock, exc): + resolver = _get_resolver_from_sock(sockets, sock) + if resolver is not None: + raise ResolverConnectionError(resolver, str(exc)) from exc + raise exc + start_time = time.perf_counter() end_time = start_time + timeout - for _, sock, isstream in sockets: - if isstream: # prepend length, RFC 1035 section 4.2.2 - if not streammsg: - length = len(dgram) - streammsg = struct.pack('!H', length) + dgram - sock.sendall(streammsg) - else: - sock.sendall(dgram) - - # receive replies + replies = {} # type: Dict[ResolverID, DNSReply] reinit = False + while len(replies) != len(sockets): remaining_time = end_time - time.perf_counter() if remaining_time <= 0: @@ -293,17 +330,40 @@ def _send_recv_parallel( sock = key.fileobj try: wire = _recv_msg(sock, isstream) - except TcpDnsLengthError: + except TcpDnsLengthError as exc: if name in replies: # we have a reply already, most likely TCP FIN reinit = True selector.unregister(sock) continue # receive answers from other parties - raise # no reply -> raise error - # assert len(wire) > 14 - if dgram[0:2] != wire[0:2]: + # no reply -> raise error + raise_resolver_exc(sock, exc) + except ConnectionError as exc: + raise_resolver_exc(sock, exc) + if msgid != wire[0:2]: continue # wrong msgid, this might be a delayed answer - ignore it replies[name] = DNSReply(wire, time.perf_counter() - start_time) + return replies, reinit + + +def _send_recv_parallel( + dgram: WireFormat, # DNS message suitable for UDP transport + selector: Selector, + sockets: ResolverSockets, + timeout: float, +) -> Tuple[Mapping[ResolverID, DNSReply], ReinitFlag]: + # send queries + for resolver, sock, isstream in sockets: + sendbuf = _create_sendbuf(dgram, isstream) + try: + sock.sendall(sendbuf) + except ConnectionError as exc: + raise ResolverConnectionError(resolver, str(exc)) from exc + + # receive replies + msgid = dgram[0:2] + replies, reinit = _recv_from_resolvers(selector, sockets, msgid, timeout) + # set missing replies as timeout for resolver, *_ in sockets: # type: ignore # python/mypy#465 if resolver not in replies: @@ -313,10 +373,11 @@ def _send_recv_parallel( def send_recv_parallel( - dgram: WireFormat, # DNS message suitable for UDP transport - timeout: float, - reinit_on_tcpfin: bool = True - ) -> Mapping[ResolverID, DNSReply]: + dgram: WireFormat, # DNS message suitable for UDP transport + timeout: float, + reinit_on_tcpfin: bool = True, +) -> Mapping[ResolverID, DNSReply]: + problematic = [] for _ in range(CONN_RESET_RETRIES + 1): try: # get sockets and selector selector = __worker_state.selector @@ -331,9 +392,19 @@ def send_recv_parallel( worker_deinit() worker_reinit() return replies - except (TcpDnsLengthError, ConnectionError): # most likely TCP RST + # The following exception handling is typically triggered by TCP RST, + # but it could also indicate some issue with one of the resolvers. + except ResolverConnectionError as exc: + problematic.append(exc.resolver) + logging.debug(exc) + worker_deinit() # re-establish connection + worker_reinit() + except ConnectionError as exc: # most likely TCP RST + logging.debug(exc) worker_deinit() # re-establish connection worker_reinit() raise RuntimeError( - 'ConnectionError received {} times in a row, exiting!'.format( - CONN_RESET_RETRIES + 1)) + "ConnectionError received {} times in a row ({}), exiting!".format( + CONN_RESET_RETRIES + 1, ", ".join(problematic) + ) + ) diff --git a/respdiff/stats.py b/respdiff/stats.py index f3820fb8d4950c5d87929f728d24148dfcb4ffa0..4a2ceb110c18f8158cffadfabfa0f38fbf7e3e08 100644 --- a/respdiff/stats.py +++ b/respdiff/stats.py @@ -23,24 +23,26 @@ class Stats(JSONDataObject): Example: samples = [1540, 1613, 1489] """ + _ATTRIBUTES = { - 'samples': (None, None), - 'threshold': (None, None), + "samples": (None, None), + "threshold": (None, None), } class SamplePosition(Enum): """Position of a single sample against the rest of the distribution.""" + ABOVE_REF = 1 ABOVE_THRESHOLD = 2 NORMAL = 3 BELOW_REF = 4 def __init__( - self, - samples: Sequence[float] = None, - threshold: Optional[float] = None, - _restore_dict: Optional[Mapping[str, float]] = None - ) -> None: + self, + samples: Sequence[float] = None, + threshold: Optional[float] = None, + _restore_dict: Optional[Mapping[str, float]] = None, + ) -> None: """ samples contain the entire data set of reference values of this parameter. If no custom threshold is provided, it is calculated automagically. @@ -73,9 +75,9 @@ class Stats(JSONDataObject): return max(self.samples) def get_percentile_rank(self, sample: float) -> float: - return scipy.stats.percentileofscore(self.samples, sample, kind='weak') + return scipy.stats.percentileofscore(self.samples, sample, kind="weak") - def evaluate_sample(self, sample: float) -> 'Stats.SamplePosition': + def evaluate_sample(self, sample: float) -> "Stats.SamplePosition": if sample < self.min: return Stats.SamplePosition.BELOW_REF elif sample > self.max: @@ -107,17 +109,15 @@ class MismatchStatistics(dict, JSONDataObject): """ _ATTRIBUTES = { - 'total': ( - lambda x: Stats(_restore_dict=x), - lambda x: x.save()), + "total": (lambda x: Stats(_restore_dict=x), lambda x: x.save()), } def __init__( - self, - mismatch_counters_list: Optional[Sequence[Counter]] = None, - sample_size: Optional[int] = None, - _restore_dict: Optional[Mapping[str, Any]] = None - ) -> None: + self, + mismatch_counters_list: Optional[Sequence[Counter]] = None, + sample_size: Optional[int] = None, + _restore_dict: Optional[Mapping[str, Any]] = None, + ) -> None: super().__init__() self.total = None if mismatch_counters_list is not None and sample_size is not None: @@ -128,15 +128,15 @@ class MismatchStatistics(dict, JSONDataObject): n += count mismatch_key = str(mismatch.key) samples[mismatch_key].append(count) - samples['total'].append(n) + samples["total"].append(n) # fill in missing samples for seq in samples.values(): seq.extend([0] * (sample_size - len(seq))) # create stats from samples - self.total = Stats(samples['total']) - del samples['total'] + self.total = Stats(samples["total"]) + del samples["total"] for mismatch_key, stats_seq in samples.items(): self[mismatch_key] = Stats(stats_seq) elif _restore_dict is not None: @@ -166,17 +166,18 @@ class FieldStatistics(dict, JSONDataObject): """ def __init__( - self, - summaries_list: Optional[Sequence[Summary]] = None, - _restore_dict: Optional[Mapping[str, Any]] = None - ) -> None: + self, + summaries_list: Optional[Sequence[Summary]] = None, + _restore_dict: Optional[Mapping[str, Any]] = None, + ) -> None: super().__init__() if summaries_list is not None: field_counters_list = [d.get_field_counters() for d in summaries_list] for field in ALL_FIELDS: mismatch_counters_list = [fc[field] for fc in field_counters_list] self[field] = MismatchStatistics( - mismatch_counters_list, len(summaries_list)) + mismatch_counters_list, len(summaries_list) + ) elif _restore_dict is not None: self.restore(_restore_dict) @@ -194,32 +195,20 @@ class FieldStatistics(dict, JSONDataObject): class SummaryStatistics(JSONDataObject): _ATTRIBUTES = { - 'sample_size': (None, None), - 'upstream_unstable': ( - lambda x: Stats(_restore_dict=x), - lambda x: x.save()), - 'usable_answers': ( - lambda x: Stats(_restore_dict=x), - lambda x: x.save()), - 'not_reproducible': ( - lambda x: Stats(_restore_dict=x), - lambda x: x.save()), - 'target_disagreements': ( - lambda x: Stats(_restore_dict=x), - lambda x: x.save()), - 'fields': ( - lambda x: FieldStatistics(_restore_dict=x), - lambda x: x.save()), - 'queries': ( - lambda x: QueryStatistics(_restore_dict=x), - lambda x: x.save()), + "sample_size": (None, None), + "upstream_unstable": (lambda x: Stats(_restore_dict=x), lambda x: x.save()), + "usable_answers": (lambda x: Stats(_restore_dict=x), lambda x: x.save()), + "not_reproducible": (lambda x: Stats(_restore_dict=x), lambda x: x.save()), + "target_disagreements": (lambda x: Stats(_restore_dict=x), lambda x: x.save()), + "fields": (lambda x: FieldStatistics(_restore_dict=x), lambda x: x.save()), + "queries": (lambda x: QueryStatistics(_restore_dict=x), lambda x: x.save()), } def __init__( - self, - reports: Sequence[DiffReport] = None, - _restore_dict: Mapping[str, Any] = None - ) -> None: + self, + reports: Sequence[DiffReport] = None, + _restore_dict: Mapping[str, Any] = None, + ) -> None: super().__init__() self.sample_size = None self.upstream_unstable = None @@ -233,15 +222,18 @@ class SummaryStatistics(JSONDataObject): usable_reports = [] for report in reports: if report.summary is None: - logging.warning('Empty diffsum in %s Omitting...', report.fileorigin) + logging.warning( + "Empty diffsum in %s Omitting...", report.fileorigin + ) else: usable_reports.append(report) summaries = [ - report.summary for report in reports if report.summary is not None] + report.summary for report in reports if report.summary is not None + ] assert len(summaries) == len(usable_reports) if not summaries: - raise ValueError('No summaries found in reports!') + raise ValueError("No summaries found in reports!") self.sample_size = len(summaries) self.upstream_unstable = Stats([s.upstream_unstable for s in summaries]) diff --git a/statcmp.py b/statcmp.py index c82adf3aad4e96fac53e0541219d6c4aa8d62fa3..3039a68e8ba4a6c05a1647ba6426a895ec4d670b 100755 --- a/statcmp.py +++ b/statcmp.py @@ -14,16 +14,17 @@ from respdiff.typing import FieldLabel import matplotlib import matplotlib.axes import matplotlib.ticker -matplotlib.use('Agg') + +matplotlib.use("Agg") import matplotlib.pyplot as plt # noqa -COLOR_OK = 'tab:blue' -COLOR_GOOD = 'tab:green' -COLOR_BAD = 'xkcd:bright red' -COLOR_THRESHOLD = 'tab:orange' -COLOR_BG = 'tab:gray' -COLOR_LABEL = 'black' +COLOR_OK = "tab:blue" +COLOR_GOOD = "tab:green" +COLOR_BAD = "xkcd:bright red" +COLOR_THRESHOLD = "tab:orange" +COLOR_BG = "tab:gray" +COLOR_LABEL = "black" VIOLIN_FIGSIZE = (3, 6) @@ -36,7 +37,9 @@ SAMPLE_COLORS = { class AxisMarker: - def __init__(self, position: float, width: float = 0.7, color: str = COLOR_BG) -> None: + def __init__( + self, position: float, width: float = 0.7, color: str = COLOR_BG + ) -> None: self.position = position self.width = width self.color = color @@ -48,19 +51,20 @@ class AxisMarker: def plot_violin( - ax: matplotlib.axes.Axes, - violin_data: Sequence[float], - markers: Sequence[AxisMarker], - label: str, - color: str = COLOR_LABEL - ) -> None: - ax.set_title(label, fontdict={'fontsize': 14}, color=color) + ax: matplotlib.axes.Axes, + violin_data: Sequence[float], + markers: Sequence[AxisMarker], + label: str, + color: str = COLOR_LABEL, +) -> None: + ax.set_title(label, fontdict={"fontsize": 14}, color=color) # plot violin graph - violin_parts = ax.violinplot(violin_data, bw_method=0.07, - showmedians=False, showextrema=False) + violin_parts = ax.violinplot( + violin_data, bw_method=0.07, showmedians=False, showextrema=False + ) # set violin background color - for pc in violin_parts['bodies']: + for pc in violin_parts["bodies"]: pc.set_facecolor(COLOR_BG) pc.set_edgecolor(COLOR_BG) @@ -69,10 +73,10 @@ def plot_violin( marker.draw(ax) # turn off axis spines - for sp in ['right', 'top', 'bottom']: - ax.spines[sp].set_color('none') + for sp in ["right", "top", "bottom"]: + ax.spines[sp].set_color("none") # move the left ax spine to center - ax.spines['left'].set_position(('data', 1)) + ax.spines["left"].set_position(("data", 1)) # customize axis ticks ax.xaxis.set_major_locator(matplotlib.ticker.NullLocator()) @@ -80,8 +84,11 @@ def plot_violin( if max(violin_data) == 0: # fix tick at 0 when there's no data ax.yaxis.set_major_locator(matplotlib.ticker.FixedLocator([0])) else: - ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator( - nbins='auto', steps=[1, 2, 4, 5, 10], integer=True)) + ax.yaxis.set_major_locator( + matplotlib.ticker.MaxNLocator( + nbins="auto", steps=[1, 2, 4, 5, 10], integer=True + ) + ) ax.yaxis.set_minor_locator(matplotlib.ticker.NullLocator()) ax.tick_params(labelsize=14) @@ -99,34 +106,41 @@ def _axes_iter(axes, width: int): def eval_and_plot_single( - ax: matplotlib.axes.Axes, - stats: Stats, - label: str, - samples: Sequence[float] - ) -> bool: + ax: matplotlib.axes.Axes, stats: Stats, label: str, samples: Sequence[float] +) -> bool: markers = [] below_min = False above_thr = False for sample in samples: result = stats.evaluate_sample(sample) markers.append(AxisMarker(sample, color=SAMPLE_COLORS[result])) - if result in (Stats.SamplePosition.ABOVE_REF, - Stats.SamplePosition.ABOVE_THRESHOLD): + if result in ( + Stats.SamplePosition.ABOVE_REF, + Stats.SamplePosition.ABOVE_THRESHOLD, + ): above_thr = True logging.error( - ' %s: threshold exceeded! sample: %d / %4.2f%% vs threshold: %d / %4.2f%%', - label, sample, stats.get_percentile_rank(sample), - stats.threshold, stats.get_percentile_rank(stats.threshold)) + " %s: threshold exceeded! sample: %d / %4.2f%% vs threshold: %d / %4.2f%%", + label, + sample, + stats.get_percentile_rank(sample), + stats.threshold, + stats.get_percentile_rank(stats.threshold), + ) elif result == Stats.SamplePosition.BELOW_REF: below_min = True logging.info( - ' %s: new minimum found! new: %d vs prev: %d', - label, sample, stats.min) + " %s: new minimum found! new: %d vs prev: %d", label, sample, stats.min + ) else: logging.info( - ' %s: ok! sample: %d / %4.2f%% vs threshold: %d / %4.2f%%', - label, sample, stats.get_percentile_rank(sample), - stats.threshold, stats.get_percentile_rank(stats.threshold)) + " %s: ok! sample: %d / %4.2f%% vs threshold: %d / %4.2f%%", + label, + sample, + stats.get_percentile_rank(sample), + stats.threshold, + stats.get_percentile_rank(stats.threshold), + ) # add min/med/max markers markers.append(AxisMarker(stats.min, 0.5, COLOR_BG)) @@ -148,11 +162,11 @@ def eval_and_plot_single( def plot_overview( - sumstats: SummaryStatistics, - fields: Sequence[FieldLabel], - summaries: Optional[Sequence[Summary]] = None, - label: str = 'fields_overview' - ) -> bool: + sumstats: SummaryStatistics, + fields: Sequence[FieldLabel], + summaries: Optional[Sequence[Summary]] = None, + label: str = "fields_overview", +) -> bool: """ Plot an overview of all fields using violing graphs. If any summaries are provided, they are drawn in the graphs and also evaluated. If any sample in any field exceeds @@ -169,26 +183,33 @@ def plot_overview( fig, axes = plt.subplots( OVERVIEW_Y_FIG, OVERVIEW_X_FIG, - figsize=(OVERVIEW_X_FIG*VIOLIN_FIGSIZE[0], OVERVIEW_Y_FIG*VIOLIN_FIGSIZE[1])) + figsize=( + OVERVIEW_X_FIG * VIOLIN_FIGSIZE[0], + OVERVIEW_Y_FIG * VIOLIN_FIGSIZE[1], + ), + ) ax_it = _axes_iter(axes, OVERVIEW_X_FIG) # target disagreements assert sumstats.target_disagreements is not None samples = [len(summary) for summary in summaries] passed &= eval_and_plot_single( - next(ax_it), sumstats.target_disagreements, 'target_disagreements', samples) + next(ax_it), sumstats.target_disagreements, "target_disagreements", samples + ) # upstream unstable assert sumstats.upstream_unstable is not None samples = [summary.upstream_unstable for summary in summaries] passed &= eval_and_plot_single( - next(ax_it), sumstats.upstream_unstable, 'upstream_unstable', samples) + next(ax_it), sumstats.upstream_unstable, "upstream_unstable", samples + ) # not 100% reproducible assert sumstats.not_reproducible is not None samples = [summary.not_reproducible for summary in summaries] passed &= eval_and_plot_single( - next(ax_it), sumstats.not_reproducible, 'not_reproducible', samples) + next(ax_it), sumstats.not_reproducible, "not_reproducible", samples + ) # fields assert sumstats.fields is not None @@ -201,7 +222,8 @@ def plot_overview( next(ax_it), sumstats.fields[field].total, field, - [len(list(fc[field].elements())) for fc in fcs]) + [len(list(fc[field].elements())) for fc in fcs], + ) # hide unused axis for ax in ax_it: @@ -209,15 +231,21 @@ def plot_overview( # display sample size fig.text( - 0.95, 0.95, - 'stat sample size: {}'.format(len(sumstats.target_disagreements.samples)), - fontsize=18, color=COLOR_BG, ha='right', va='bottom', alpha=0.7) + 0.95, + 0.95, + "stat sample size: {}".format(len(sumstats.target_disagreements.samples)), + fontsize=18, + color=COLOR_BG, + ha="right", + va="bottom", + alpha=0.7, + ) # save image plt.tight_layout() plt.subplots_adjust(top=0.9) fig.suptitle(label, fontsize=22) - plt.savefig('{}.png'.format(label)) + plt.savefig("{}.png".format(label)) plt.close() return passed @@ -226,25 +254,32 @@ def plot_overview( def main(): cli.setup_logging() parser = argparse.ArgumentParser( - description=("Plot and compare reports against statistical data. " - "Returns non-zero exit code if any threshold is exceeded.")) + description=( + "Plot and compare reports against statistical data. " + "Returns non-zero exit code if any threshold is exceeded." + ) + ) cli.add_arg_stats(parser) cli.add_arg_report(parser) cli.add_arg_config(parser) - parser.add_argument('-l', '--label', default='fields_overview', - help='Set plot label. It is also used for the filename.') + parser.add_argument( + "-l", + "--label", + default="fields_overview", + help="Set plot label. It is also used for the filename.", + ) args = parser.parse_args() sumstats = args.stats - field_weights = args.cfg['report']['field_weights'] + field_weights = args.cfg["report"]["field_weights"] try: summaries = cli.load_summaries(args.report) except ValueError: sys.exit(1) - logging.info('Start Comparison: %s', args.label) + logging.info("Start Comparison: %s", args.label) passed = plot_overview(sumstats, field_weights, summaries, args.label) if not passed: @@ -253,5 +288,5 @@ def main(): sys.exit(0) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/sumcmp.py b/sumcmp.py index f73e003248906deb2c3ba5f5eaf4ee43704955f2..d8d1b48a9a171332768eab19177a77ca2861455d 100755 --- a/sumcmp.py +++ b/sumcmp.py @@ -10,7 +10,10 @@ from respdiff import cli from respdiff.database import LMDB from respdiff.dataformat import DiffReport from respdiff.query import ( - convert_queries, get_printable_queries_format, get_query_iterator) + convert_queries, + get_printable_queries_format, + get_query_iterator, +) ANSWERS_DIFFERENCE_THRESHOLD_WARNING = 0.05 @@ -25,27 +28,33 @@ def check_report_summary(report: DiffReport): def check_usable_answers(report: DiffReport, ref_report: DiffReport): if report.summary is None or ref_report.summary is None: raise RuntimeError("Report doesn't contain necessary data!") - answers_difference = math.fabs( - report.summary.usable_answers - ref_report.summary.usable_answers - ) / ref_report.summary.usable_answers + answers_difference = ( + math.fabs(report.summary.usable_answers - ref_report.summary.usable_answers) + / ref_report.summary.usable_answers + ) if answers_difference >= ANSWERS_DIFFERENCE_THRESHOLD_WARNING: - logging.warning('Number of usable answers changed by {:.1f} %!'.format( - answers_difference * 100.0)) + logging.warning( + "Number of usable answers changed by {:.1f} %!".format( + answers_difference * 100.0 + ) + ) def main(): cli.setup_logging() - parser = argparse.ArgumentParser(description='compare two diff summaries') + parser = argparse.ArgumentParser(description="compare two diff summaries") cli.add_arg_config(parser) - parser.add_argument('old_datafile', type=str, help='report to compare against') - parser.add_argument('new_datafile', type=str, help='report to compare evaluate') - cli.add_arg_envdir(parser) # TODO remove when we no longer need to read queries from lmdb + parser.add_argument("old_datafile", type=str, help="report to compare against") + parser.add_argument("new_datafile", type=str, help="report to compare evaluate") + cli.add_arg_envdir( + parser + ) # TODO remove when we no longer need to read queries from lmdb cli.add_arg_limit(parser) args = parser.parse_args() - report = DiffReport.from_json(cli.get_datafile(args, key='new_datafile')) - field_weights = args.cfg['report']['field_weights'] - ref_report = DiffReport.from_json(cli.get_datafile(args, key='old_datafile')) + report = DiffReport.from_json(cli.get_datafile(args, key="new_datafile")) + field_weights = args.cfg["report"]["field_weights"] + ref_report = DiffReport.from_json(cli.get_datafile(args, key="old_datafile")) check_report_summary(report) check_report_summary(ref_report) @@ -63,7 +72,9 @@ def main(): if field not in field_counters: field_counters[field] = Counter() - cli.print_fields_overview(field_counters, len(report.summary), ref_field_counters) + cli.print_fields_overview( + field_counters, len(report.summary), ref_field_counters + ) for field in field_weights: if field in field_counters: @@ -76,22 +87,27 @@ def main(): counter[mismatch] = 0 cli.print_field_mismatch_stats( - field, counter, len(report.summary), ref_counter) + field, counter, len(report.summary), ref_counter + ) # query details with LMDB(args.envdir, readonly=True) as lmdb: lmdb.open_db(LMDB.QUERIES) queries_all = convert_queries( - get_query_iterator(lmdb, report.summary.keys())) + get_query_iterator(lmdb, report.summary.keys()) + ) ref_queries_all = convert_queries( - get_query_iterator(lmdb, ref_report.summary.keys())) + get_query_iterator(lmdb, ref_report.summary.keys()) + ) for field in field_weights: if field in field_counters: # ensure "disappeared" mismatches are shown field_mismatches = dict(report.summary.get_field_mismatches(field)) - ref_field_mismatches = dict(ref_report.summary.get_field_mismatches(field)) + ref_field_mismatches = dict( + ref_report.summary.get_field_mismatches(field) + ) mismatches = set(field_mismatches.keys()) mismatches.update(ref_field_mismatches.keys()) @@ -99,17 +115,19 @@ def main(): qids = field_mismatches.get(mismatch, set()) queries = convert_queries(get_query_iterator(lmdb, qids)) ref_queries = convert_queries( - get_query_iterator(lmdb, ref_field_mismatches.get(mismatch, set()))) + get_query_iterator( + lmdb, ref_field_mismatches.get(mismatch, set()) + ) + ) cli.print_mismatch_queries( field, mismatch, get_printable_queries_format( - queries, - queries_all, - ref_queries, - ref_queries_all), - args.limit) + queries, queries_all, ref_queries, ref_queries_all + ), + args.limit, + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/sumstat.py b/sumstat.py index d17fe9cf09f509ee05b01f607e3e276b0ff6bb79..5e30d47a45073dc623c8fa50c2823dae9e51f9bd 100755 --- a/sumstat.py +++ b/sumstat.py @@ -10,12 +10,14 @@ from respdiff.stats import SummaryStatistics def _log_threshold(stats, label): percentile_rank = stats.get_percentile_rank(stats.threshold) - logging.info(' %s: %4.2f percentile rank', label, percentile_rank) + logging.info(" %s: %4.2f percentile rank", label, percentile_rank) def main(): cli.setup_logging() - parser = argparse.ArgumentParser(description='generate statistics file from reports') + parser = argparse.ArgumentParser( + description="generate statistics file from reports" + ) cli.add_arg_report_filename(parser) cli.add_arg_stats_filename(parser) @@ -28,16 +30,16 @@ def main(): logging.critical(exc) sys.exit(1) - logging.info('Total sample size: %d', sumstats.sample_size) - logging.info('Upper boundaries:') - _log_threshold(sumstats.target_disagreements, 'target_disagreements') - _log_threshold(sumstats.upstream_unstable, 'upstream_unstable') - _log_threshold(sumstats.not_reproducible, 'not_reproducible') + logging.info("Total sample size: %d", sumstats.sample_size) + logging.info("Upper boundaries:") + _log_threshold(sumstats.target_disagreements, "target_disagreements") + _log_threshold(sumstats.upstream_unstable, "upstream_unstable") + _log_threshold(sumstats.not_reproducible, "not_reproducible") for field_name, mismatch_stats in sumstats.fields.items(): _log_threshold(mismatch_stats.total, field_name) sumstats.export_json(args.stats_filename) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tests/lmdb/create_test_lmdb.py b/tests/lmdb/create_test_lmdb.py index 1ce25df82de86b31a7e5f67c6f8ffa6a9e9f1980..03df81ccd8d303906c6b3e70b181161badfa012b 100755 --- a/tests/lmdb/create_test_lmdb.py +++ b/tests/lmdb/create_test_lmdb.py @@ -12,13 +12,13 @@ import lmdb VERSION = "2018-05-21" CREATE_ENVS = { "2018-05-21": [ - 'answers_single_server', - 'answers_multiple_servers', + "answers_single_server", + "answers_multiple_servers", ], } -BIN_INT_3000000000 = b'\x00^\xd0\xb2' +BIN_INT_3000000000 = b"\x00^\xd0\xb2" class LMDBExistsError(Exception): @@ -27,68 +27,65 @@ class LMDBExistsError(Exception): def open_env(version, name): path = os.path.join(version, name) - if os.path.exists(os.path.join(path, 'data.mdb')): + if os.path.exists(os.path.join(path, "data.mdb")): raise LMDBExistsError if not os.path.exists(path): os.makedirs(path) - return lmdb.Environment( - path=path, - max_dbs=5, - create=True) + return lmdb.Environment(path=path, max_dbs=5, create=True) def create_answers_single_server(): - env = open_env(VERSION, 'answers_single_server') - mdb = env.open_db(key=b'meta', create=True) + env = open_env(VERSION, "answers_single_server") + mdb = env.open_db(key=b"meta", create=True) with env.begin(mdb, write=True) as txn: - txn.put(b'version', VERSION.encode('ascii')) - txn.put(b'start_time', BIN_INT_3000000000) - txn.put(b'end_time', BIN_INT_3000000000) - txn.put(b'servers', struct.pack('<I', 1)) - txn.put(b'name0', 'kresd'.encode('ascii')) + txn.put(b"version", VERSION.encode("ascii")) + txn.put(b"start_time", BIN_INT_3000000000) + txn.put(b"end_time", BIN_INT_3000000000) + txn.put(b"servers", struct.pack("<I", 1)) + txn.put(b"name0", "kresd".encode("ascii")) - adb = env.open_db(key=b'answers', create=True) + adb = env.open_db(key=b"answers", create=True) with env.begin(adb, write=True) as txn: answer = BIN_INT_3000000000 - answer += struct.pack('<H', 1) - answer += b'a' + answer += struct.pack("<H", 1) + answer += b"a" txn.put(BIN_INT_3000000000, answer) def create_answers_multiple_servers(): - env = open_env(VERSION, 'answers_multiple_servers') - mdb = env.open_db(key=b'meta', create=True) + env = open_env(VERSION, "answers_multiple_servers") + mdb = env.open_db(key=b"meta", create=True) with env.begin(mdb, write=True) as txn: - txn.put(b'version', VERSION.encode('ascii')) - txn.put(b'servers', struct.pack('<I', 3)) - txn.put(b'name0', 'kresd'.encode('ascii')) - txn.put(b'name1', 'bind'.encode('ascii')) - txn.put(b'name2', 'unbound'.encode('ascii')) + txn.put(b"version", VERSION.encode("ascii")) + txn.put(b"servers", struct.pack("<I", 3)) + txn.put(b"name0", "kresd".encode("ascii")) + txn.put(b"name1", "bind".encode("ascii")) + txn.put(b"name2", "unbound".encode("ascii")) - adb = env.open_db(key=b'answers', create=True) + adb = env.open_db(key=b"answers", create=True) with env.begin(adb, write=True) as txn: # kresd answer = BIN_INT_3000000000 - answer += struct.pack('<H', 0) + answer += struct.pack("<H", 0) # bind answer += BIN_INT_3000000000 - answer += struct.pack('<H', 2) - answer += b'ab' + answer += struct.pack("<H", 2) + answer += b"ab" # unbound answer += BIN_INT_3000000000 - answer += struct.pack('<H', 1) - answer += b'a' + answer += struct.pack("<H", 1) + answer += b"a" txn.put(BIN_INT_3000000000, answer) def main(): for env_name in CREATE_ENVS[VERSION]: try: - globals()['create_{}'.format(env_name)]() + globals()["create_{}".format(env_name)]() except LMDBExistsError: - print('{} exists, skipping'.format(env_name)) + print("{} exists, skipping".format(env_name)) continue -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tests/test_cli.py b/tests/test_cli.py index 613c100d2a162179cac63ff7e26734e02c7bed28..3a441867dc2a92749aa06578aac56071930405a1 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -6,18 +6,21 @@ from pytest import approx from respdiff.cli import get_stats_data -@pytest.mark.parametrize('n, total, ref_n, expected', [ - (9, None, None, (9, None, None, None)), - (9, 10, None, (9, 90, None, None)), - (11, None, 10, (11, None, 1, 10)), - (9, None, 10, (9, None, -1, -10)), - (10, None, 10, (10, None, 0, 0)), - (10, None, 0, (10, None, +10, float('inf'))), - (0, None, 0, (0, None, 0, float('nan'))), - (9, 10, 10, (9, 90, -1, -10)), - (9, 10, 90, (9, 90, -81, -81*100.0/90)), - (90, 100, 9, (90, 90, 81, 81*100.0/9)), -]) +@pytest.mark.parametrize( + "n, total, ref_n, expected", + [ + (9, None, None, (9, None, None, None)), + (9, 10, None, (9, 90, None, None)), + (11, None, 10, (11, None, 1, 10)), + (9, None, 10, (9, None, -1, -10)), + (10, None, 10, (10, None, 0, 0)), + (10, None, 0, (10, None, +10, float("inf"))), + (0, None, 0, (0, None, 0, float("nan"))), + (9, 10, 10, (9, 90, -1, -10)), + (9, 10, 90, (9, 90, -81, -81 * 100.0 / 90)), + (90, 100, 9, (90, 90, 81, 81 * 100.0 / 9)), + ], +) def test_get_stats_data(n, total, ref_n, expected): got_n, got_pct, got_diff, got_diff_pct = get_stats_data(n, total, ref_n) assert got_n == expected[0] diff --git a/tests/test_data.py b/tests/test_data.py index 7fd1b92de6f47988802e601304b93479f8a2b213..3bd6ff5b5fcacb8994cc5397fe1cab131191fc8a 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -4,32 +4,43 @@ import json import pytest from respdiff.dataformat import ( - Diff, DiffReport, Disagreements, DisagreementsCounter, JSONDataObject, ReproCounter, - ReproData, Summary) + Diff, + DiffReport, + Disagreements, + DisagreementsCounter, + JSONDataObject, + ReproCounter, + ReproData, + Summary, +) from respdiff.match import DataMismatch MISMATCH_DATA = [ - ('timeout', 'answer'), - (['A'], ['A', 'CNAME']), - (['A'], ['A', 'RRSIG(A)']), + ("timeout", "answer"), + (["A"], ["A", "CNAME"]), + (["A"], ["A", "RRSIG(A)"]), ] DIFF_DATA = [ - ('timeout', MISMATCH_DATA[0]), - ('answertypes', MISMATCH_DATA[1]), - ('answerrrsigs', MISMATCH_DATA[2]), - ('answerrrsigs', MISMATCH_DATA[1]), + ("timeout", MISMATCH_DATA[0]), + ("answertypes", MISMATCH_DATA[1]), + ("answerrrsigs", MISMATCH_DATA[2]), + ("answerrrsigs", MISMATCH_DATA[1]), ] -QUERY_DIFF_DATA = list(enumerate([ - (), - (DIFF_DATA[0],), - (DIFF_DATA[0], DIFF_DATA[1]), - (DIFF_DATA[1], DIFF_DATA[0]), - (DIFF_DATA[0], DIFF_DATA[1], DIFF_DATA[2]), - (DIFF_DATA[0], DIFF_DATA[3], DIFF_DATA[1]), -])) +QUERY_DIFF_DATA = list( + enumerate( + [ + (), + (DIFF_DATA[0],), + (DIFF_DATA[0], DIFF_DATA[1]), + (DIFF_DATA[1], DIFF_DATA[0]), + (DIFF_DATA[0], DIFF_DATA[1], DIFF_DATA[2]), + (DIFF_DATA[0], DIFF_DATA[3], DIFF_DATA[1]), + ] + ) +) QUERY_DIFF_JSON = """ { @@ -140,14 +151,20 @@ REPRODATA_JSON = """ def test_data_mismatch_init(): with pytest.raises(Exception): - DataMismatch(1, 1) - - -@pytest.mark.parametrize('mismatch_data, expected_key', zip(MISMATCH_DATA, [ - ('timeout', 'answer'), - (('A',), ('A', 'CNAME')), - (('A',), ('A', 'RRSIG(A)')), -])) + DataMismatch(1, 1) # pylint: disable=pointless-exception-statement + + +@pytest.mark.parametrize( + "mismatch_data, expected_key", + zip( + MISMATCH_DATA, + [ + ("timeout", "answer"), + (("A",), ("A", "CNAME")), + (("A",), ("A", "RRSIG(A)")), + ], + ), +) def test_data_mismatch(mismatch_data, expected_key): mismatch1 = DataMismatch(*mismatch_data) mismatch2 = DataMismatch(*mismatch_data) @@ -162,8 +179,9 @@ def test_data_mismatch(mismatch_data, expected_key): assert mismatch1.key == expected_key -@pytest.mark.parametrize('mismatch1_data, mismatch2_data', - itertools.combinations(MISMATCH_DATA, 2)) +@pytest.mark.parametrize( + "mismatch1_data, mismatch2_data", itertools.combinations(MISMATCH_DATA, 2) +) def test_data_mismatch_differnet_key_hash(mismatch1_data, mismatch2_data): if mismatch1_data == mismatch2_data: return @@ -181,35 +199,38 @@ def test_json_data_object(): assert empty.save() is None # simple scalar, list or dict -- no restore/save callbacks - attrs = {'a': (None, None)} + attrs = {"a": (None, None)} basic = JSONDataObject() basic.a = 1 basic._ATTRIBUTES = attrs data = basic.save() - assert data['a'] == 1 + assert data["a"] == 1 restored = JSONDataObject() restored._ATTRIBUTES = attrs restored.restore(data) assert restored.a == basic.a # with save/restore callback - attrs = {'b': (lambda x: x + 1, lambda x: x - 1)} + attrs = {"b": (lambda x: x + 1, lambda x: x - 1)} complex_obj = JSONDataObject() complex_obj.b = 1 complex_obj._ATTRIBUTES = attrs data = complex_obj.save() - assert data['b'] == 0 + assert data["b"] == 0 restored = JSONDataObject() restored._ATTRIBUTES = attrs restored.restore(data) assert restored.b == complex_obj.b + + # pylint: enable=protected-access -@pytest.mark.parametrize('qid, diff_data', QUERY_DIFF_DATA) +@pytest.mark.parametrize("qid, diff_data", QUERY_DIFF_DATA) def test_diff(qid, diff_data): - mismatches = {field: DataMismatch(*mismatch_data) - for field, mismatch_data in diff_data} + mismatches = { + field: DataMismatch(*mismatch_data) for field, mismatch_data in diff_data + } diff = Diff(qid, mismatches) fields = [] @@ -224,28 +245,38 @@ def test_diff(qid, diff_data): if not field_weights: continue field_weights = list(field_weights) - assert diff.get_significant_field(field_weights) == \ - (field_weights[0], mismatches[field_weights[0]]) - field_weights.append('custom') - assert diff.get_significant_field(field_weights) == \ - (field_weights[0], mismatches[field_weights[0]]) - field_weights.insert(0, 'custom2') - assert diff.get_significant_field(field_weights) == \ - (field_weights[1], mismatches[field_weights[1]]) - assert diff.get_significant_field(['non_existent']) == (None, None) + assert diff.get_significant_field(field_weights) == ( + field_weights[0], + mismatches[field_weights[0]], + ) + field_weights.append("custom") + assert diff.get_significant_field(field_weights) == ( + field_weights[0], + mismatches[field_weights[0]], + ) + field_weights.insert(0, "custom2") + assert diff.get_significant_field(field_weights) == ( + field_weights[1], + mismatches[field_weights[1]], + ) + assert diff.get_significant_field(["non_existent"]) == (None, None) # adding or removing items isn't possible with pytest.raises(Exception): - diff['non_existent'] = None # pylint: disable=unsupported-assignment-operation + diff["non_existent"] = None # pylint: disable=unsupported-assignment-operation with pytest.raises(Exception): del diff[list(diff.keys())[0]] # pylint: disable=unsupported-delete-operation def test_diff_equality(): - mismatches_tuple = {'timeout': DataMismatch('answer', 'timeout'), - 'answertypes': DataMismatch(('A',), ('CNAME',))} - mismatches_list = {'timeout': DataMismatch('answer', 'timeout'), - 'answertypes': DataMismatch(['A'], ['CNAME'])} + mismatches_tuple = { + "timeout": DataMismatch("answer", "timeout"), + "answertypes": DataMismatch(("A",), ("CNAME",)), + } + mismatches_list = { + "timeout": DataMismatch("answer", "timeout"), + "answertypes": DataMismatch(["A"], ["CNAME"]), + } # tuple or list doesn't matter assert Diff(1, mismatches_tuple) == Diff(1, mismatches_list) @@ -254,7 +285,7 @@ def test_diff_equality(): assert Diff(1, mismatches_tuple) == Diff(2, mismatches_tuple) # different mismatches - mismatches_tuple['answerrrsigs'] = DataMismatch(('RRSIG(A)',), ('',)) + mismatches_tuple["answerrrsigs"] = DataMismatch(("RRSIG(A)",), ("",)) assert Diff(1, mismatches_tuple) != Diff(1, mismatches_list) @@ -327,8 +358,8 @@ def test_diff_report(): # report with some missing fields partial_data = report.save() - del partial_data['other_disagreements'] - del partial_data['target_disagreements'] + del partial_data["other_disagreements"] + del partial_data["target_disagreements"] partial_report = DiffReport(_restore_dict=partial_data) assert partial_report.other_disagreements is None assert partial_report.target_disagreements is None @@ -337,7 +368,7 @@ def test_diff_report(): def test_summary(): - field_weights = ['timeout', 'answertypes', 'aswerrrsigs'] + field_weights = ["timeout", "answertypes", "aswerrrsigs"] report = DiffReport(_restore_dict=json.loads(DIFF_REPORT_JSON)) # no reprodata -- no queries are missing @@ -390,17 +421,17 @@ def test_repro_counter(): assert rc.verified == 1 assert rc.different_failure == 1 - rc = ReproCounter(_restore_dict={'retries': 4}) + rc = ReproCounter(_restore_dict={"retries": 4}) assert rc.retries == 4 assert rc.upstream_stable == 0 assert rc.verified == 0 assert rc.different_failure == 0 data = rc.save() - assert data['retries'] == 4 - assert data['upstream_stable'] == 0 - assert data['verified'] == 0 - assert data['different_failure'] == 0 + assert data["retries"] == 4 + assert data["upstream_stable"] == 0 + assert data["verified"] == 0 + assert data["different_failure"] == 0 assert ReproCounter().save() is None diff --git a/tests/test_database.py b/tests/test_database.py index d70b4e39e3d6bc336c52558fc10c3d614adb805d..481186602316d2d806a7d7a0a5f70671adaeb698 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -2,11 +2,18 @@ import os import pytest from respdiff.database import ( - DNSReply, DNSRepliesFactory, LMDB, MetaDatabase, BIN_FORMAT_VERSION, qid2key) + DNSReply, + DNSRepliesFactory, + LMDB, + MetaDatabase, + BIN_FORMAT_VERSION, + qid2key, +) LMDB_DIR = os.path.join( - os.path.abspath(os.path.dirname(__file__)), 'lmdb', BIN_FORMAT_VERSION) + os.path.abspath(os.path.dirname(__file__)), "lmdb", BIN_FORMAT_VERSION +) def create_reply(wire, time): @@ -18,71 +25,83 @@ def create_reply(wire, time): def test_dns_reply_timeout(): reply = DNSReply(None) assert reply.timeout - assert reply.time == float('+inf') - - -@pytest.mark.parametrize('wire1, time1, wire2, time2, equals', [ - (None, None, None, None, True), - (None, None, None, 1, True), - (b'', None, b'', None, True), - (b'', None, b'', 1, False), - (b'a', None, b'a', None, True), - (b'a', None, b'b', None, False), - (b'a', None, b'aa', None, False), -]) + assert reply.time == float("+inf") + + +@pytest.mark.parametrize( + "wire1, time1, wire2, time2, equals", + [ + (None, None, None, None, True), + (None, None, None, 1, True), + (b"", None, b"", None, True), + (b"", None, b"", 1, False), + (b"a", None, b"a", None, True), + (b"a", None, b"b", None, False), + (b"a", None, b"aa", None, False), + ], +) def test_dns_reply_equals(wire1, time1, wire2, time2, equals): r1 = create_reply(wire1, time1) r2 = create_reply(wire2, time2) assert (r1 == r2) == equals -@pytest.mark.parametrize('time, time_int', [ - (None, 0), - (0, 0), - (1.43, 1430000), - (0.4591856, 459186), -]) +@pytest.mark.parametrize( + "time, time_int", + [ + (None, 0), + (0, 0), + (1.43, 1430000), + (0.4591856, 459186), + ], +) def test_dns_reply_time_int(time, time_int): - reply = create_reply(b'', time) + reply = create_reply(b"", time) assert reply.time_int == time_int DR_TIMEOUT = DNSReply(None) -DR_TIMEOUT_BIN = b'\xff\xff\xff\xff\x00\x00' -DR_EMPTY_0 = DNSReply(b'') -DR_EMPTY_0_BIN = b'\x00\x00\x00\x00\x00\x00' -DR_EMPTY_1 = DNSReply(b'', 1) -DR_EMPTY_1_BIN = b'\x40\x42\x0f\x00\x00\x00' -DR_A_0 = DNSReply(b'a') -DR_A_0_BIN = b'\x00\x00\x00\x00\x01\x00a' -DR_A_1 = DNSReply(b'a', 1) -DR_A_1_BIN = b'\x40\x42\x0f\x00\x01\x00a' -DR_ABCD_1 = DNSReply(b'abcd', 1) -DR_ABCD_1_BIN = b'\x40\x42\x0f\x00\x04\x00abcd' - - -@pytest.mark.parametrize('reply, binary', [ - (DR_TIMEOUT, DR_TIMEOUT_BIN), - (DR_EMPTY_0, DR_EMPTY_0_BIN), - (DR_EMPTY_1, DR_EMPTY_1_BIN), - (DR_A_0, DR_A_0_BIN), - (DR_A_1, DR_A_1_BIN), - (DR_ABCD_1, DR_ABCD_1_BIN), -]) +DR_TIMEOUT_BIN = b"\xff\xff\xff\xff\x00\x00" +DR_EMPTY_0 = DNSReply(b"") +DR_EMPTY_0_BIN = b"\x00\x00\x00\x00\x00\x00" +DR_EMPTY_1 = DNSReply(b"", 1) +DR_EMPTY_1_BIN = b"\x40\x42\x0f\x00\x00\x00" +DR_A_0 = DNSReply(b"a") +DR_A_0_BIN = b"\x00\x00\x00\x00\x01\x00a" +DR_A_1 = DNSReply(b"a", 1) +DR_A_1_BIN = b"\x40\x42\x0f\x00\x01\x00a" +DR_ABCD_1 = DNSReply(b"abcd", 1) +DR_ABCD_1_BIN = b"\x40\x42\x0f\x00\x04\x00abcd" + + +@pytest.mark.parametrize( + "reply, binary", + [ + (DR_TIMEOUT, DR_TIMEOUT_BIN), + (DR_EMPTY_0, DR_EMPTY_0_BIN), + (DR_EMPTY_1, DR_EMPTY_1_BIN), + (DR_A_0, DR_A_0_BIN), + (DR_A_1, DR_A_1_BIN), + (DR_ABCD_1, DR_ABCD_1_BIN), + ], +) def test_dns_reply_serialization(reply, binary): assert reply.binary == binary -@pytest.mark.parametrize('binary, reply, remaining', [ - (DR_TIMEOUT_BIN, DR_TIMEOUT, b''), - (DR_EMPTY_0_BIN, DR_EMPTY_0, b''), - (DR_EMPTY_1_BIN, DR_EMPTY_1, b''), - (DR_A_0_BIN, DR_A_0, b''), - (DR_A_1_BIN, DR_A_1, b''), - (DR_ABCD_1_BIN, DR_ABCD_1, b''), - (DR_A_1_BIN + b'a', DR_A_1, b'a'), - (DR_ABCD_1_BIN + b'bcd', DR_ABCD_1, b'bcd'), -]) +@pytest.mark.parametrize( + "binary, reply, remaining", + [ + (DR_TIMEOUT_BIN, DR_TIMEOUT, b""), + (DR_EMPTY_0_BIN, DR_EMPTY_0, b""), + (DR_EMPTY_1_BIN, DR_EMPTY_1, b""), + (DR_A_0_BIN, DR_A_0, b""), + (DR_A_1_BIN, DR_A_1, b""), + (DR_ABCD_1_BIN, DR_ABCD_1, b""), + (DR_A_1_BIN + b"a", DR_A_1, b"a"), + (DR_ABCD_1_BIN + b"bcd", DR_ABCD_1, b"bcd"), + ], +) def test_dns_reply_deserialization(binary, reply, remaining): got_reply, buff = DNSReply.from_binary(binary) assert reply == got_reply @@ -93,18 +112,18 @@ def test_dns_replies_factory(): with pytest.raises(ValueError): DNSRepliesFactory([]) - rf = DNSRepliesFactory(['a']) + rf = DNSRepliesFactory(["a"]) replies = rf.parse(DR_TIMEOUT_BIN) - assert replies['a'] == DR_TIMEOUT + assert replies["a"] == DR_TIMEOUT - rf2 = DNSRepliesFactory(['a', 'b']) + rf2 = DNSRepliesFactory(["a", "b"]) bin_data = DR_A_0_BIN + DR_ABCD_1_BIN replies = rf2.parse(bin_data) - assert replies['a'] == DR_A_0 - assert replies['b'] == DR_ABCD_1 + assert replies["a"] == DR_A_0 + assert replies["b"] == DR_ABCD_1 with pytest.raises(ValueError): - rf2.parse(DR_A_0_BIN + b'a') + rf2.parse(DR_A_0_BIN + b"a") assert rf2.serialize(replies) == bin_data @@ -114,38 +133,38 @@ TIME_3M = 3000.0 def test_lmdb_answers_single_server(): - envdir = os.path.join(LMDB_DIR, 'answers_single_server') + envdir = os.path.join(LMDB_DIR, "answers_single_server") with LMDB(envdir) as lmdb: adb = lmdb.open_db(LMDB.ANSWERS) - meta = MetaDatabase(lmdb, ['kresd']) + meta = MetaDatabase(lmdb, ["kresd"]) assert meta.read_start_time() == INT_3M assert meta.read_end_time() == INT_3M servers = meta.read_servers() assert len(servers) == 1 - assert servers[0] == 'kresd' + assert servers[0] == "kresd" with lmdb.env.begin(adb) as txn: data = txn.get(qid2key(INT_3M)) df = DNSRepliesFactory(servers) replies = df.parse(data) assert len(replies) == 1 - assert replies[servers[0]] == DNSReply(b'a', TIME_3M) + assert replies[servers[0]] == DNSReply(b"a", TIME_3M) def test_lmdb_answers_multiple_servers(): - envdir = os.path.join(LMDB_DIR, 'answers_multiple_servers') + envdir = os.path.join(LMDB_DIR, "answers_multiple_servers") with LMDB(envdir) as lmdb: adb = lmdb.open_db(LMDB.ANSWERS) - meta = MetaDatabase(lmdb, ['kresd', 'bind', 'unbound']) + meta = MetaDatabase(lmdb, ["kresd", "bind", "unbound"]) assert meta.read_start_time() is None assert meta.read_end_time() is None servers = meta.read_servers() assert len(servers) == 3 - assert servers[0] == 'kresd' - assert servers[1] == 'bind' - assert servers[2] == 'unbound' + assert servers[0] == "kresd" + assert servers[1] == "bind" + assert servers[2] == "unbound" df = DNSRepliesFactory(servers) @@ -154,6 +173,6 @@ def test_lmdb_answers_multiple_servers(): replies = df.parse(data) assert len(replies) == 3 - assert replies[servers[0]] == DNSReply(b'', TIME_3M) - assert replies[servers[1]] == DNSReply(b'ab', TIME_3M) - assert replies[servers[2]] == DNSReply(b'a', TIME_3M) + assert replies[servers[0]] == DNSReply(b"", TIME_3M) + assert replies[servers[1]] == DNSReply(b"ab", TIME_3M) + assert replies[servers[2]] == DNSReply(b"a", TIME_3M) diff --git a/tests/test_qprep_pcap.py b/tests/test_qprep_pcap.py index 16b1badafabd3dcd443216c7a8c1e37bc26c28e1..157d964960ded5f91e6660e461f9794d4e677cc4 100644 --- a/tests/test_qprep_pcap.py +++ b/tests/test_qprep_pcap.py @@ -5,52 +5,73 @@ import pytest from qprep import wrk_process_frame, wrk_process_wire_packet -@pytest.mark.parametrize('wire', [ - b'', - b'x', - b'xx', -]) +@pytest.mark.parametrize( + "wire", + [ + b"", + b"x", + b"xx", + ], +) def test_wire_input_invalid(wire): - assert wrk_process_wire_packet(1, wire, 'invalid') == (1, wire) - assert wrk_process_wire_packet(1, wire, 'invalid') == (1, wire) + assert wrk_process_wire_packet(1, wire, "invalid") == (1, wire) + assert wrk_process_wire_packet(1, wire, "invalid") == (1, wire) -@pytest.mark.parametrize('wire_hex', [ - # www.audioweb.cz A - 'ed21010000010000000000010377777708617564696f77656202637a00000100010000291000000080000000', -]) +@pytest.mark.parametrize( + "wire_hex", + [ + # www.audioweb.cz A + "ed21010000010000000000010377777708617564696f77656202637a00000100010000291000000080000000", + ], +) def test_wire_input_valid(wire_hex): wire_in = binascii.unhexlify(wire_hex) - qid, wire_out = wrk_process_wire_packet(1, wire_in, 'qid 1') + qid, wire_out = wrk_process_wire_packet(1, wire_in, "qid 1") assert wire_in == wire_out assert qid == 1 -@pytest.mark.parametrize('wire_hex', [ - # test.dotnxdomain.net. A - ('ce970120000100000000000104746573740b646f746e78646f6d61696e036e657400000' - '10001000029100000000000000c000a00084a69fef0f174d87e'), - # 0es-u2af5c077-c56-s1492621913-i00000000.eue.dotnxdomain.net A - ('d72f01000001000000000001273065732d7532616635633037372d6335362d733134393' - '23632313931332d693030303030303030036575650b646f746e78646f6d61696e036e65' - '7400000100010000291000000080000000'), -]) +@pytest.mark.parametrize( + "wire_hex", + [ + # test.dotnxdomain.net. A + ( + "ce970120000100000000000104746573740b646f746e78646f6d61696e036e657400000" + "10001000029100000000000000c000a00084a69fef0f174d87e" + ), + # 0es-u2af5c077-c56-s1492621913-i00000000.eue.dotnxdomain.net A + ( + "d72f01000001000000000001273065732d7532616635633037372d6335362d733134393" + "23632313931332d693030303030303030036575650b646f746e78646f6d61696e036e65" + "7400000100010000291000000080000000" + ), + ], +) def test_pcap_input_blacklist(wire_hex): wire = binascii.unhexlify(wire_hex) - assert wrk_process_wire_packet(1, wire, 'qid 1') == (None, None) - - -@pytest.mark.parametrize('frame_hex, wire_hex', [ - # UPD nic.cz A - ('deadbeefcafecafebeefbeef08004500004bf9d000004011940d0202020201010101b533003500375520', - 'b90001200001000000000001036e696302637a0000010001000029100000000000000c000a00081491f8' - '93b0c90b2f'), - # TCP nic.cz A - ('deadbeefcafebeefbeefcafe080045000059e2f2400040066ae80202020201010101ace7003557b51707' - '47583400501800e5568c0000002f', '49e501200001000000000001036e696302637a00000100010000' - '29100000000000000c000a0008a1db546e1d6fa39f'), -]) + assert wrk_process_wire_packet(1, wire, "qid 1") == (None, None) + + +@pytest.mark.parametrize( + "frame_hex, wire_hex", + [ + # UPD nic.cz A + ( + "deadbeefcafecafebeefbeef08004500004bf9d000004011940d0202020201010101b533003500375520", + "b90001200001000000000001036e696302637a0000010001000029100000000000000c000a00081491f8" + "93b0c90b2f", + ), + # TCP nic.cz A + ( + "deadbeefcafebeefbeefcafe080045000059e2f2400040066ae80202020201010101ace7003557b51707" + "47583400501800e5568c0000002f", + "49e501200001000000000001036e696302637a00000100010000" + "29100000000000000c000a0008a1db546e1d6fa39f", + ), + ], +) def test_wrk_process_frame(frame_hex, wire_hex): data = binascii.unhexlify(frame_hex + wire_hex) wire = binascii.unhexlify(wire_hex) - assert wrk_process_frame((1, data, 'qid 1')) == (1, wire) + assert wrk_process_frame((1, data, "qid 1")) == (1, wire) diff --git a/tests/test_qprep_text.py b/tests/test_qprep_text.py index 553b1457b32fd1fd2ef1c71d44f5f30bb8594f62..ac61864ac84f92cf69014f45fe120f70b8b09cfc 100644 --- a/tests/test_qprep_text.py +++ b/tests/test_qprep_text.py @@ -7,25 +7,31 @@ import pytest from qprep import wrk_process_line -@pytest.mark.parametrize('line', [ - '', - 'x'*256 + ' A', - '\123x.test. 65536', - '\321.test. 1', - 'test. A,AAAA', - 'test. A, AAAA', -]) +@pytest.mark.parametrize( + "line", + [ + "", + "x" * 256 + " A", + "\123x.test. 65536", + "\321.test. 1", + "test. A,AAAA", + "test. A, AAAA", + ], +) def test_text_input_invalid(line): assert wrk_process_line((1, line, line)) == (None, None) -@pytest.mark.parametrize('qname, qtype', [ - ('x', 'A'), - ('x', 1), - ('blabla.test.', 'TSIG'), -]) +@pytest.mark.parametrize( + "qname, qtype", + [ + ("x", "A"), + ("x", 1), + ("blabla.test.", "TSIG"), + ], +) def test_text_input_valid(qname, qtype): - line = '{} {}'.format(qname, qtype) + line = "{} {}".format(qname, qtype) if isinstance(qtype, int): rdtype = qtype @@ -39,12 +45,15 @@ def test_text_input_valid(qname, qtype): assert qid == 1 -@pytest.mark.parametrize('line', [ - 'test. ANY', - 'test. RRSIG', - 'dotnxdomain.net. 28', - 'something.dotnxdomain.net. A', - 'something.dashnxdomain.net. AAAA', -]) +@pytest.mark.parametrize( + "line", + [ + "test. ANY", + "test. RRSIG", + "dotnxdomain.net. 28", + "something.dotnxdomain.net. A", + "something.dashnxdomain.net. AAAA", + ], +) def test_text_input_blacklist(line): assert wrk_process_line((1, line, line)) == (None, None) diff --git a/utils/dns2txt.py b/utils/dns2txt.py index 57ad8e51db15801c466827c93cd0e1fe5bf3c223..d0271dbabb513832da1e2fde5bf4119b717f3d85 100644 --- a/utils/dns2txt.py +++ b/utils/dns2txt.py @@ -4,6 +4,6 @@ import sys import dns.message -with open(sys.argv[1], 'rb') as f: +with open(sys.argv[1], "rb") as f: m = dns.message.from_wire(f.read()) print(str(m)) diff --git a/utils/normalize_names.py b/utils/normalize_names.py index 603a288100dce1784e722d56968d59d8d638b38a..1d7fe0bfab2c33fec298c55293f7b09f571f8050 100644 --- a/utils/normalize_names.py +++ b/utils/normalize_names.py @@ -1,13 +1,13 @@ import string import sys -allowed_bytes = (string.ascii_letters + string.digits + '.-_').encode('ascii') +allowed_bytes = (string.ascii_letters + string.digits + ".-_").encode("ascii") trans = {} for i in range(0, 256): if i in allowed_bytes: - trans[i] = bytes(chr(i), encoding='ascii') + trans[i] = bytes(chr(i), encoding="ascii") else: - trans[i] = (r'\%03i' % i).encode('ascii') + trans[i] = (r"\%03i" % i).encode("ascii") # pprint(trans) while True: @@ -16,7 +16,7 @@ while True: break line = line[:-1] # newline - typestart = line.rfind(b' ') + 1 # rightmost space + typestart = line.rfind(b" ") + 1 # rightmost space if not typestart: continue # no RR type!? typetext = line[typestart:] @@ -24,10 +24,10 @@ while True: continue # normalize name - normalized = b'' - for nb in line[:typestart - 1]: + normalized = b"" + for nb in line[: typestart - 1]: normalized += trans[nb] sys.stdout.buffer.write(normalized) - sys.stdout.buffer.write(b' ') + sys.stdout.buffer.write(b" ") sys.stdout.buffer.write(typetext) - sys.stdout.buffer.write(b'\n') + sys.stdout.buffer.write(b"\n")