Commit 87c68ce6 authored by Tomas Krizek's avatar Tomas Krizek

match: handle malformed DNS replies

Closes #5
parent 842e14a8
......@@ -83,7 +83,7 @@ def chunker(iterable: Iterable[T], size: int) -> Iterator[Iterable[T]]:
def process_answers(
qkey: QKey,
replies: Mapping[ResolverID, DNSReply],
answers: Mapping[ResolverID, DNSReply],
report: DiffReport,
criteria: Sequence[FieldLabel],
target: ResolverID
......@@ -92,7 +92,6 @@ def process_answers(
raise RuntimeError("Report doesn't contain necessary data!")
qid = key2qid(qkey)
reprocounter = report.reprodata[qid]
answers = DNSRepliesFactory.decode_parsed(replies)
others_agree, mismatches = compare(answers, criteria, target)
reprocounter.retries += 1
......@@ -9,13 +9,10 @@ import pickle
from typing import Any, Dict, Iterator, Mapping, Optional, Sequence, Tuple # noqa
import sys
import dns.exception
import dns.message
from respdiff import cli
from respdiff.dataformat import (
DiffReport, Disagreements, DisagreementsCounter, FieldLabel, QID)
from respdiff.database import DNSRepliesFactory, key2qid, LMDB, MetaDatabase
from respdiff.database import DNSRepliesFactory, DNSReply, key2qid, LMDB, MetaDatabase
from respdiff.match import compare
from respdiff.typing import ResolverID
......@@ -26,14 +23,13 @@ lmdb = None
def read_answers_lmdb(
dnsreplies_factory: DNSRepliesFactory,
qid: QID
) -> Mapping[ResolverID, dns.message.Message]:
) -> Mapping[ResolverID, DNSReply]:
assert lmdb is not None, "LMDB wasn't initialized!"
adb = lmdb.get_db(LMDB.ANSWERS)
with lmdb.env.begin(adb) as txn:
replies_blob = txn.get(qid)
assert replies_blob
replies = dnsreplies_factory.parse(replies_blob)
return dnsreplies_factory.decode_parsed(replies)
return dnsreplies_factory.parse(replies_blob)
def compare_lmdb_wrapper(
......@@ -47,4 +47,4 @@ criteria = opcode, rcode, flags, question, qname, qtype, answertypes, answerrrsi
# diffsum reports mismatches in field values in this order
# if particular message has multiple mismatches, it is counted only once into category with highest weight
field_weights = timeout, opcode, qcase, qtype, rcode, flags, answertypes, answerrrsigs, answer, authority, additional, edns, nsid
field_weights = timeout, malformed, opcode, qcase, qtype, rcode, flags, answertypes, answerrrsigs, answer, authority, additional, edns, nsid
from abc import ABC
from contextlib import contextmanager
import logging
import os
import struct
import time
from typing import ( # noqa
Any, Callable, Dict, Iterator, List, Mapping, Optional, Tuple, Sequence)
import dns.exception
import dns.message
import lmdb
from .dataformat import QID
from .typing import ResolverID, QKey, WireFormat
from .typing import ResolverID, QID, QKey, WireFormat
BIN_FORMAT_VERSION = '2018-05-21'
......@@ -137,6 +136,7 @@ class DNSReply:
TIMEOUT_INT = 4294967295
def __init__(self, wire: Optional[WireFormat], time_: float = 0) -> None:
if wire is None:
......@@ -197,6 +197,14 @@ class DNSReply:
return reply, buff[offset:]
def parse_wire(
) -> Tuple[Optional[dns.message.Message], Optional[str]]:
return dns.message.from_wire(self.wire), self.WIREFORMAT_VALID
except dns.exception.FormError as exc:
return None, type(exc).__name__
class DNSRepliesFactory:
"""Thread-safe factory to parse DNSReply objects from binary blob."""
......@@ -227,22 +235,6 @@ class DNSRepliesFactory:
return b''.join(data)
def decode_parsed(
replies: Mapping[ResolverID, DNSReply]
) -> Mapping[ResolverID, dns.message.Message]:
answers = {} # type: Dict[ResolverID, dns.message.Message]
for resolver, reply in replies.items():
if reply.timeout:
answers[resolver] = None
answers[resolver] = dns.message.from_wire(reply.wire)
except Exception as exc:
logging.warning('Failed to decode DNS message from wire format: %s', exc)
return answers
class Database(ABC):
DB_NAME = b''
......@@ -7,6 +7,7 @@ import dns.rdatatype
from dns.rrset import RRset
import dns.message
from .database import DNSReply
from .typing import FieldLabel, MismatchValue, ResolverID
......@@ -111,7 +112,11 @@ def compare_rrs_types(exp_val: RRset, got_val: RRset, compare_rrsigs: bool):
tuple(key_to_text(i) for i in sorted(got_types)))
def match_part(exp_msg, got_msg, criteria): # pylint: disable=inconsistent-return-statements
def match_part( # pylint: disable=inconsistent-return-statements
exp_msg: dns.message.Message,
got_msg: dns.message.Message,
criteria: FieldLabel
""" Compare scripted reply to given message using single criteria. """
if criteria == 'opcode':
return compare_val(exp_msg.opcode(), got_msg.opcode())
......@@ -168,26 +173,40 @@ def match_part(exp_msg, got_msg, criteria): # pylint: disable=inconsistent-retu
def match(
expected: dns.message.Message,
got: dns.message.Message,
expected: DNSReply,
got: DNSReply,
match_fields: Sequence[FieldLabel]
) -> Iterator[Tuple[FieldLabel, DataMismatch]]:
""" Compare scripted reply to given message based on match criteria. """
if expected is None or got is None:
if expected is not None:
exp_msg, exp_res = expected.parse_wire()
got_msg, got_res = got.parse_wire()
exp_malformed = exp_res != DNSReply.WIREFORMAT_VALID
got_malformed = got_res != DNSReply.WIREFORMAT_VALID
if expected.timeout or got.timeout:
if not expected.timeout:
yield 'timeout', DataMismatch('answer', 'timeout')
if got is not None:
if not got.timeout:
yield 'timeout', DataMismatch('timeout', 'answer')
return # don't attempt to match any other fields if one answer is timeout
elif exp_malformed or got_malformed:
if exp_res == got_res:
'match: DNS replies malformed in the same way! (%s)', exp_res)
yield 'malformed', DataMismatch(exp_res, got_res)
if expected.timeout or got.timeout or exp_malformed or got_malformed:
return # don't attempt to match any other fields
for criteria in match_fields:
match_part(expected, got, criteria)
match_part(exp_msg, got_msg, criteria)
except DataMismatch as ex:
yield criteria, ex
def diff_pair(
answers: Mapping[ResolverID, dns.message.Message],
answers: Mapping[ResolverID, DNSReply],
criteria: Sequence[FieldLabel],
name1: ResolverID,
name2: ResolverID
......@@ -196,7 +215,7 @@ def diff_pair(
def transitive_equality(
answers: Mapping[ResolverID, dns.message.Message],
answers: Mapping[ResolverID, DNSReply],
criteria: Sequence[FieldLabel],
resolvers: Sequence[ResolverID]
) -> bool:
......@@ -213,7 +232,7 @@ def transitive_equality(
def compare(
answers: Mapping[ResolverID, dns.message.Message],
answers: Mapping[ResolverID, DNSReply],
criteria: Sequence[FieldLabel],
target: ResolverID
) -> Tuple[bool, Optional[Mapping[FieldLabel, DataMismatch]]]:
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment