Commit 81927fd2 authored by Tomas Krizek's avatar Tomas Krizek

dbhelper.DNSReply: support binary format

parent ee3d740c
import os
import struct
import sys
from typing import Any, Dict, Iterator, Optional, Tuple # noqa
import lmdb
......@@ -7,6 +8,8 @@ import lmdb
from dataformat import QID
ResolverID = str
RepliesBlob = bytes
QKey = bytes
WireFormat = bytes
......@@ -128,6 +131,10 @@ class LMDB:
class DNSReply:
TIMEOUT_INT = 4294967295
SIZEOF_INT = 4
SIZEOF_SHORT = 2
def __init__(self, wire: Optional[WireFormat], time: float = 0) -> None:
if wire is None:
self.wire = b''
......@@ -137,5 +144,49 @@ class DNSReply:
self.time = time
@property
def timeout(self):
def timeout(self) -> bool:
return self.time == float('+inf')
def __eq__(self, other) -> bool:
if self.timeout and other.timeout:
return True
return self.wire == other.wire and \
abs(self.time - other.time) < 10 ** -7
@property
def time_int(self) -> int:
if self.time == float('+inf'):
return self.TIMEOUT_INT
value = round(self.time * (10 ** 6))
if value > self.TIMEOUT_INT:
raise ValueError('Maximum time value exceeded')
return value
@property
def binary(self) -> bytes:
length = len(self.wire)
assert length < 2**(self.SIZEOF_SHORT*8), 'Maximum wire format length exceeded'
return struct.pack('<I', self.time_int) + struct.pack('<H', length) + self.wire
@classmethod
def from_binary(cls, buff: bytes) -> Tuple['DNSReply', bytes]:
assert len(buff) >= (cls.SIZEOF_INT + cls.SIZEOF_SHORT), "Malformed bin format"
offset = 0
time_int, = struct.unpack_from('<I', buff, offset)
offset += cls.SIZEOF_INT
length, = struct.unpack_from('<H', buff, offset)
offset += cls.SIZEOF_SHORT
wire = buff[offset:(offset+length)]
offset += length
if time_int == cls.TIMEOUT_INT:
time = float('+inf')
else:
time = time_int / (10 ** 6)
reply = DNSReply(wire, time)
return reply, buff[offset:]
# upon import, check we're on a little endian platform
assert sys.byteorder == 'little', 'Big endian platforms are not supported'
......@@ -12,12 +12,11 @@ from typing import ( # noqa
Union)
import cli
from dbhelper import LMDB, qid2key, key2qid, QKey, WireFormat
from dbhelper import LMDB, key2qid, ResolverID, RepliesBlob, qid2key, QKey, WireFormat
import diffsum
from dataformat import Diff, DiffReport, FieldLabel, ReproData, QID # noqa
import msgdiff
import sendrecv
from sendrecv import ResolverID, RepliesBlob
T = TypeVar('T')
......
......@@ -15,8 +15,7 @@ import cli
from dataformat import (
DataMismatch, DiffReport, Disagreements, DisagreementsCounter,
FieldLabel, MismatchValue, QID)
from dbhelper import DNSReply, LMDB, key2qid
from sendrecv import ResolverID
from dbhelper import DNSReply, LMDB, key2qid, ResolverID
lmdb = None
......
......@@ -23,14 +23,12 @@ from typing import Any, Dict, List, Mapping, Sequence, Tuple # noqa: type hints
import dns.inet
import dns.message
from dbhelper import DNSReply, QKey, WireFormat
from dbhelper import DNSReply, RepliesBlob, ResolverID, QKey, WireFormat
ResolverID = str
IP = str
Protocol = str
Port = int
RepliesBlob = bytes
IsStreamFlag = bool # Is message preceeded by RFC 1035 section 4.2.2 length?
ReinitFlag = bool
Selector = selectors.BaseSelector
......
import pytest
from respdiff.dbhelper import DNSReply
def create_reply(wire, time):
if time is not None:
return DNSReply(wire, time)
return DNSReply(wire)
def test_dns_reply_timeout():
reply = DNSReply(None)
assert reply.timeout
assert reply.time == float('+inf')
@pytest.mark.parametrize('wire1, time1, wire2, time2, equals', [
(None, None, None, None, True),
(None, None, None, 1, True),
(b'', None, b'', None, True),
(b'', None, b'', 1, False),
(b'a', None, b'a', None, True),
(b'a', None, b'b', None, False),
(b'a', None, b'aa', None, False),
])
def test_dns_reply_equals(wire1, time1, wire2, time2, equals):
r1 = create_reply(wire1, time1)
r2 = create_reply(wire2, time2)
assert (r1 == r2) == equals
@pytest.mark.parametrize('time, time_int', [
(None, 0),
(0, 0),
(1.43, 1430000),
(0.4591856, 459186),
])
def test_dns_reply_time_int(time, time_int):
reply = create_reply(b'', time)
assert reply.time_int == time_int
DR_TIMEOUT = DNSReply(None, None)
DR_TIMEOUT_BIN = b'\xff\xff\xff\xff\x00\x00'
DR_EMPTY_0 = DNSReply(b'')
DR_EMPTY_0_BIN = b'\x00\x00\x00\x00\x00\x00'
DR_EMPTY_1 = DNSReply(b'', 1)
DR_EMPTY_1_BIN = b'\x40\x42\x0f\x00\x00\x00'
DR_A_0 = DNSReply(b'a')
DR_A_0_BIN = b'\x00\x00\x00\x00\x01\x00a'
DR_A_1 = DNSReply(b'a', 1)
DR_A_1_BIN = b'\x40\x42\x0f\x00\x01\x00a'
DR_ABCD_1 = DNSReply(b'abcd', 1)
DR_ABCD_1_BIN = b'\x40\x42\x0f\x00\x04\x00abcd'
@pytest.mark.parametrize('reply, binary', [
(DR_TIMEOUT, DR_TIMEOUT_BIN),
(DR_EMPTY_0, DR_EMPTY_0_BIN),
(DR_EMPTY_1, DR_EMPTY_1_BIN),
(DR_A_0, DR_A_0_BIN),
(DR_A_1, DR_A_1_BIN),
(DR_ABCD_1, DR_ABCD_1_BIN),
])
def test_dns_reply_serialization(reply, binary):
assert reply.binary == binary
@pytest.mark.parametrize('binary, reply, remaining', [
(DR_TIMEOUT_BIN, DR_TIMEOUT, b''),
(DR_EMPTY_0_BIN, DR_EMPTY_0, b''),
(DR_EMPTY_1_BIN, DR_EMPTY_1, b''),
(DR_A_0_BIN, DR_A_0, b''),
(DR_A_1_BIN, DR_A_1, b''),
(DR_ABCD_1_BIN, DR_ABCD_1, b''),
(DR_A_1_BIN + b'a', DR_A_1, b'a'),
(DR_ABCD_1_BIN + b'bcd', DR_ABCD_1, b'bcd'),
])
def test_dns_reply_deserialization(binary, reply, remaining):
got_reply, buff = DNSReply.from_binary(binary)
assert reply == got_reply
assert buff == remaining
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment