Commit 54edfc0f authored by Tomas Krizek's avatar Tomas Krizek

msgdiff: convert exp_val/got_val to str before passing it to DataMismatch

Fixes knot/resolver-benchmarking#33
parent 54ac6ef5
......@@ -6,8 +6,9 @@ import multiprocessing.pool as pool
import pickle
from typing import Any, Dict, Iterator, Mapping, Optional, Sequence, Tuple # noqa
import dns.message
import dns.exception
import dns.message
from dns.rrset import RRset
import cli
from dataformat import (
......@@ -23,24 +24,24 @@ lmdb = None
def compare_val(exp_val: MismatchValue, got_val: MismatchValue):
""" Compare values, throw exception if different. """
if exp_val != got_val:
raise DataMismatch(exp_val, got_val)
raise DataMismatch(str(exp_val), str(got_val))
return True
def compare_rrs(expected, got):
def compare_rrs(expected: RRset, got: RRset):
""" Compare lists of RR sets, throw exception if different. """
for rr in expected:
if rr not in got:
raise DataMismatch(expected, got)
raise DataMismatch(str(expected), str(got))
for rr in got:
if rr not in expected:
raise DataMismatch(expected, got)
raise DataMismatch(str(expected), str(got))
if len(expected) != len(got):
raise DataMismatch(expected, got)
raise DataMismatch(str(expected), str(got))
return True
def compare_rrs_types(exp_val, got_val, compare_rrsigs):
def compare_rrs_types(exp_val: RRset, got_val: RRset, compare_rrsigs: bool):
"""sets of RR types in both sections must match"""
def rr_ordering_key(rrset):
return rrset.covers if compare_rrsigs else rrset.rdtype
......@@ -61,9 +62,9 @@ def compare_rrs_types(exp_val, got_val, compare_rrsigs):
got_types = frozenset(rr_ordering_key(rrset)
for rrset in filter_by_rrsig(got_val, compare_rrsigs))
if exp_types != got_types:
exp_types = tuple(key_to_text(i) for i in sorted(exp_types))
got_types = tuple(key_to_text(i) for i in sorted(got_types))
raise DataMismatch(exp_types, got_types)
raise DataMismatch(
tuple(key_to_text(i) for i in sorted(exp_types)),
tuple(key_to_text(i) for i in sorted(got_types)))
def match_part(exp_msg, got_msg, code): # pylint: disable=inconsistent-return-statements
......@@ -98,9 +99,9 @@ def match_part(exp_msg, got_msg, code): # pylint: disable=inconsistent-return-s
return compare_rrs(exp_msg.additional, got_msg.additional)
elif code == 'edns':
if got_msg.edns != exp_msg.edns:
raise DataMismatch(exp_msg.edns, got_msg.edns)
raise DataMismatch(str(exp_msg.edns), str(got_msg.edns))
if got_msg.payload != exp_msg.payload:
raise DataMismatch(exp_msg.payload, got_msg.payload)
raise DataMismatch(str(exp_msg.payload), str(got_msg.payload))
elif code == 'nsid':
nsid_opt = None
for opt in exp_msg.options:
......@@ -111,13 +112,13 @@ def match_part(exp_msg, got_msg, code): # pylint: disable=inconsistent-return-s
for opt in got_msg.options:
if opt.otype == dns.edns.NSID:
if not nsid_opt:
raise DataMismatch(None, opt.data)
raise DataMismatch('', str(opt.data))
if opt == nsid_opt:
return True
else:
raise DataMismatch(nsid_opt.data, opt.data)
raise DataMismatch(str(nsid_opt.data), str(opt.data))
if nsid_opt:
raise DataMismatch(nsid_opt.data, None)
raise DataMismatch(str(nsid_opt.data), '')
else:
raise NotImplementedError('unknown match request "%s"' % code)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment