Commit 0f5c6d5d authored by Tomas Krizek's avatar Tomas Krizek

dbhelper.DNSRepliesFactory: create factory and add tests

parent 81927fd2
import logging
import os
import struct
import sys
from typing import Any, Dict, Iterator, Optional, Tuple # noqa
from typing import Any, Dict, Iterator, Optional, Tuple, Sequence # noqa
import lmdb
......@@ -170,7 +171,8 @@ class DNSReply:
@classmethod
def from_binary(cls, buff: bytes) -> Tuple['DNSReply', bytes]:
assert len(buff) >= (cls.SIZEOF_INT + cls.SIZEOF_SHORT), "Malformed bin format"
if len(buff) < (cls.SIZEOF_INT + cls.SIZEOF_SHORT):
raise ValueError('Missing data in binary format')
offset = 0
time_int, = struct.unpack_from('<I', buff, offset)
offset += cls.SIZEOF_INT
......@@ -179,6 +181,9 @@ class DNSReply:
wire = buff[offset:(offset+length)]
offset += length
if len(wire) != length:
raise ValueError('Missing data in binary format')
if time_int == cls.TIMEOUT_INT:
time = float('+inf')
else:
......@@ -188,5 +193,21 @@ class DNSReply:
return reply, buff[offset:]
class DNSRepliesFactory:
def __init__(self, servers: Sequence[ResolverID]) -> None:
if not servers:
raise ValueError('One or more servers have to be specified')
self.servers = servers
def parse(self, buff: bytes) -> Dict[ResolverID, DNSReply]:
replies = {}
for server in self.servers:
reply, buff = DNSReply.from_binary(buff)
replies[server] = reply
if buff:
logging.warning('Trailing data in buffer')
return replies
# upon import, check we're on a little endian platform
assert sys.byteorder == 'little', 'Big endian platforms are not supported'
import pytest
from respdiff.dbhelper import DNSReply
from respdiff.dbhelper import DNSReply, DNSRepliesFactory
def create_reply(wire, time):
......@@ -81,3 +81,20 @@ def test_dns_reply_deserialization(binary, reply, remaining):
got_reply, buff = DNSReply.from_binary(binary)
assert reply == got_reply
assert buff == remaining
def test_dns_replies_factory_init():
with pytest.raises(ValueError):
DNSRepliesFactory([])
rf = DNSRepliesFactory(['a'])
replies = rf.parse(DR_TIMEOUT_BIN)
assert replies['a'] == DR_TIMEOUT
rf2 = DNSRepliesFactory(['a', 'b'])
replies = rf2.parse(DR_A_0_BIN + DR_ABCD_1_BIN)
assert replies['a'] == DR_A_0
assert replies['b'] == DR_ABCD_1
with pytest.raises(ValueError):
rf2.parse(DR_A_0_BIN + b'a')
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment