Commit 5aff589a authored by Tomas Krizek's avatar Tomas Krizek


parent 65b48167
from abc import ABC
from contextlib import contextmanager
import logging
import os
import struct
import sys
......@@ -160,6 +159,8 @@ class DNSReply:
def __eq__(self, other) -> bool:
if self.timeout and other.timeout:
return True
# float equality comparison: use 10^-7 tolerance since it's less than available
# resoltuion from the time_int integer value (which is 10^-6)
return self.wire == other.wire and \
abs(self.time - other.time) < 10 ** -7
......@@ -169,13 +170,14 @@ class DNSReply:
return self.TIMEOUT_INT
value = round(self.time * (10 ** 6))
if value > self.TIMEOUT_INT:
raise ValueError('Maximum time value exceeded')
raise ValueError(
'Maximum time value exceeded: (value: "{}", max: {})'.format(
value, self.TIMEOUT_INT))
return value
def binary(self) -> bytes:
length = len(self.wire)
assert length < 2**(self.SIZEOF_SHORT*8), 'Maximum wire format length exceeded'
return struct.pack('<I', self.time_int) + struct.pack('<H', length) + self.wire
......@@ -215,10 +217,12 @@ class DNSRepliesFactory:
reply, buff = DNSReply.from_binary(buff)
replies[server] = reply
if buff:
logging.warning('Trailing data in buffer')
raise ValueError('Trailing data in buffer')
return replies
def serialize(self, replies: Mapping[ResolverID, DNSReply]) -> bytes:
if len(replies) > len(self.servers):
raise ValueError('Extra unexpected data to serialize!')
data = []
for server in self.servers:
......@@ -227,8 +231,6 @@ class DNSRepliesFactory:
raise ValueError('Missing reply for server "{}"!'.format(server))
if len(replies) > len(self.servers):
raise ValueError('Extra unexpected data to serialize!')
return b''.join(data)
#!/usr/bin/env python3
# NOTE: Due to a weird bug, numpy is detected as a 3rd party module, while lmdb
# is not and pylint complains about wrong-import-order.
# Since these checks have to be disabled for matplotlib imports anyway, they
# were moved a bit higher up to avoid the issue.
# pylint: disable=wrong-import-order,wrong-import-position
import argparse
import logging
......@@ -259,7 +259,8 @@ def export_json(filename: str, report: DiffReport):
for field, mismatch in diff.items():
report.target_disagreements.add_mismatch(field, mismatch, qid)
# it doesn't make sense to use existing report.json
# NOTE: msgdiff is the first tool in the toolchain to generate report.json
# thus it doesn't make sense to re-use existing report.json file
if os.path.exists(filename):
backup_filename = filename + '.bak'
os.rename(filename, backup_filename)
......@@ -321,9 +322,10 @@ def main():
qid_stream = lmdb.key_stream(LMDB.ANSWERS)
dnsreplies_factory = DNSRepliesFactory(servers)
func = partial(compare_lmdb_wrapper, criteria, target, dnsreplies_factory)
compare_func = partial(
compare_lmdb_wrapper, criteria, target, dnsreplies_factory)
with pool.Pool() as p:
for _ in p.imap_unordered(func, qid_stream, chunksize=10):
for _ in p.imap_unordered(compare_func, qid_stream, chunksize=10):
export_json(datafile, report)
......@@ -46,7 +46,7 @@ def test_dns_reply_time_int(time, time_int):
assert reply.time_int == time_int
DR_TIMEOUT = DNSReply(None, None)
DR_TIMEOUT_BIN = b'\xff\xff\xff\xff\x00\x00'
DR_EMPTY_0 = DNSReply(b'')
DR_EMPTY_0_BIN = b'\x00\x00\x00\x00\x00\x00'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment