From f8bdcd8b44ce4fa7b12bf23f55bbb817e6fd42c3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Marek=20Vavru=C5=A1a?= <marek.vavrusa@nic.cz>
Date: Sat, 17 Jan 2015 21:44:37 +0100
Subject: [PATCH] tests/integration: Python code to support QUERY/CHECK_ANSWER
 steps

---
 .gitignore                       |   1 +
 tests/.gitignore                 |   2 +
 tests/pydnstest/__init__.py      |   0
 tests/pydnstest/requirements.txt |   1 +
 tests/pydnstest/scenario.py      | 186 ++++++++++++++++++++++++++-----
 tests/test_integration.py        | 111 ++++++++++++------
 6 files changed, 241 insertions(+), 60 deletions(-)
 create mode 100644 tests/.gitignore
 create mode 100644 tests/pydnstest/__init__.py
 create mode 100644 tests/pydnstest/requirements.txt

diff --git a/.gitignore b/.gitignore
index eb1c16e6c..159b3e03d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,6 +7,7 @@
 *.in
 *.Plo
 *.swp
+.dirstamp
 .libs
 .deps
 autom4te.cache/*
diff --git a/tests/.gitignore b/tests/.gitignore
new file mode 100644
index 000000000..5aacb1141
--- /dev/null
+++ b/tests/.gitignore
@@ -0,0 +1,2 @@
+*.pyc
+tmp*
diff --git a/tests/pydnstest/__init__.py b/tests/pydnstest/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/pydnstest/requirements.txt b/tests/pydnstest/requirements.txt
new file mode 100644
index 000000000..2f7359679
--- /dev/null
+++ b/tests/pydnstest/requirements.txt
@@ -0,0 +1 @@
+dnspython
diff --git a/tests/pydnstest/scenario.py b/tests/pydnstest/scenario.py
index aa9e827b7..366dc173b 100644
--- a/tests/pydnstest/scenario.py
+++ b/tests/pydnstest/scenario.py
@@ -1,46 +1,180 @@
-class Query:
+import traceback
+import dns.message
+import dns.rrset
+import dns.rcode
 
-    match_fields = []
+
+class Entry:
+    default_ttl = 3600
+    default_cls = 'IN'
+    default_rc = 'NOERROR'
 
     def __init__(self):
-        pass
+        self.match_fields = None
+        self.adjust_fields = None
+        self.message = dns.message.Message()
+
+    def match_part(self, code, msg):
+        if code not in self.match_fields and 'all' not in self.match_fields:
+            return True
+        expected = self.message
+        if code == 'opcode':
+            return self.__compare_val(expected.opcode(), msg.opcode())
+        elif code == 'qtype':
+            return self.__compare_val(expected.question[0].rdtype, msg.question[0].rdtype)
+        elif code == 'qname':
+            return self.__compare_val(expected.question[0].name, msg.question[0].name)
+        elif code == 'flags':
+            return self.__compare_val(dns.flags.to_text(expected.flags), dns.flags.to_text(msg.flags))
+        elif code == 'question':
+            return self.__compare_rrs(expected.question, msg.question)
+        elif code == 'answer':
+            return self.__compare_rrs(expected.answer, msg.answer)
+        elif code == 'authority':
+            return self.__compare_rrs(expected.authority, msg.authority)
+        elif code == 'additional':
+            return self.__compare_rrs(expected.additional, msg.additional)
+        else:
+            raise Exception('unknown match request "%s"' % code)
+
+    def match(self, msg):
+        match_fields = self.match_fields
+        if 'all' in match_fields:
+            match_fields = ('flags', 'question', 'answer', 'authority', 'additional')
+        for code in match_fields:
+            try:
+                self.match_part(code, msg)
+            except Exception as e:
+                raise Exception("when matching %s: %s" % (code, str(e)))
 
-    def match(self, fields):
+    def set_match(self, fields):
         self.match_fields = fields
 
-    def parse(self, text):
-        pass
+    def set_adjust(self, fields):
+        self.adjust_fields = fields
 
-class Range:
+    def set_reply(self, fields):
+        flags = []
+        rcode = dns.rcode.from_text(self.default_rc)
+        for code in fields:
+            try:
+                rcode = dns.rcode.from_text(code)
+            except:
+                flags.append(code)
+        self.message.flags = dns.flags.from_text(' '.join(flags))
+        self.message.rcode = rcode
+
+    def begin_section(self, section):
+        self.section = section
+
+    def add_record(self, owner, args):
+        rr = self.__rr_from_str(owner, args)
+        if self.section == 'QUESTION':
+            self.message.question.append(rr)
+        elif self.section == 'ANSWER':
+            self.message.answer.append(rr)
+        elif self.section == 'AUTHORITY':
+            self.message.authority.append(rr)
+        elif self.section == 'ADDITIONAL':
+            self.message.additional.append(rr)
+        else:
+            raise Exception('attempted to add record in section %s' % self.section)
+
+
+    def __rr_from_str(self, owner, args):
+        ttl = self.default_ttl
+        rdclass = self.default_cls
+        try:
+            dns.ttl.from_text(args[0])
+            ttl = args.pop(0)
+        except:
+            pass  # optional
+        try:
+            dns.rdataclass.from_text(args[0])
+            rdclass = args.pop(0)
+        except:
+            pass  # optional
+        rdtype = args.pop(0)
+        if len(args) > 0:
+            return dns.rrset.from_text(owner, ttl, rdclass, rdtype, ' '.join(args))
+        else:
+            return dns.rrset.from_text(owner, ttl, rdclass, rdtype)
+
+    def __compare_rrs(self, name, expected, got):
+        for rr in expected:
+            if rr not in got:
+                raise Exception("expected record '%s'" % rr.to_text())
+        for rr in got:
+            if rr not in expected:
+                raise Exception("unexpected record '%s'" % rr.to_text())
+        return True
 
-    a = 0
-    b = 0
-    queries = []
+    def __compare_val(self, expected, got):
+        if expected != got:
+            raise Exception("expected '%s', got '%s'" % (expected, got))
+        return True
 
+
+class Range:
     def __init__(self, a, b):
         self.a = a
         self.b = b
+        self.queries = []
 
-    def add_query(self, query):
-        self.queries.append(query)
+    def add(self, entry):
+        self.queries.append(entry)
 
 
-class Scenario:
+class Step:
+    def __init__(self, id, type):
+        self.id = int(id)
+        self.type = type
+        self.data = []
 
-    name = ''
-    ranges = []
-    steps = []
+    def add(self, entry):
+        self.data.append(entry)
 
-    def __init__(self):
-        pass
+    def play(self, ctx):
+        if self.type == 'QUERY':
+            return self.__query(ctx)
+        elif self.type == 'CHECK_ANSWER':
+            return self.__check_answer(ctx)
+        else:
+            print '%d %s (%d entries) => NOOP' % (self.id, self.type, len(self.data))
+            return None
+
+    def __check_answer(self, ctx):
+        if len(self.data) == 0:
+            raise Exception("response definition required")
+        if ctx.last_answer is None:
+            raise Exception("no answer from preceding query")
+        expected = self.data[0]
+        expected.match(ctx.last_answer)
+
+    def __query(self, ctx):
+        if len(self.data) == 0:
+            raise Exception("query definition required")
+        msg = self.data[0].message
+        self.answer = ctx.resolve(msg.to_wire())
+        if self.answer is not None:
+            self.answer = dns.message.from_wire(self.answer)
+            ctx.last_answer = self.answer
+
+
+class Scenario:
+    def __init__(self, info):
+        print '# %s' % info
+        self.ranges = []
+        self.steps = []
 
-    def begin(self, explanation):
-        print '# %s' % explanation
+    def play(self, ctx):
+        step = None
+        if len(self.steps) == 0:
+            raise ('no steps in this scenario')
+        try:
+            for step in self.steps:
+                step.play(ctx)
+        except Exception as e:
+            raise Exception('on step #%d "%s": %s\n%s' % (step.id, step.type, str(e), traceback.format_exc()))
 
-    def range(self, a, b):
-        range_new = Range(a, b)
-        self.ranges.append(range_new)
-        return range_new
 
-    def step(self, n, step_type):
-        pass
diff --git a/tests/test_integration.py b/tests/test_integration.py
index e413a183f..508133d77 100755
--- a/tests/test_integration.py
+++ b/tests/test_integration.py
@@ -1,68 +1,111 @@
 #!/usr/bin/env python
 import sys, os, fileinput
-import _test_integration
+from pydnstest import scenario
+import _test_integration as mock_ctx
 
+def get_next(file_in):
+    while True:
+        line = file_in.readline()
+        if len(line) == 0:
+            return False
+        tokens = ' '.join(line.strip().split()).split()
+        if len(tokens) == 0:
+            continue # Skip empty lines
+        op = tokens.pop(0)
+        if op.startswith(';') or op.startswith('#'):
+            continue # Skip comments
+        return op, tokens
 
-def parse_entry(line, file_in):
+def parse_entry(op, args, file_in):
     """ Parse entry definition. """
-    print line.split(' ')
-    for line in iter(lambda: file_in.readline(), ''):
-        if line.startswith('ENTRY_END'):
+    out = scenario.Entry()
+    for op, args in iter(lambda: get_next(file_in), False):
+        if op == 'ENTRY_END':
             break
+        elif op == 'REPLY':
+            out.set_reply(args)
+        elif op == 'MATCH':
+            out.set_match(args)
+        elif op == 'ADJUST':
+            out.set_adjust(args)
+        elif op == 'SECTION':
+            out.begin_section(args[0])
+        else:
+            out.add_record(op, args)
+    return out
 
 
-def parse_step(line, file_in):
+def parse_step(op, args, file_in):
     """ Parse range definition. """
-    print line.split(' ')
+    if len(args) < 2:
+        raise Exception('expected STEP <id> <type>')
+    out = scenario.Step(args[0], args[1])
+    op, args = get_next(file_in)
+    # Optional data
+    if op == 'ENTRY_BEGIN':
+        out.add(parse_entry(op, args, file_in))
+    else:
+        raise Exception('expected "ENTRY_BEGIN"')
+    return out
 
 
-def parse_range(line, file_in):
+def parse_range(op, args, file_in):
     """ Parse range definition. """
-    print line.split(' ')
-    for line in iter(lambda: file_in.readline(), ''):
-        if line.startswith('ENTRY_BEGIN'):
-            parse_entry(line, file_in)
-        if line.startswith('RANGE_END'):
+    if len(args) < 2:
+        raise Exception('expected RANGE_BEGIN <from> <to>')
+    out = scenario.Range(int(args[0]), int(args[1]))
+    for op, args in iter(lambda: get_next(file_in), False):
+        if op == 'ADDRESS':
+            out.address = args[0]
+        elif op == 'ENTRY_BEGIN':
+            out.add(parse_entry(op, args, file_in))
+        elif op == 'RANGE_END':
             break
+    return out
 
 
-def parse_scenario(line, file_in):
+def parse_scenario(op, args, file_in):
     """ Parse scenario definition. """
-    print line.split(' ')
-    for line in iter(lambda: file_in.readline(), ''):
-        if line.startswith('SCENARIO_END'):
+    out = scenario.Scenario(args[0])
+    for op, args in iter(lambda: get_next(file_in), False):
+        if op == 'SCENARIO_END':
             break
-        if line.startswith('RANGE_BEGIN'):
-            parse_range(line, file_in)
-        if line.startswith('STEP'):
-            parse_step(line, file_in)
-
+        if op == 'RANGE_BEGIN':
+            out.ranges.append(parse_range(op, args, file_in))
+        if op == 'STEP':
+            out.steps.append(parse_step(op, args, file_in))
+    return out
 
 def parse_file(file_in):
-    """ Parse and play scenario from a file. """
+    """ Parse scenario from a file. """
     try:
-        for line in iter(lambda: file_in.readline(), ''):
-            if line.startswith('SCENARIO_BEGIN'):
-                return parse_scenario(line, file_in)
+        for op, args in iter(lambda: get_next(file_in), False):
+            if op == 'SCENARIO_BEGIN':
+                return parse_scenario(op, args, file_in)
         raise Exception("IGNORE (missing scenario)")
     except Exception as e:
         raise Exception('line %d: %s' % (file_in.lineno(), str(e)))
 
-
 def parse_object(path):
     """ Recursively scan file/directory for scenarios. """
     if os.path.isdir(path):
         for e in os.listdir(path):
             parse_object(os.path.join(path, e))
     elif os.path.isfile(path):
-        file_in = fileinput.input(path)
-        try:
-            parse_file(file_in)
-            print('%s OK' % os.path.basename(path))
-        except Exception as e:
-            print('%s %s' % (os.path.basename(path), str(e)))
-        file_in.close()
+        play_object(path)
 
+def play_object(path):
+    """ Play scenario from a file object. """
+    file_in = fileinput.input(path)
+    mock_ctx.init()
+    try:
+        scenario = parse_file(file_in)
+        scenario.play(mock_ctx)
+        print('%s OK' % os.path.basename(path))
+    except Exception as e:
+        print('%s %s' % (os.path.basename(path), str(e)))
+    mock_ctx.deinit()
+    file_in.close()
 
 if __name__ == '__main__':
     for arg in sys.argv[1:]:
-- 
GitLab