Skip to content
Snippets Groups Projects
Commit 5de0f9a1 authored by Daniel Salzman's avatar Daniel Salzman
Browse files

func-test: improve dig, add CHAOS/TXT tests

parent 22791635
No related branches found
No related tags found
No related merge requests found
__pycache__
#!/usr/bin/env python3
'''Test for server identification over CH/TXT'''
import dnstest
import socket
t = dnstest.DnsTest()
name = "Knot DNS server name"
server1 = t.server("knot", ident=name)
server2 = t.server("knot", ident=True)
server3 = t.server("knot", ident=False)
server4 = t.server("knot")
t.start()
# 1a) Custom identification string.
resp = server1.dig("id.server", "TXT", "CH")
resp.check('"' + name + '"')
# 1b) Bind old version of above.
resp = server1.dig("hostname.bind", "TXT", "CH")
resp.check('"' + name + '"')
# 2) FQDN hostname.
resp = server2.dig("id.server", "TXT", "CH")
resp.check(socket.getfqdn())
# 3) Explicitly disabled.
resp = server3.dig("id.server", "TXT", "CH")
resp.check(rcode="REFUSED")
# 4) Disabled.
resp = server4.dig("id.server", "TXT", "CH")
resp.check(rcode="REFUSED")
t.stop()
#!/usr/bin/env python3
'''Test for server version over CH/TXT'''
import dnstest
t = dnstest.DnsTest()
ver = "ver. 1.3.1-p3"
server1 = t.server("knot", version=ver)
server2 = t.server("knot", version=True)
server3 = t.server("knot", version=False)
server4 = t.server("knot")
t.start()
# 1a) Custom version string.
resp = server1.dig("version.server", "TXT", "CH")
resp.check('"' + ver + '"')
# 1b) Bind old version of above.
resp = server1.dig("version.bind", "TXT", "CH")
resp.check('"' + ver + '"')
# 2) Automatic version string (can't be tested).
resp = server2.dig("version.server", "TXT", "CH")
resp.check(rcode="NOERROR")
# 3) Explicitly disabled.
resp = server3.dig("version.server", "TXT", "CH")
resp.check(rcode="REFUSED")
# 4) Disabled.
resp = server4.dig("version.server", "TXT", "CH")
resp.check(rcode="REFUSED")
t.stop()
......@@ -37,6 +37,11 @@ def log(text):
def err(text):
log(" ERR> %s" % text)
def compare(value, expected, name):
if (value != expected):
err("%s is (" % name + str(value) + ") != (" + str(expected) + ")")
params.err = True
class Tsig(object):
'''TSIG key generator'''
......@@ -177,6 +182,65 @@ class Zone(object):
# ddns: True - ddns, False - ixfrFromDiff, None - empty
self.ddns = ddns
class Response(object):
'''Dig output context'''
def __init__(self, response, rname, rtype, rclass, serial):
self.resp = response
self.rname = dns.name.from_text(rname)
self.serial = serial
if type(rtype) is str:
self.rtype = dns.rdatatype.from_text(rtype)
else:
self.rtype = rtype
if type(rclass) is str:
self.rclass = dns.rdataclass.from_text(rclass)
else:
self.rclass = rclass
def _check_query(self):
question = self.resp.question.pop()
compare(question.name, self.rname, "question.name")
compare(question.rdclass, self.rclass, "question.class")
compare(question.rdtype, self.rtype, "question.type")
def check(self, rdata=None, ttl=None, rcode="NOERROR"):
# Check question section.
self._check_query()
# Check rcode.
if type(rcode) is str:
rc = dns.rcode.from_text(rcode)
else:
rc = rcode
compare(self.resp.rcode(), rc, "rcode")
# Check rdata only if NOERROR.
if rc != 0 or rdata == None:
return
# We work with just one rdata with TTL=0 (this TTL is not used).
ref = list(dns.rdataset.from_text(self.rclass, self.rtype, 0, rdata))[0]
# Check answer section if contains reference rdata.
for data in self.resp.answer:
for rdata in data.to_rdataset():
# Compare Rdataset instances.
if rdata == ref:
# Check CLASS.
compare(data.rdclass, self.rclass, "CLASS")
# Check TYPE.
compare(data.rdtype, self.rtype, "TYPE")
# Check TTL if specified.
if ttl != None:
compare(data.ttl, int(ttl), "TTL")
return
else:
err("RDATA (" + str(rdata) + ") not in answer section")
params.err = True
class Update(object):
'''DNS update context'''
......@@ -191,10 +255,13 @@ class Update(object):
self.upd.delete(owner, **args)
def send(self, rcode="NOERROR"):
if type(rcode) is str:
rc = dns.rcode.from_text(rcode)
else:
rc = rcode
resp = dns.query.tcp(self.upd, self.server.addr, port=self.server.port)
if resp.rcode() != dns.rcode.from_text(rcode):
raise Exception("Update rcode is %s (expected %s)" \
% (dns.rcode.to_text(resp.rcode()), rcode))
compare(resp.rcode(), rc, "update rcode")
class DnsServer(object):
'''Specification of DNS server'''
......@@ -386,50 +453,49 @@ class DnsServer(object):
f.close
def dig(self, rname, rtype, rclass="IN", use_udp=None, serial=None, \
timeout=DIG_TIMEOUT):
timeout=DIG_TIMEOUT, tries=3):
key_params = self.tsig.key_params if self.tsig else dict()
try:
if rtype.upper() == "AXFR":
# Always use TCP.
resp = dns.query.xfr(self.addr, rname, rtype, rclass, \
port=self.port, lifetime=timeout, \
use_udp=False, **key_params)
elif rtype.upper() == "IXFR":
# Use TCP if not specified.
use_udp = use_udp if use_udp != None else False
resp = dns.query.xfr(self.addr, rname, rtype, rclass, \
port=self.port, lifetime=timeout, \
use_udp=use_udp, serial=serial, \
**key_params)
else:
# Use TCP or UDP at random if not specified.
use_udp = use_udp if use_udp != None else \
random.choice([True, False])
for t in range(tries):
try:
if rtype.upper() == "AXFR":
# Use TCP if not specified. UDP is for testing.
use_udp = use_udp if use_udp != None else False
resp = dns.query.xfr(self.addr, rname, rtype, rclass, \
port=self.port, lifetime=timeout, \
use_udp=use_udp, **key_params)
elif rtype.upper() == "IXFR":
# Use TCP if not specified.
use_udp = use_udp if use_udp != None else False
resp = dns.query.xfr(self.addr, rname, rtype, rclass, \
port=self.port, lifetime=timeout, \
use_udp=use_udp, serial=serial, \
**key_params)
else:
# Use TCP or UDP at random if not specified.
use_udp = use_udp if use_udp != None else \
random.choice([True, False])
query = dns.message.make_query(rname, rtype, rclass)
query = dns.message.make_query(rname, rtype, rclass)
if use_udp:
resp = dns.query.udp(query, self.addr, port=self.port, \
timeout=timeout)
else:
resp = dns.query.tcp(query, self.addr, port=self.port, \
timeout=timeout)
return resp
except:
return None
if use_udp:
resp = dns.query.udp(query, self.addr, port=self.port, \
timeout=timeout)
else:
resp = dns.query.tcp(query, self.addr, port=self.port, \
timeout=timeout)
return Response(resp, rname, rtype, rclass, serial)
except:
time.sleep(timeout)
raise Exception("Can't query %s for %s %s %s." % \
(self.name, rname, rclass, rtype))
def zones_wait(self, zones):
for zone in zones:
for attempt in range(10):
resp = self.dig(zone, "SOA", use_udp=True)
if resp and resp.answer:
break
time.sleep(DnsServer.DIG_TIMEOUT)
else:
raise Exception("Can't get %s SOA from %s." % \
(zone, self.name))
self.dig(zone, "SOA", use_udp=True)
def update(self, zone):
if len(zone) != 1:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment