Skip to content
Snippets Groups Projects
Commit 5303da69 authored by Grigorii Demidov's avatar Grigorii Demidov
Browse files

tests: raw query support

parent 94d11c7d
Branches
Tags
No related merge requests found
......@@ -2,6 +2,7 @@ import dns.message
import dns.rrset
import dns.rcode
import dns.dnssec
import binascii
class Entry:
"""
......@@ -20,6 +21,9 @@ class Entry:
self.origin = '.'
self.message = dns.message.Message()
self.sections = []
self.is_raw_data_entry = False
self.raw_data_pending = False
self.raw_data = None
def match_part(self, code, msg):
""" Compare scripted reply to given message using single criteria. """
......@@ -66,6 +70,19 @@ class Entry:
except Exception as e:
raise Exception("%s: %s" % (code, str(e)))
def cmp_raw(self, raw_value):
if self.is_raw_data_entry is False:
raise Exception("entry.cmp_raw() misuse")
expected = None
if self.raw_data is not None:
expected = binascii.hexlify(self.raw_data)
got = None
if raw_value is not None:
got = binascii.hexlify(raw_value)
if expected != got:
print "expected '",expected,"', got '",got,"'"
raise Exception("comparsion failed")
def set_match(self, fields):
""" Set conditions for message comparison [all, flags, question, answer, authority, additional] """
self.match_fields = fields
......@@ -101,6 +118,10 @@ class Entry:
self.message.ednsflags = dns.flags.edns_from_text(' '.join(eflags))
self.message.set_rcode(rcode)
def begin_raw(self):
""" Set raw data pending flag. """
self.raw_data_pending = True
def begin_section(self, section):
""" Begin packet section. """
self.section = section
......@@ -108,17 +129,28 @@ class Entry:
def add_record(self, owner, args):
""" Add record to current packet section. """
rr = self.__rr_from_str(owner, args)
if self.section == 'QUESTION':
self.message.question.append(rr)
elif self.section == 'ANSWER':
self.message.answer.append(rr)
elif self.section == 'AUTHORITY':
self.message.authority.append(rr)
elif self.section == 'ADDITIONAL':
self.message.additional.append(rr)
if self.raw_data_pending is True:
if self.raw_data == None:
if owner == 'NULL':
self.raw_data = None
else:
self.raw_data = binascii.unhexlify(owner)
else:
raise Exception('raw data already set in this entry')
self.raw_data_pending = False
self.is_raw_data_entry = True
else:
raise Exception('bad section %s' % self.section)
rr = self.__rr_from_str(owner, args)
if self.section == 'QUESTION':
self.message.question.append(rr)
elif self.section == 'ANSWER':
self.message.answer.append(rr)
elif self.section == 'AUTHORITY':
self.message.authority.append(rr)
elif self.section == 'ADDITIONAL':
self.message.additional.append(rr)
else:
raise Exception('bad section %s' % self.section)
def __rr_from_str(self, owner, args):
""" Parse RR from tokenized string. """
......@@ -217,6 +249,8 @@ class Step:
self.args = extra_args
self.data = []
self.has_data = self.type not in Step.require_data
self.answer = None
self.raw_answer = None
def add(self, entry):
""" Append a data entry to this step. """
......@@ -241,21 +275,33 @@ class Step:
""" Compare answer from previously resolved query. """
if len(self.data) == 0:
raise Exception("response definition required")
if ctx.last_answer is None:
raise Exception("no answer from preceding query")
expected = self.data[0]
expected.match(ctx.last_answer)
if expected.is_raw_data_entry is True:
expected.cmp_raw(ctx.last_raw_answer);
else:
if ctx.last_answer is None :
raise Exception("no answer from preceding query")
expected.match(ctx.last_answer)
def __query(self, ctx):
""" Resolve a query. """
if len(self.data) == 0:
raise Exception("query definition required")
msg = self.data[0].message
msg.use_edns(edns = 1)
self.answer = ctx.resolve(msg.to_wire())
if self.answer is not None:
self.answer = dns.message.from_wire(self.answer)
ctx.last_answer = self.answer
if self.data[0].is_raw_data_entry is True:
data_to_wire = self.data[0].raw_data
else:
msg = self.data[0].message
msg.use_edns(edns = 1)
data_to_wire = msg.to_wire()
self.raw_answer = ctx.resolve(data_to_wire)
ctx.last_raw_answer = self.raw_answer
if self.raw_answer is not None:
self.answer = dns.message.from_wire(self.raw_answer)
else:
self.answer = None
ctx.last_answer = self.answer
def __time_passes(self, ctx):
""" Modify system time. """
......@@ -282,18 +328,24 @@ class Scenario:
# Find current valid query response range
for rng in self.ranges:
if rng.eligible(step_id, address):
return rng.reply(query)
return (rng.reply(query), False)
# Find any prescripted one-shot replies
for step in self.steps:
if step.id <= step_id or step.type != 'REPLY':
continue
try:
candidate = step.data[0]
candidate.match(query)
step.data.remove(candidate)
return candidate.adjust_reply(query)
if candidate.is_raw_data_entry is False:
candidate.match(query)
step.data.remove(candidate)
answer = candidate.adjust_reply(query)
return (answer, False)
else:
answer = candidate.raw_data
return (answer, True)
except:
pass
return (None, True)
def play(self, ctx):
""" Play given scenario. """
......
......@@ -2,6 +2,7 @@ import threading
import select, socket, struct, sys, os, time
import dns.message
import test
import binascii
# Test debugging
TEST_DEBUG = 0
......@@ -27,7 +28,6 @@ def sendto_message(stream, message, addr):
""" Send DNS/UDP message. """
if TEST_DEBUG > 0:
syn_message("outgoing data")
message = message.to_wire()
stream.sendto(message, addr)
if TEST_DEBUG > 0:
syn_message("[Python] sent", len(message), "bytes to", addr)
......@@ -109,16 +109,20 @@ class TestServer:
syn_message("Empty query")
return False
response = dns.message.make_response(query)
is_raw_data = False
if self.scenario is not None:
if TEST_DEBUG > 0:
syn_message("get scenario reply")
response = self.scenario.reply(query, client_address)
response, is_raw_data = self.scenario.reply(query, client_address)
if response:
if TEST_DEBUG > 0:
syn_message("sending answer")
if TEST_DEBUG > 1:
syn_message("=========\n",response,"=========")
sendto_message(client, response, addr)
if is_raw_data is False:
sendto_message(client, response.to_wire(), addr)
else:
sendto_message(client, response, addr)
return True
else:
if TEST_DEBUG > 0:
......
......@@ -125,7 +125,8 @@ static PyObject* resolve(PyObject *self, PyObject *args)
const char *query_wire = NULL;
size_t query_size = 0;
if (!PyArg_ParseTuple(args, "s#", &query_wire, &query_size)) {
return NULL;
PyObject *out = Py_BuildValue("s#", NULL, 0);
return out;
}
/* Prepare input */
......@@ -133,8 +134,9 @@ static PyObject* resolve(PyObject *self, PyObject *args)
assert(query);
int ret = knot_pkt_parse(query, 0);
if (ret != KNOT_EOK) {
PyObject *out = Py_BuildValue("s#", NULL, 0);
knot_pkt_free(&query);
return NULL;
return out;
}
/* Resolve query */
......
......@@ -41,6 +41,8 @@ def parse_entry(op, args, file_in):
out.set_adjust(args)
elif op == 'SECTION':
out.begin_section(args[0])
elif op == 'RAW':
out.begin_raw()
else:
out.add_record(op, args)
return out
......@@ -141,7 +143,8 @@ def play_object(path):
try:
mock_ctx.set_server(server)
if TEST_DEBUG > 0:
print('--- server listening at %s ---' % str(server.address()))
print('--- UDP test server started at')
print(server.address())
print('--- scenario parsed, any key to continue ---')
sys.stdin.readline()
scenario.play(mock_ctx)
......@@ -149,7 +152,6 @@ def play_object(path):
server.stop()
mock_ctx.deinit()
def test_platform(*args):
if sys.platform == 'windows':
raise Exception('not supported at all on Windows')
......
This diff is collapsed.
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment