Commit 95a49bca authored by Tomas Krizek's avatar Tomas Krizek

dbhelper: move DNSReply, WireFormat

parent 93fa182b
......@@ -13,18 +13,11 @@ import dns.rrset
# replace Any with 'MismatchValue' once nested types are supported with mypy
MismatchValue = Union[str, dns.rrset.RRset, Sequence[Any]]
QID = int
WireFormat = bytes
FieldLabel = str
RestoreFunction = Optional[Callable[[Any], Any]]
SaveFunction = Optional[Callable[[Any], Any]]
class Reply:
def __init__(self, wire: Optional[WireFormat], duration: float) -> None:
self.wire = wire
self.duration = duration
class DataMismatch(Exception):
def __init__(self, exp_val: MismatchValue, got_val: MismatchValue) -> None:
def convert_val_type(val: Any) -> MismatchValue:
......
import os
import struct
from typing import Any, Dict, Iterator, Tuple # NOQA: needed for type hint in comment
from typing import Any, Dict, Iterator, Optional, Tuple # noqa
import lmdb
......@@ -8,6 +8,7 @@ from dataformat import QID
QKey = bytes
WireFormat = bytes
def qid2key(qid: QID) -> QKey:
......@@ -124,3 +125,9 @@ class LMDB:
cur = txn.cursor(db)
for key, blob in cur:
yield key, blob
class DNSReply:
def __init__(self, wire: Optional[WireFormat], duration: float) -> None:
self.wire = wire
self.duration = duration
......@@ -12,9 +12,9 @@ from typing import ( # noqa
Union)
import cli
from dbhelper import LMDB, qid2key, key2qid, QKey
from dbhelper import LMDB, qid2key, key2qid, QKey, WireFormat
import diffsum
from dataformat import Diff, DiffReport, FieldLabel, ReproData, WireFormat, QID # noqa
from dataformat import Diff, DiffReport, FieldLabel, ReproData, QID # noqa
import msgdiff
import sendrecv
from sendrecv import ResolverID, RepliesBlob
......
......@@ -13,10 +13,8 @@ import dns.rdatatype
from tabulate import tabulate
import cli
from dbhelper import LMDB, qid2key
from dataformat import (
DataMismatch, DiffReport, FieldLabel, Summary,
QID, WireFormat)
from dbhelper import LMDB, qid2key, WireFormat
from dataformat import DataMismatch, DiffReport, FieldLabel, Summary, QID
DEFAULT_LIMIT = 10
......
......@@ -12,9 +12,9 @@ from dns.rrset import RRset
import cli
from dataformat import (
DataMismatch, DiffReport, Disagreements, DisagreementsCounter, FieldLabel, MismatchValue,
Reply, QID)
from dbhelper import LMDB, key2qid
DataMismatch, DiffReport, Disagreements, DisagreementsCounter,
FieldLabel, MismatchValue, QID)
from dbhelper import DNSReply, LMDB, key2qid
from sendrecv import ResolverID
......@@ -143,7 +143,7 @@ def match(
def decode_wire_dict(
wire_dict: Mapping[ResolverID, Reply]
wire_dict: Mapping[ResolverID, DNSReply]
) -> Mapping[ResolverID, dns.message.Message]:
answers = {} # type: Dict[ResolverID, dns.message.Message]
for k, v in wire_dict.items():
......
......@@ -23,8 +23,7 @@ from typing import Any, Dict, List, Mapping, Sequence, Tuple # noqa: type hints
import dns.inet
import dns.message
from dataformat import Reply, WireFormat
from dbhelper import QKey
from dbhelper import DNSReply, QKey, WireFormat
ResolverID = str
......@@ -47,7 +46,7 @@ __ignore_timeout = False
__timeout = 10
__time_delay_min = 0
__time_delay_max = 0
__timeout_replies = {} # type: Dict[float, Reply]
__timeout_replies = {} # type: Dict[float, DNSReply]
def module_init(args: Namespace) -> None:
......@@ -140,7 +139,7 @@ def get_resolvers(
return resolvers
def _check_timeout(replies: Mapping[ResolverID, Reply]) -> None:
def _check_timeout(replies: Mapping[ResolverID, DNSReply]) -> None:
for resolver, reply in replies.items():
timeouts = __worker_state.timeouts
if reply.wire is not None:
......@@ -202,11 +201,11 @@ def send_recv_parallel(
selector: Selector,
sockets: ResolverSockets,
timeout: float
) -> Tuple[Mapping[ResolverID, Reply], ReinitFlag]:
replies = {} # type: Dict[ResolverID, Reply]
) -> Tuple[Mapping[ResolverID, DNSReply], ReinitFlag]:
replies = {} # type: Dict[ResolverID, DNSReply]
streammsg = None
# optimization: create only one timeout_reply object per timeout value
timeout_reply = __timeout_replies.setdefault(timeout, Reply(None, timeout))
timeout_reply = __timeout_replies.setdefault(timeout, DNSReply(None, timeout))
start_time = time.perf_counter()
end_time = start_time + timeout
for _, sock, isstream in sockets:
......@@ -238,7 +237,7 @@ def send_recv_parallel(
# assert len(wire) > 14
if dgram[0:2] != wire[0:2]:
continue # wrong msgid, this might be a delayed answer - ignore it
replies[name] = Reply(wire, time.perf_counter() - start_time)
replies[name] = DNSReply(wire, time.perf_counter() - start_time)
# set missing replies as timeout
for resolver, *_ in sockets: # type: ignore # python/mypy#465
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment