Skip to content
Snippets Groups Projects
Commit 62117a3a authored by Marek Vavruša's avatar Marek Vavruša
Browse files

tests/integration: support range address query

parent 14c05ee2
Branches
Tags
No related merge requests found
......@@ -5,7 +5,7 @@ import dns.rcode
class Entry:
"""
Data entry represents prescripted message and extra metadata, notably match criteria and reply adjustments.
Data entry represents scripted message and extra metadata, notably match criteria and reply adjustments.
"""
# Globals
......@@ -20,7 +20,7 @@ class Entry:
self.message = dns.message.Message()
def match_part(self, code, msg):
""" Compare prescripted reply to given message using single criteria. """
""" Compare scripted reply to given message using single criteria. """
if code not in self.match_fields and 'all' not in self.match_fields:
return True
expected = self.message
......@@ -44,7 +44,7 @@ class Entry:
raise Exception('unknown match request "%s"' % code)
def match(self, msg):
""" Compare prescripted reply to given message based on match criteria. """
""" Compare scripted reply to given message based on match criteria. """
match_fields = self.match_fields
if 'all' in match_fields:
match_fields = ('flags', 'question', 'answer', 'authority', 'additional')
......@@ -59,7 +59,7 @@ class Entry:
self.match_fields = fields
def adjust_reply(self, query):
""" Copy prescripted reply and adjust to received query. """
""" Copy scripted reply and adjust to received query. """
answer = self.message
if 'copy_id' in self.adjust_fields:
answer.id = query.id
......@@ -141,19 +141,26 @@ class Entry:
class Range:
"""
Range represents a set of prescripted queries valid for given step range.
Range represents a set of scripted queries valid for given step range.
"""
def __init__(self, a, b):
""" Initialize reply range. """
self.a = a
self.b = b
self.address = None
self.stored = []
def add(self, entry):
""" Append a prescripted response to the range"""
""" Append a scripted response to the range"""
self.stored.append(entry)
def eligible(self, id, address):
""" Return true if this range is eligible for fetching reply. """
if self.a <= id <= self.b:
return None in (self.address, address) or (self.address == address)
return False
def reply(self, query):
""" Find matching response to given query. """
for candidate in self.stored:
......@@ -226,14 +233,14 @@ class Scenario:
self.steps = []
self.current_step = None
def reply(self, query):
def reply(self, query, address = None):
""" Attempt to find a range reply for a query. """
id = 0
step_id = 0
if self.current_step is not None:
id = self.current_step.id
step_id = self.current_step.id
# Find current valid query response range
for rng in self.ranges:
if id >= rng.a and id <= rng.b:
if rng.eligible(step_id, address):
return rng.reply(query)
def play(self, ctx):
......
......@@ -16,12 +16,13 @@ def send_message(stream, message):
stream.send(struct.pack('!H', len(message)) + message)
class TestServer:
""" This simulates TCP DNS server returning prescripted or mirror DNS responses. """
""" This simulates TCP DNS server returning scripted or mirror DNS responses. """
def __init__(self, scenario, type = socket.AF_UNIX, address = '.test_server.sock', port = 0):
""" Initialize server instance. """
self.is_active = False
self.thread = None
self.client_address = None
self.sock = socket.socket(type, socket.SOCK_STREAM)
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if type == socket.AF_UNIX:
......@@ -46,7 +47,7 @@ class TestServer:
return False
response = dns.message.make_response(query)
if self.scenario is not None:
response = self.scenario.reply(query)
response = self.scenario.reply(query, self.client_address)
if response:
send_message(client, response)
return True
......@@ -69,7 +70,8 @@ class TestServer:
to_read, _, to_error = select.select(clients, [], clients, 0.1)
for sock in to_read:
if sock == self.sock:
clients.append(sock.accept()[0])
client_sock, _ = sock.accept()
clients.append(client_sock)
else:
if not self.handle(sock):
to_error.append(sock)
......@@ -88,8 +90,9 @@ class TestServer:
if os.path.exists(address):
os.remove(address)
def client(self):
def client(self, dst_address = None):
""" Return connected client. """
self.client_address = dst_address
sock = socket.socket(self.sock_type, socket.SOCK_STREAM)
sock.connect(self.sock.getsockname())
return sock
......
......@@ -137,14 +137,15 @@ static PyObject* set_server(PyObject *self, PyObject *args)
static PyObject* test_connect(PyObject *self, PyObject *args)
{
/* Fetch a new client */
PyObject *result = PyObject_CallMethod(mock_server, "client", "");
if (result == NULL) {
struct sockaddr_storage addr;
sockaddr_set(&addr, AF_INET, "127.0.0.1", 0);
int sock = net_connected_socket(SOCK_STREAM, &addr, NULL, 0);
if (sock < 0) {
return NULL;
}
int ret = 0;
bool test_passed = true;
int sock = dup(PyObject_AsFileDescriptor(result));
knot_pkt_t *query = NULL, *reply = NULL;
/* Send and receive a query. */
......@@ -164,7 +165,7 @@ static PyObject* test_connect(PyObject *self, PyObject *args)
}
finish:
Py_DECREF(result);
close(sock);
knot_pkt_free(&query);
knot_pkt_free(&reply);
if (test_passed) {
......@@ -248,12 +249,10 @@ int udp_send_msg(int fd, const uint8_t *msg, size_t msglen,
int net_connected_socket(int type, const struct sockaddr_storage *dst_addr,
const struct sockaddr_storage *src_addr, unsigned flags)
{
char dst_addr_str[SOCKADDR_STRLEN], src_addr_str[SOCKADDR_STRLEN];
sockaddr_tostr(dst_addr_str, sizeof(dst_addr_str), dst_addr);
sockaddr_tostr(src_addr_str, sizeof(src_addr_str), src_addr);
fprintf(stderr, "%s (%d, %s, %s, %u)\n", __func__, type, dst_addr_str, src_addr_str, flags);
char addr_str[SOCKADDR_STRLEN];
sockaddr_tostr(addr_str, sizeof(addr_str), dst_addr);
PyObject *result = PyObject_CallMethod(mock_server, "client", "");
PyObject *result = PyObject_CallMethod(mock_server, "client", "s", addr_str);
if (result == NULL) {
return -1;
}
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment