From bf28ef7f25a08436e751085e2059929c20a1324e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Marek=20Vavru=C5=A1a?= <marek.vavrusa@nic.cz>
Date: Fri, 10 Jul 2015 03:15:26 +0200
Subject: [PATCH] tests: can execute specific tests, fixed EDNS+RR merging

---
 tests/integration.mk        |  3 ++-
 tests/pydnstest/scenario.py | 21 +++++++++++++++------
 2 files changed, 17 insertions(+), 7 deletions(-)

diff --git a/tests/integration.mk b/tests/integration.mk
index 100de5937..f4a3b8962 100644
--- a/tests/integration.mk
+++ b/tests/integration.mk
@@ -1,6 +1,7 @@
 #
 # Integration tests
 #
+TESTS ?= tests/testdata
 
 # Mocked calls library
 libmock_calls_SOURCES := tests/mock_calls.c
@@ -22,6 +23,6 @@ else
 endif
 
 check-integration: $(libmock_calls) $(_test_integration)
-	$(call preload_LIBS) $(preload_syms) tests/test_integration.py tests/testdata
+	$(call preload_LIBS) $(preload_syms) tests/test_integration.py $(TESTS)
 
 .PHONY: check-integration
diff --git a/tests/pydnstest/scenario.py b/tests/pydnstest/scenario.py
index 4b2a760bc..d3f175be8 100644
--- a/tests/pydnstest/scenario.py
+++ b/tests/pydnstest/scenario.py
@@ -19,6 +19,7 @@ class Entry:
         self.adjust_fields = ['copy_id']
         self.origin = '.'
         self.message = dns.message.Message()
+        self.message.use_edns(edns = 0)
         self.sections = []
 
     def match_part(self, code, msg):
@@ -73,6 +74,7 @@ class Entry:
     def adjust_reply(self, query):
         """ Copy scripted reply and adjust to received query. """
         answer = dns.message.from_text(self.message.to_text())
+        answer.use_edns(query.edns, query.ednsflags)
         if 'copy_id' in self.adjust_fields:
             answer.id = query.id
             answer.question[0].name = query.question[0].name
@@ -98,7 +100,7 @@ class Entry:
             except:
                 flags.append(code)
         self.message.flags = dns.flags.from_text(' '.join(flags))
-        self.message.ednsflags = dns.flags.edns_from_text(' '.join(eflags))
+        self.message.want_dnssec('DO' in eflags)
         self.message.set_rcode(rcode)
 
     def begin_section(self, section):
@@ -110,16 +112,24 @@ class Entry:
         """ Add record to current packet section. """
         rr = self.__rr_from_str(owner, args)
         if self.section == 'QUESTION':
-            self.message.question.append(rr)
+            self.__rr_add(self.message.question, rr)
         elif self.section == 'ANSWER':
-            self.message.answer.append(rr)
+            self.__rr_add(self.message.answer, rr)
         elif self.section == 'AUTHORITY':
-            self.message.authority.append(rr)
+            self.__rr_add(self.message.authority, rr)
         elif self.section == 'ADDITIONAL':
-            self.message.additional.append(rr)
+            self.__rr_add(self.message.additional, rr)
         else:
             raise Exception('bad section %s' % self.section)
 
+    def __rr_add(self, section, rr):
+    	""" Merge record to existing RRSet, or append to given section. """
+    	for existing_rr in section:
+    		if existing_rr.match(rr.name, rr.rdclass, rr.rdtype, 0):
+    			existing_rr += rr
+    			return
+    	section.append(rr)
+
     def __rr_from_str(self, owner, args):
         """ Parse RR from tokenized string. """
         if not owner.endswith('.'):
@@ -251,7 +261,6 @@ class Step:
         if len(self.data) == 0:
             raise Exception("query definition required")
         msg = self.data[0].message
-        msg.use_edns(edns = 1)
         self.answer = ctx.resolve(msg.to_wire())
         if self.answer is not None:
             self.answer = dns.message.from_wire(self.answer)
-- 
GitLab