scenario.py 32.9 KB
Newer Older
Marek Vavruša's avatar
Marek Vavruša committed
1 2 3
import dns.rrset
import dns.rcode
import dns.dnssec
4
import dns.tsigkeyring
Marek Vavruša's avatar
Marek Vavruša committed
5
import binascii
6
import socket, struct
7
import os, sys, errno
8
import itertools, random, string
Marek Vavruša's avatar
Marek Vavruša committed
9 10
import time
from datetime import datetime
11
from dprint import dprint
12
from testserver import recvfrom_msg, sendto_msg
Marek Vavruša's avatar
Marek Vavruša committed
13

14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
# If PCAP is pointed to a file, queries/responses from the test are captured
g_pcap = None
if 'PCAP' in os.environ:
    import dpkt
    g_pcap = dpkt.pcap.Writer(open(os.environ['PCAP'], 'wb'))
def log_packet(sock, buf, query = True):
    """ Fake underlying layers and store packet in a pcap. """
    if not g_pcap:
        return
    src, dst = (sock.getpeername()[0], 53), sock.getsockname()
    if query:
        src, dst = sock.getsockname(), (sock.getpeername()[0], 53)
    # Synthesise IP/UDP/Eth layers
    transport = dpkt.udp.UDP(data = buf, dport = dst[1], sport = src[1])
    transport.ulen = len(transport)
    ip = dpkt.ip.IP(src = socket.inet_pton(sock.family, src[0]),
                    dst = socket.inet_pton(sock.family, dst[0]), p = dpkt.ip.IP_PROTO_UDP)
    ip.data = transport
    ip.len = len(ip)
    eth = dpkt.ethernet.Ethernet(data = ip)
    g_pcap.writepkt(eth.pack())

36 37 38 39
# Global statistics
g_rtt = 0.0
g_nqueries = 0

40 41 42 43
#
# Element comparators
#

44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
def create_rr(owner, args, ttl = 3600, rdclass = 'IN', origin = '.'):
    """ Parse RR from tokenized string. """
    if not owner.endswith('.'):
        owner += origin
    try:
        ttl = dns.ttl.from_text(args[0])
        args.pop(0)
    except:
        pass  # optional
    try:
        rdclass = dns.rdataclass.from_text(args[0])
        args.pop(0)
    except:
        pass  # optional
    rdtype = args.pop(0)
    rr = dns.rrset.from_text(owner, ttl, rdclass, rdtype)
    if len(args) > 0:
        if (rr.rdtype == dns.rdatatype.DS):
            # convert textual algorithm identifier to number
            args[1] = str(dns.dnssec.algorithm_from_text(args[1]))
        rd = dns.rdata.from_text(rr.rdclass, rr.rdtype, ' '.join(args), origin=dns.name.from_text(origin), relativize=False)
        rr.add(rd)
    return rr

68 69 70 71 72 73 74 75
def compare_rrs(expected, got):
    """ Compare lists of RR sets, throw exception if different. """
    for rr in expected:
        if rr not in got:
            raise Exception("expected record '%s'" % rr.to_text())
    for rr in got:
        if rr not in expected:
            raise Exception("unexpected record '%s'" % rr.to_text())
76 77 78 79
    if len(expected) != len(got):
        raise Exception("expected %s records but got %s records "
                        "(a duplicate RR somewhere?)"
                        % (len(expected), len(got)))
80 81
    return True

Petr Špaček's avatar
Petr Špaček committed
82

83 84 85 86 87 88 89 90 91 92 93 94
def compare_val(expected, got):
    """ Compare values, throw exception if different. """
    if expected != got:
        raise Exception("expected '%s', got '%s'" % (expected, got))
    return True

def compare_sub(got, expected):
    """ Check if got subdomain of expected, throw exception if different. """
    if not expected.is_subdomain(got):
        raise Exception("expected subdomain of '%s', got '%s'" % (expected, got))
    return True

95
def replay_rrs(rrs, nqueries, destination, args = []):
96 97 98
    """ Replay list of queries and report statistics. """
    navail, queries = len(rrs), []
    chunksize = 16
99
    for i in range(nqueries if 'RAND' in args else navail):
100
        rr = rrs[i % navail]
101 102 103 104 105 106 107
        name = rr.name
        if 'RAND' in args:
            prefix = ''.join([random.choice(string.ascii_letters + string.digits) for n in range(8)])
            name = prefix + '.' + rr.name.to_text()
        msg = dns.message.make_query(name, rr.rdtype, rr.rdclass)
        if 'DO' in args:
            msg.want_dnssec(True)
108 109 110 111 112 113 114 115 116 117 118 119 120 121
        queries.append(msg.to_wire())
    # Make a UDP connected socket to the destination
    tstart = datetime.now()
    family = socket.AF_INET6 if ':' in destination[0] else socket.AF_INET
    sock = socket.socket(family, socket.SOCK_DGRAM)
    sock.connect(destination)
    sock.setblocking(False)
    # Play the query set
    # @NOTE: this is only good for relative low-speed replay
    rcvbuf = bytearray('\x00' * 512)
    nsent, nrcvd, nwait, navail = 0, 0, 0, len(queries)
    fdset = [sock]
    import select
    while nsent - nwait < nqueries:
Marek Vavrusa's avatar
Marek Vavrusa committed
122
        to_read, to_write, _ = select.select(fdset, fdset if nwait < chunksize else [], [], 0.5)
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
        if len(to_write) > 0:
            try:
                while nsent < nqueries and nwait < chunksize:
                    sock.send(queries[nsent % navail])
                    nwait += 1
                    nsent += 1
            except:
                pass # EINVAL
        if len(to_read) > 0:
            try:
                while nwait > 0:
                    sock.recv_into(rcvbuf)
                    nwait -= 1
                    nrcvd += 1
            except:
                pass
Marek Vavrusa's avatar
Marek Vavrusa committed
139 140 141
        if len(to_write) == 0 and len(to_read) == 0:
            nwait = 0 # Timeout, started dropping packets
            break
142 143
    return nsent, nrcvd

Marek Vavruša's avatar
Marek Vavruša committed
144 145 146 147 148 149 150 151 152 153
class Entry:
    """
    Data entry represents scripted message and extra metadata, notably match criteria and reply adjustments.
    """

    # Globals
    default_ttl = 3600
    default_cls = 'IN'
    default_rc = 'NOERROR'

154
    def __init__(self, lineno = 0):
Marek Vavruša's avatar
Marek Vavruša committed
155 156 157 158 159 160 161 162 163 164
        """ Initialize data entry. """
        self.match_fields = ['opcode', 'qtype', 'qname']
        self.adjust_fields = ['copy_id']
        self.origin = '.'
        self.message = dns.message.Message()
        self.message.use_edns(edns = 0, payload = 4096)
        self.sections = []
        self.is_raw_data_entry = False
        self.raw_data_pending = False
        self.raw_data = None
165
        self.lineno = lineno
166 167
        self.mandatory = False
        self.fired = 0;
Marek Vavruša's avatar
Marek Vavruša committed
168 169 170 171 172 173 174

    def match_part(self, code, msg):
        """ Compare scripted reply to given message using single criteria. """
        if code not in self.match_fields and 'all' not in self.match_fields:
            return True
        expected = self.message
        if code == 'opcode':
175
            return compare_val(expected.opcode(), msg.opcode())
Marek Vavruša's avatar
Marek Vavruša committed
176 177 178
        elif code == 'qtype':
            if len(expected.question) == 0:
                return True
179
            return compare_val(expected.question[0].rdtype, msg.question[0].rdtype)
Marek Vavruša's avatar
Marek Vavruša committed
180 181 182 183
        elif code == 'qname':
            if len(expected.question) == 0:
                return True
            qname = dns.name.from_text(msg.question[0].name.to_text().lower())
184
            return compare_val(expected.question[0].name, qname)
Marek Vavruša's avatar
Marek Vavruša committed
185 186 187 188
        elif code == 'subdomain':
            if len(expected.question) == 0:
                return True
            qname = dns.name.from_text(msg.question[0].name.to_text().lower())
189
            return compare_sub(expected.question[0].name, qname)
Marek Vavruša's avatar
Marek Vavruša committed
190
        elif code == 'flags':
191
            return compare_val(dns.flags.to_text(expected.flags), dns.flags.to_text(msg.flags))
192
        elif code == 'rcode':
193
            return compare_val(dns.rcode.to_text(expected.rcode()), dns.rcode.to_text(msg.rcode()))
Marek Vavruša's avatar
Marek Vavruša committed
194
        elif code == 'question':
195
            return compare_rrs(expected.question, msg.question)
Marek Vavrusa's avatar
Marek Vavrusa committed
196
        elif code == 'answer' or code == 'ttl':
197
            return compare_rrs(expected.answer, msg.answer)
Marek Vavruša's avatar
Marek Vavruša committed
198
        elif code == 'authority':
199
            return compare_rrs(expected.authority, msg.authority)
Marek Vavruša's avatar
Marek Vavruša committed
200
        elif code == 'additional':
201
            return compare_rrs(expected.additional, msg.additional)
202 203 204 205 206
        elif code == 'edns':
            if msg.edns != expected.edns:
                raise Exception('expected EDNS %d, got %d' % (expected.edns, msg.edns))
            if msg.payload != expected.payload:
                raise Exception('expected EDNS bufsize %d, got %d' % (expected.payload, msg.payload))
207 208 209 210 211 212 213 214 215
        elif code == 'nsid':
            nsid_opt = None
            for opt in expected.options:
                if opt.otype == dns.edns.NSID:
                    nsid_opt = opt
                    break
            # Find matching NSID
            for opt in msg.options:
                if opt.otype == dns.edns.NSID:
216 217
                    if nsid_opt == None:
                        raise Exception('unexpected NSID value "%s"' % opt.data)
218 219 220 221
                    if opt == nsid_opt:
                        return True
                    else:
                        raise Exception('expected NSID "%s", got "%s"' % (nsid_opt.data, opt.data))
222 223
            if nsid_opt:
                raise Exception('expected NSID "%s"' % nsid_opt.data)
Marek Vavruša's avatar
Marek Vavruša committed
224 225 226 227 228 229 230
        else:
            raise Exception('unknown match request "%s"' % code)

    def match(self, msg):
        """ Compare scripted reply to given message based on match criteria. """
        match_fields = self.match_fields
        if 'all' in match_fields:
231 232
            match_fields.remove('all')
            match_fields += ['flags'] + ['rcode'] + self.sections
Marek Vavruša's avatar
Marek Vavruša committed
233 234 235 236
        for code in match_fields:
            try:
                res = self.match_part(code, msg)
            except Exception as e:
237 238
                errstr = '%s in the response:\n%s' % (str(e), msg.to_text())
                raise Exception("line %d, \"%s\": %s" % (self.lineno, code, errstr))
Marek Vavruša's avatar
Marek Vavruša committed
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253

    def cmp_raw(self, raw_value):
        if self.is_raw_data_entry is False:
            raise Exception("entry.cmp_raw() misuse")
        expected = None
        if self.raw_data is not None:
            expected = binascii.hexlify(self.raw_data)
        got = None
        if raw_value is not None:
            got = binascii.hexlify(raw_value)
        if expected != got:
            print("expected '",expected,"', got '",got,"'")
            raise Exception("comparsion failed")

    def set_match(self, fields):
254
        """ Set conditions for message comparison [all, flags, question, answer, authority, additional, edns] """
Marek Vavruša's avatar
Marek Vavruša committed
255 256 257 258
        self.match_fields = fields

    def adjust_reply(self, query):
        """ Copy scripted reply and adjust to received query. """
259 260 261
        answer = dns.message.from_wire(self.message.to_wire(),
                                       xfr=self.message.xfr,
                                       one_rr_per_rrset=True)
Petr Špaček's avatar
Petr Špaček committed
262
        answer.use_edns(query.edns, query.ednsflags, options=self.message.options)
Marek Vavruša's avatar
Marek Vavruša committed
263 264
        if 'copy_id' in self.adjust_fields:
            answer.id = query.id
265 266 267
            # Copy letter-case if the template has QD
            if len(answer.question) > 0:
                answer.question[0].name = query.question[0].name
Marek Vavruša's avatar
Marek Vavruša committed
268 269
        if 'copy_query' in self.adjust_fields:
            answer.question = query.question
270 271
        # Re-set, as the EDNS might have reset the ext-rcode
        answer.set_rcode(self.message.rcode())
272 273 274 275 276

        # sanity check: adjusted answer should be almost the same
        assert len(answer.answer) == len(self.message.answer)
        assert len(answer.authority) == len(self.message.authority)
        assert len(answer.additional) == len(self.message.additional)
Marek Vavruša's avatar
Marek Vavruša committed
277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299
        return answer

    def set_adjust(self, fields):
        """ Set reply adjustment fields [copy_id, copy_query] """
        self.adjust_fields = fields

    def set_reply(self, fields):
        """ Set reply flags and rcode. """
        eflags = []
        flags = []
        rcode = dns.rcode.from_text(self.default_rc)
        for code in fields:
            if code == 'DO':
                eflags.append(code)
                continue
            try:
                rcode = dns.rcode.from_text(code)
            except:
                flags.append(code)
        self.message.flags = dns.flags.from_text(' '.join(flags))
        self.message.want_dnssec('DO' in eflags)
        self.message.set_rcode(rcode)

300 301 302 303
    def set_edns(self, fields):
        """ Set EDNS version and bufsize. """
        version = 0
        bufsize = 4096
304
        if len(fields) > 0 and fields[0].isdigit():
305
            version = int(fields.pop(0))
306
        if len(fields) > 0 and fields[0].isdigit():
307
            bufsize = int(fields.pop(0))
308 309 310
        if bufsize == 0:
            self.message.use_edns(False)
            return
311 312
        opts = []
        for v in fields:
Marek Vavrusa's avatar
Marek Vavrusa committed
313
            k, v = tuple(v.split('=')) if '=' in v else (v, True)
314
            if k.lower() == 'nsid':
315
                opts.append(dns.edns.GenericOption(dns.edns.NSID, '' if v == True else v))
316 317 318 319 320 321 322 323 324 325 326 327 328
            if k.lower() == 'subnet':
                net = v.split('/')
                family = socket.AF_INET6 if ':' in net[0] else socket.AF_INET
                subnet_addr = net[0]
                addr = socket.inet_pton(family, net[0])
                prefix = len(addr) * 8
                if len(net) > 1:
                    prefix = int(net[1])
                addr = addr[0 : (prefix + 7)/8]
                if prefix % 8 != 0: # Mask the last byte
                    addr = addr[:-1] + chr(ord(addr[-1]) & 0xFF << (8 - prefix % 8))
                opts.append(dns.edns.GenericOption(8, struct.pack("!HBB", 1 if family == socket.AF_INET else 2, prefix, 0) + addr))
        self.message.use_edns(edns = version, payload = bufsize, options = opts)
329

Marek Vavruša's avatar
Marek Vavruša committed
330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351
    def begin_raw(self):
        """ Set raw data pending flag. """
        self.raw_data_pending = True

    def begin_section(self, section):
        """ Begin packet section. """
        self.section = section
        self.sections.append(section.lower())

    def add_record(self, owner, args):
        """ Add record to current packet section. """
        if self.raw_data_pending is True:
            if self.raw_data == None:
                if owner == 'NULL':
                    self.raw_data = None
                else:
                    self.raw_data = binascii.unhexlify(owner)
            else:
                raise Exception('raw data already set in this entry')
            self.raw_data_pending = False
            self.is_raw_data_entry = True
        else:
352
            rr = create_rr(owner, args, ttl = self.default_ttl, rdclass = self.default_cls, origin = self.origin)
Marek Vavruša's avatar
Marek Vavruša committed
353
            if self.section == 'QUESTION':
354 355
                if rr.rdtype == dns.rdatatype.AXFR:
                    self.message.xfr = True
Marek Vavruša's avatar
Marek Vavruša committed
356 357 358 359 360 361 362 363 364 365
                self.__rr_add(self.message.question, rr)
            elif self.section == 'ANSWER':
                self.__rr_add(self.message.answer, rr)
            elif self.section == 'AUTHORITY':
                self.__rr_add(self.message.authority, rr)
            elif self.section == 'ADDITIONAL':
                self.__rr_add(self.message.additional, rr)
            else:
                raise Exception('bad section %s' % self.section)

Petr Špaček's avatar
Petr Špaček committed
366
    def use_tsig(self, fields):
367
        tsig_keyname = fields[0]
Petr Špaček's avatar
Petr Špaček committed
368 369 370
        tsig_secret = fields[1]
        keyring = dns.tsigkeyring.from_text({tsig_keyname: tsig_secret})
        self.message.use_tsig(keyring=keyring, keyname=tsig_keyname)
371

Marek Vavruša's avatar
Marek Vavruša committed
372
    def __rr_add(self, section, rr):
Petr Špaček's avatar
Petr Špaček committed
373
        """ Merge record to existing RRSet, or append to given section. """
374
        section.append(rr)
Marek Vavruša's avatar
Marek Vavruša committed
375

376 377 378
    def set_mandatory(self):
        self.mandatory = True

Marek Vavruša's avatar
Marek Vavruša committed
379 380 381 382 383 384 385 386 387
class Range:
    """
    Range represents a set of scripted queries valid for given step range.
    """

    def __init__(self, a, b):
        """ Initialize reply range. """
        self.a = a
        self.b = b
388
        self.addresses = set()
Marek Vavruša's avatar
Marek Vavruša committed
389
        self.stored = []
390 391 392 393 394
        self.args = {}
        self.received = 0
        self.sent = 0

    def __del__(self):
395
        dtag = '[ RANGE %d-%d ] %s' % (self.a, self.b, self.addresses)
396
        dprint(dtag, 'received: %d sent: %d' % (self.received, self.sent))
Marek Vavruša's avatar
Marek Vavruša committed
397 398 399 400 401 402 403 404

    def add(self, entry):
        """ Append a scripted response to the range"""
        self.stored.append(entry)

    def eligible(self, id, address):
        """ Return true if this range is eligible for fetching reply. """
        if self.a <= id <= self.b:
405 406 407
            return (None == address
                    or set() == self.addresses
                    or address in self.addresses)
Marek Vavruša's avatar
Marek Vavruša committed
408 409 410
        return False

    def reply(self, query):
411 412 413 414 415 416
        """
        Get answer for given query (adjusted if needed).

        Returns:
            (DNS message object) or None if there is no candidate in this range
        """
417
        self.received += 1
Marek Vavruša's avatar
Marek Vavruša committed
418 419 420
        for candidate in self.stored:
            try:
                candidate.match(query)
421 422 423 424 425 426
                resp = candidate.adjust_reply(query)
                # Probabilistic loss
                if 'LOSS' in self.args:
                    if random.random() < float(self.args['LOSS']):
                        return None
                self.sent += 1
427
                candidate.fired += 1
428
                return resp
Marek Vavruša's avatar
Marek Vavruša committed
429 430 431 432 433 434 435 436 437 438 439
            except Exception as e:
                pass
        return None


class Step:
    """
    Step represents one scripted action in a given moment,
    each step has an order identifier, type and optionally data entry.
    """

440
    require_data = ['QUERY', 'CHECK_ANSWER', 'REPLY']
Marek Vavruša's avatar
Marek Vavruša committed
441 442 443 444 445 446 447

    def __init__(self, id, type, extra_args):
        """ Initialize single scenario step. """
        self.id = int(id)
        self.type = type
        self.args = extra_args
        self.data = []
448
        self.queries = []
449
        self.has_data = self.type in Step.require_data
Marek Vavruša's avatar
Marek Vavruša committed
450 451
        self.answer = None
        self.raw_answer = None
452 453 454 455 456 457 458 459 460 461 462 463 464 465 466
        self.repeat_if_fail = 0
        self.pause_if_fail = 0
        self.next_if_fail = -1
        
        if type == 'CHECK_ANSWER':
            for arg in extra_args:
                param = arg.split('=')
                try:
                    if param[0] == 'REPEAT':
                        self.repeat_if_fail = int(param[1])
                    elif param[0] == 'PAUSE':
                        self.pause_if_fail = float(param[1])
                    elif param[0] == 'NEXT':
                        self.next_if_fail = int(param[1])
                except Exception as e:
467
                    raise Exception('step %d - wrong %s arg: %s' % (self.id, param[0], str(e)))
468

Marek Vavruša's avatar
Marek Vavruša committed
469 470 471 472 473

    def add(self, entry):
        """ Append a data entry to this step. """
        self.data.append(entry)

474
    def play(self, ctx):
Marek Vavruša's avatar
Marek Vavruša committed
475
        """ Play one step from a scenario. """
476
        dtag = '[ STEP %03d ] %s' % (self.id, self.type)
Marek Vavruša's avatar
Marek Vavruša committed
477
        if self.type == 'QUERY':
478
            dprint(dtag, self.data[0].message.to_text())
479 480 481 482 483 484 485 486 487 488 489 490
            # Parse QUERY-specific parameters
            choice, tcp, source = None, False, None
            for v in self.args:
                if '=' in v: # Key=Value
                    v = v.split('=')
                    if v[0].lower() == 'source':
                        source = v[1]
                elif v.lower() == 'tcp':
                    tcp = True
                else:
                    choice = v
            return self.__query(ctx, tcp = tcp, choice = choice, source = source)
Marek Vavruša's avatar
Marek Vavruša committed
491
        elif self.type == 'CHECK_OUT_QUERY':
492 493
            dprint(dtag, '')
            pass # Ignore
494
        elif self.type == 'CHECK_ANSWER' or self.type == 'ANSWER':
495
            dprint(dtag, '')
Marek Vavruša's avatar
Marek Vavruša committed
496 497
            return self.__check_answer(ctx)
        elif self.type == 'TIME_PASSES':
498
            dprint(dtag, '')
Marek Vavruša's avatar
Marek Vavruša committed
499
            return self.__time_passes(ctx)
500
        elif self.type == 'REPLY' or self.type == 'MOCK':
501
            dprint(dtag, '')
Marek Vavruša's avatar
Marek Vavruša committed
502
            pass
503 504 505 506
        elif self.type == 'LOG':
            if not ctx.log:
                raise Exception('scenario has no log interface')
            return ctx.log.match(self.args)
507 508
        elif self.type == 'REPLAY':
            self.__replay(ctx)
Marek Vavrusa's avatar
Marek Vavrusa committed
509 510
        elif self.type == 'ASSERT':
            self.__assert(ctx)
Marek Vavruša's avatar
Marek Vavruša committed
511
        else:
Marek Vavrusa's avatar
Marek Vavrusa committed
512
            raise Exception('step %03d type %s unsupported' % (self.id, self.type))
513

Marek Vavruša's avatar
Marek Vavruša committed
514 515 516 517 518 519
    def __check_answer(self, ctx):
        """ Compare answer from previously resolved query. """
        if len(self.data) == 0:
            raise Exception("response definition required")
        expected = self.data[0]
        if expected.is_raw_data_entry is True:
520
            dprint("", ctx.last_raw_answer.to_text())
Marek Vavruša's avatar
Marek Vavruša committed
521 522 523 524
            expected.cmp_raw(ctx.last_raw_answer)
        else:
            if ctx.last_answer is None:
                raise Exception("no answer from preceding query")
525
            dprint("", ctx.last_answer.to_text())
Marek Vavruša's avatar
Marek Vavruša committed
526 527
            expected.match(ctx.last_answer)

528 529 530
    def __replay(self, ctx, chunksize = 8):
        dtag = '[ STEP %03d ] %s' % (self.id, self.type)
        nqueries = len(self.queries)
531 532
        if len(self.args) > 0 and self.args[0].isdigit():
            nqueries = int(self.args.pop(0))
533
        destination = ctx.client[ctx.client.keys()[0]]
534
        dprint(dtag, 'replaying %d queries to %s@%d (%s)' % (nqueries, destination[0], destination[1], ' '.join(self.args)))
Marek Vavrusa's avatar
Marek Vavrusa committed
535 536
        if 'INTENSIFY' in os.environ:
            nqueries *= int(os.environ['INTENSIFY'])
537
        tstart = datetime.now()
538 539 540 541
        nsent, nrcvd = replay_rrs(self.queries, nqueries, destination, self.args)
        # Keep/print the statistics
        rtt = (datetime.now() - tstart).total_seconds() * 1000
        pps = 1000 * nrcvd / rtt
542
        dprint(dtag, 'sent: %d, received: %d (%d ms, %d p/s)' % (nsent, nrcvd, rtt, pps))
543 544 545 546 547 548 549
        tag = None
        for arg in self.args:
            if arg.upper().startswith('PRINT'):
                _, tag = tuple(arg.split('=')) if '=' in arg else (None, 'replay')
        if tag:
            print('  [ REPLAY ] test: %s pps: %5d time: %4d sent: %5d received: %5d' % (tag.ljust(11), pps, rtt, nsent, nrcvd))

550

551
    def __query(self, ctx, tcp = False, choice = None, source = None):
552 553 554 555 556
        """
        Send query and wait for an answer (if the query is not RAW).

        The received answer is stored in self.answer and ctx.last_answer.
        """
Marek Vavruša's avatar
Marek Vavruša committed
557 558 559 560 561 562 563
        if len(self.data) == 0:
            raise Exception("query definition required")
        if self.data[0].is_raw_data_entry is True:
            data_to_wire = self.data[0].raw_data
        else:
            # Don't use a message copy as the EDNS data portion is not copied.
            data_to_wire = self.data[0].message.to_wire()
Marek Vavrusa's avatar
Marek Vavrusa committed
564
        if choice is None or len(choice) == 0:
565 566
            choice = ctx.client.keys()[0]
        if choice not in ctx.client:
Marek Vavrusa's avatar
Marek Vavrusa committed
567
            raise Exception('step %03d invalid QUERY target: %s' % (self.id, choice))
568 569 570
        # Create socket to test subject
        sock = None
        destination = ctx.client[choice]
571
        family = socket.AF_INET6 if ':' in destination[0] else socket.AF_INET
572 573
        sock = socket.socket(family, socket.SOCK_STREAM if tcp else socket.SOCK_DGRAM)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
574
        if tcp:
575 576 577 578 579
            sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
        sock.settimeout(3)
        if source:
            sock.bind((source, 0))    
        sock.connect(destination)
580
        # Send query to client and wait for response
581
        tstart = datetime.now()
582
        log_packet(sock, data_to_wire, query = True)
Marek Vavruša's avatar
Marek Vavruša committed
583 584
        while True:
            try:
585
                sendto_msg(sock, data_to_wire)
Marek Vavruša's avatar
Marek Vavruša committed
586 587 588 589 590 591 592 593
                break
            except OSError, e:
                # ENOBUFS, throttle sending
                if e.errno == errno.ENOBUFS:
                    time.sleep(0.1)
        # Wait for a response for a reasonable time
        answer = None
        if not self.data[0].is_raw_data_entry:
594 595
            while True:
                try:
596
                    answer, _ = recvfrom_msg(sock, True)
597 598 599 600
                    break
                except OSError, e:
                    if e.errno == errno.ENOBUFS:
                        time.sleep(0.1)
601 602 603 604 605
        # Track RTT
        rtt = (datetime.now() - tstart).total_seconds() * 1000
        global g_rtt, g_nqueries
        g_nqueries += 1
        g_rtt += rtt
Marek Vavruša's avatar
Marek Vavruša committed
606 607 608 609
        # Remember last answer for checking later
        self.raw_answer = answer
        ctx.last_raw_answer = answer
        if self.raw_answer is not None:
Petr Špaček's avatar
Petr Špaček committed
610 611
            self.answer = dns.message.from_wire(self.raw_answer, one_rr_per_rrset=True)
            log_packet(sock, answer, query=False)
Marek Vavruša's avatar
Marek Vavruša committed
612 613 614 615 616 617 618 619 620
        else:
            self.answer = None
        ctx.last_answer = self.answer

    def __time_passes(self, ctx):
        """ Modify system time. """
        time_file = open(os.environ["FAKETIME_TIMESTAMP_FILE"], 'r')
        line = time_file.readline().strip()
        time_file.close()
621
        t = time.mktime(datetime.strptime(line, '@%Y-%m-%d %H:%M:%S').timetuple())
Marek Vavruša's avatar
Marek Vavruša committed
622 623
        t += int(self.args[1])
        time_file = open(os.environ["FAKETIME_TIMESTAMP_FILE"], 'w')
624 625
        time_file.write(datetime.fromtimestamp(t).strftime('@%Y-%m-%d %H:%M:%S') + "\n")
        time_file.flush()
Marek Vavruša's avatar
Marek Vavruša committed
626 627
        time_file.close()

Marek Vavrusa's avatar
Marek Vavrusa committed
628 629 630 631 632 633 634 635 636 637 638 639 640
    def __assert(self, ctx):
        """ Assert that a passed expression evaluates to True. """
        result = eval(' '.join(self.args), {'SCENARIO': ctx, 'RANGE': ctx.ranges})
        # Evaluate subexpressions for clarity
        subexpr = []
        for expr in self.args:
            try:
                ee = eval(expr, {'SCENARIO': ctx, 'RANGE': ctx.ranges})
                subexpr.append(str(ee))
            except:
                subexpr.append(expr)
        assert result is True, '"%s" assertion fails (%s)' % (' '.join(self.args), ' '.join(subexpr))

Marek Vavruša's avatar
Marek Vavruša committed
641
class Scenario:
642
    def __init__(self, info, filename = ''):
Marek Vavruša's avatar
Marek Vavruša committed
643 644
        """ Initialize scenario with description. """
        self.info = info
645
        self.file = filename
Marek Vavruša's avatar
Marek Vavruša committed
646
        self.ranges = []
647
        self.current_range = None
Marek Vavruša's avatar
Marek Vavruša committed
648 649
        self.steps = []
        self.current_step = None
650
        self.client = {}
651
        self.force_ipv6 = False
Marek Vavruša's avatar
Marek Vavruša committed
652 653

    def reply(self, query, address = None):
654 655 656 657 658 659 660 661
        """
        Generate answer packet for given query.

        The answer can be DNS message object or a binary blob.
        Returns:
            (answer, boolean "is the answer binary blob?")
        """
        current_step_id = 0
Marek Vavruša's avatar
Marek Vavruša committed
662
        if self.current_step is not None:
663
            current_step_id = self.current_step.id
Marek Vavruša's avatar
Marek Vavruša committed
664 665
        # Unknown address, select any match
        # TODO: workaround until the server supports stub zones
666 667 668 669
        all_addresses = set()
        for rng in self.ranges:
            all_addresses.update(rng.addresses)
        if address not in all_addresses:
Marek Vavruša's avatar
Marek Vavruša committed
670 671 672
            address = None
        # Find current valid query response range
        for rng in self.ranges:
673
            if rng.eligible(current_step_id, address):
674
                self.current_range = rng
Marek Vavruša's avatar
Marek Vavruša committed
675 676 677
                return (rng.reply(query), False)
        # Find any prescripted one-shot replies
        for step in self.steps:
678
            if step.id < current_step_id or step.type != 'REPLY':
Marek Vavruša's avatar
Marek Vavruša committed
679 680 681 682 683 684 685 686 687 688 689 690 691 692 693
                continue
            try:
                candidate = step.data[0]
                if candidate.is_raw_data_entry is False:
                    candidate.match(query)
                    step.data.remove(candidate)
                    answer = candidate.adjust_reply(query)
                    return (answer, False)
                else:
                    answer = candidate.raw_data
                    return (answer, True)
            except:
                pass
        return (None, True)

694
    def play(self, family, paddr):
Marek Vavruša's avatar
Marek Vavruša committed
695
        """ Play given scenario. """
696 697
        # Store test subject => address mapping
        self.client = paddr
Marek Vavruša's avatar
Marek Vavruša committed
698

699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724
        step = None
        i = 0
        while i < len(self.steps):
            step = self.steps[i]
            self.current_step = step
            try:
                step.play(self)
            except Exception as e:
                if (step.repeat_if_fail > 0):
                    dprint ('[play]',"step %d: exception catched - '%s', retrying step %d (%d left)" % (step.id, e, step.next_if_fail, step.repeat_if_fail))
                    step.repeat_if_fail -= 1
                    if (step.pause_if_fail > 0):
                        time.sleep(step.pause_if_fail)
                    if (step.next_if_fail != -1):
                        next_steps = [j for j in range(len(self.steps)) if self.steps[j].id == step.next_if_fail]
                        if (len(next_steps) == 0):
                            raise Exception('step %d: wrong NEXT value "%d"' % (step.id, step.next_if_fail))
                        next_step = next_steps[0]
                        if (next_step < len(self.steps)):
                            i = next_step
                        else:
                            raise Exception('step %d: Can''t branch to NEXT value "%d"' % (step.id, step.next_if_fail))
                    continue
                else:
                    raise Exception('%s step %d %s' % (self.file, step.id, str(e)))
            i = i + 1
725

726 727 728 729 730
        for r in self.ranges:
            for e in r.stored:
                if e.mandatory is True and e.fired == 0:
                    raise Exception('Mandatory section at line %d is not fired' % e.lineno)

731

732
def get_next(file_in, skip_empty = True):
733 734 735 736 737
    """ Return next token from the input stream. """
    while True:
        line = file_in.readline()
        if len(line) == 0:
            return False
738 739 740 741 742 743
        quoted, escaped = False, False
        for i in range(len(line)):
            if line[i] == '\\':
                escaped = not escaped
            if not escaped and line[i] == '"':
                quoted = not quoted
744
            if line[i] in (';') and not quoted:
745 746 747 748
                line = line[0:i]
                break
            if line[i] != '\\':
                escaped = False
749 750
        tokens = ' '.join(line.strip().split()).split()
        if len(tokens) == 0:
751 752 753 754
            if skip_empty:
                continue
            else:
                return '', []
755 756 757
        op = tokens.pop(0)
        return op, tokens

758
def parse_entry(op, args, file_in, in_entry = False):
759
    """ Parse entry definition. """
760
    out = Entry(file_in.lineno())
761 762 763
    for op, args in iter(lambda: get_next(file_in, in_entry), False):
        if op == 'ENTRY_END' or op == '':
            in_entry = False
764
            break
765 766 767 768 769
        elif op == 'ENTRY_BEGIN': # Optional, compatibility with Unbound tests
            if in_entry:
                raise Exception('nested ENTRY_BEGIN not supported')
            in_entry = True
            pass
770 771 772
        elif op == 'EDNS':
            out.set_edns(args)
        elif op == 'REPLY' or op == 'FLAGS':
773 774 775 776 777 778 779 780 781 782 783
            out.set_reply(args)
        elif op == 'MATCH':
            out.set_match(args)
        elif op == 'ADJUST':
            out.set_adjust(args)
        elif op == 'SECTION':
            out.begin_section(args[0])
        elif op == 'RAW':
            out.begin_raw()
        elif op == 'TSIG':
            out.use_tsig(args)
784 785
        elif op == 'MANDATORY':
            out.set_mandatory()
786 787 788 789
        else:
            out.add_record(op, args)
    return out

790 791 792 793 794 795 796 797 798
def parse_queries(out, file_in):
    """ Parse list of queries terminated by blank line. """
    out.queries = []
    for op, args in iter(lambda: get_next(file_in, False), False):
        if op == '':
            break
        out.queries.append(create_rr(op, args))
    return out

799 800 801 802 803 804
auto_step = 0
def parse_step(op, args, file_in):
    """ Parse range definition. """
    global auto_step
    if len(args) == 0:
        raise Exception('expected at least STEP <type>')
805 806 807 808 809
    # Auto-increment when step ID isn't specified
    if len(args) < 2 or not args[0].isdigit():
        args = [str(auto_step)] + args
    auto_step = int(args[0]) + 1
    out = Step(args[0], args[1], args[2:])
810
    if out.has_data:
811
        out.add(parse_entry(op, args, file_in))
812 813 814
    # Special steps
    if args[1] == 'REPLAY':
        parse_queries(out, file_in)
815 816 817 818 819 820 821 822 823 824
    return out


def parse_range(op, args, file_in):
    """ Parse range definition. """
    if len(args) < 2:
        raise Exception('expected RANGE_BEGIN <from> <to> [address]')
    out = Range(int(args[0]), int(args[1]))
    # Shortcut for address
    if len(args) > 2:
825
        out.addresses.add(args[2])
826 827 828 829 830 831
    # Parameters
    if len(args) > 3:
        out.args = {}
        for v in args[3:]:
            k, v = tuple(v.split('=')) if '=' in v else (v, True)
            out.args[k] = v
832 833
    for op, args in iter(lambda: get_next(file_in), False):
        if op == 'ADDRESS':
834
            out.addresses.add(args[0])
835
        elif op == 'ENTRY_BEGIN':
836
            out.add(parse_entry(op, args, file_in, in_entry = True))
837 838 839 840 841 842 843
        elif op == 'RANGE_END':
            break
    return out


def parse_scenario(op, args, file_in):
    """ Parse scenario definition. """
844
    out = Scenario(args[0], file_in.filename())
845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880
    for op, args in iter(lambda: get_next(file_in), False):
        if op == 'SCENARIO_END':
            break
        if op == 'RANGE_BEGIN':
            out.ranges.append(parse_range(op, args, file_in))
        if op == 'STEP':
            out.steps.append(parse_step(op, args, file_in))
    return out


def parse_file(file_in):
    """ Parse scenario from a file. """
    try:
        config = []
        line = file_in.readline()
        while len(line):
            # Zero-configuration
            if line.startswith('SCENARIO_BEGIN'):
                return parse_scenario(line, line.split(' ')[1:], file_in), config
            if line.startswith('CONFIG_END'):
                break
            if not line.startswith(';'):
                if '#' in line:
                    line = line[0:line.index('#')]
                # Break to key-value pairs
                # e.g.: ['minimization', 'on']
                kv = [x.strip() for x in line.split(':',1)]
                if len(kv) >= 2:
                    config.append(kv)
            line = file_in.readline()

        for op, args in iter(lambda: get_next(file_in), False):
            if op == 'SCENARIO_BEGIN':
                return parse_scenario(op, args, file_in), config
        raise Exception("IGNORE (missing scenario)")
    except Exception as e:
881
        raise Exception('%s#%d: %s' % (file_in.filename(), file_in.lineno(), str(e)))