Commit 04429d99 authored by Petr Špaček's avatar Petr Špaček

Merge branch 'parallel-diffrepro' into 'master'

parallel diffrepro

See merge request knot/resolver-benchmarking!40
parents eab5a928 60a4be9a
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# mypy
.mypy_cache/
......@@ -6,6 +6,7 @@ CONFIG="response_differences/respdiff/respdiff.cfg"
response_differences/respdiff/qprep.py /tmp/respdiff.db < /tmp/queries.txt
time response_differences/respdiff/orchestrator.py /tmp/respdiff.db -c "${CONFIG}"
time response_differences/respdiff/msgdiff.py /tmp/respdiff.db -c "${CONFIG}"
response_differences/respdiff/diffrepro.py /tmp/respdiff.db -c "${CONFIG}"
response_differences/respdiff/diffsum.py /tmp/respdiff.db -c "${CONFIG}"
# it must not explode/raise an unhandled exception
......@@ -13,6 +13,7 @@ Respdiff v2 is conceptually chain of independent tools:
1. qprep: generate queries in wire format
2. orchestrator: send pre-generated wire format to servers and gather answers
3. msgdiff: compare DNS answers
4. diffrepro: (optional) attempt to reproduce the differences
4. diffsum: summarize differences into textual report
5. histogram: plot graph of answer latencies
......@@ -60,7 +61,23 @@ which reads configuration from config file section ``[diff]``.
The tool refers to one resolver as ``target`` and to remaining servers
as ``others``. Msgdiff compares specified fields and stores result
in the LMDB.
in the LMDB and the JSON datafile.
Diffrepro
---------
Use of this tool is optional. It can be used to filter "unstable" differences,
which aren't reproducible. If the upstream answers differ (between resolvers or
over time), the query is flagged as unstable.
The tool can run queries in parallel (like orchestrator), or sequentially (slower,
but more predictable). Resolvers should be restarted (and cache cleared) between
the queries. Path to an executable restart script can be provided with
``restart_script`` value in each resolver's section in config.
The output is written to the JSON datafile and other tools automatically use
this data if present.
Diffsum
......
from argparse import ArgumentParser, Namespace
import logging
import os
import sys
import cfg
......@@ -33,7 +34,13 @@ def add_arg_datafile(parser: ArgumentParser) -> None:
REPORT_FILENAME))
def get_datafile(args: Namespace) -> str:
if args.datafile is None:
return os.path.join(args.envdir, REPORT_FILENAME)
return args.datafile
def get_datafile(args: Namespace, check_exists=True) -> str:
datafile = args.datafile
if datafile is None:
datafile = os.path.join(args.envdir, REPORT_FILENAME)
if check_exists and not os.path.exists(datafile):
logging.error("JSON report (%s) doesn't exist!", datafile)
sys.exit(1)
return datafile
......@@ -9,7 +9,6 @@ from typing import ( # noqa
MismatchValue = Union[str, Sequence[str]]
ResolverID = str
QID = int
WireFormat = bytes
FieldLabel = str
......
from typing import Dict, Any, Tuple, Generator # NOQA: needed for type hint in comment
import os
import struct
from typing import Any, Dict, Iterator, Tuple # NOQA: needed for type hint in comment
import lmdb
from dataformat import QID
def qid2key(qid):
QKey = bytes
def qid2key(qid: QID) -> QKey:
"""Encode query ID to database key"""
return struct.pack('@I', qid) # native integer
def key2qid(key):
def key2qid(key: QKey) -> QID:
return struct.unpack('@I', key)[0]
......@@ -104,7 +109,7 @@ class LMDB:
except KeyError:
raise RuntimeError("Database {} isn't open!".format(dbname.decode('utf-8')))
def key_stream(self, dbname: bytes):
def key_stream(self, dbname: bytes) -> Iterator[bytes]:
"""yield all keys from given db"""
db = self.get_db(dbname)
with self.env.begin(db) as txn:
......@@ -112,7 +117,7 @@ class LMDB:
for key in cur.iternext(keys=True, values=False):
yield key
def key_value_stream(self, dbname: bytes):
def key_value_stream(self, dbname: bytes) -> Iterator[Tuple[bytes, bytes]]:
"""yield all (key, value) pairs from given db"""
db = self.get_db(dbname)
with self.env.begin(db) as txn:
......
#!/usr/bin/env python3
import argparse
from itertools import zip_longest
import logging
from multiprocessing import pool
import pickle
import random
import subprocess
from typing import Any, Mapping
import sys
from typing import ( # noqa
Any, AbstractSet, Iterable, Iterator, Mapping, Sequence, Tuple, TypeVar,
Union)
import cli
from dbhelper import LMDB
from dbhelper import LMDB, qid2key, key2qid, QKey
import diffsum
from dataformat import Diff, DiffReport, ReproData, ResolverID
from dataformat import Diff, DiffReport, FieldLabel, ReproData, WireFormat, QID # noqa
import msgdiff
from orchestrator import get_resolvers
import sendrecv
from sendrecv import ResolverID, RepliesBlob
T = TypeVar('T')
def restart_resolver(script_path: str) -> None:
......@@ -36,6 +44,61 @@ def get_restart_scripts(config: Mapping[str, Any]) -> Mapping[ResolverID, str]:
return restart_scripts
def disagreement_query_stream(
lmdb,
report: DiffReport,
skip_unstable: bool = True,
skip_non_reproducible: bool = True,
shuffle: bool = True
) -> Iterator[Tuple[QKey, WireFormat]]:
qids = report.target_disagreements.keys() # type: Union[Sequence[QID], AbstractSet[QID]]
if shuffle:
# create a new, randomized list from disagreements
qids = random.sample(qids, len(qids))
queries = diffsum.get_query_iterator(lmdb, qids)
for qid, qwire in queries:
diff = report.target_disagreements[qid]
reprocounter = report.reprodata[qid]
# verify if answers are stable
if skip_unstable and reprocounter.retries != reprocounter.upstream_stable:
logging.debug('Skipping QID %7d: unstable upstream', diff.qid)
continue
if skip_non_reproducible and reprocounter.retries != reprocounter.verified:
logging.debug('Skipping QID %7d: not 100 %% reproducible', diff.qid)
continue
yield qid2key(qid), qwire
def chunker(iterable: Iterable[T], size: int) -> Iterator[Iterable[T]]:
"""
Collect data into fixed-length chunks or blocks
chunker([x, y, z], 2) --> [x, y], [z, None]
"""
args = [iter(iterable)] * size
return zip_longest(*args)
def process_answers(
qkey: QKey,
replies: RepliesBlob,
report: DiffReport,
criteria: Sequence[FieldLabel],
target: ResolverID
) -> None:
qid = key2qid(qkey)
reprocounter = report.reprodata[qid]
wire_dict = pickle.loads(replies)
answers = msgdiff.decode_wire_dict(wire_dict)
others_agree, mismatches = msgdiff.compare(answers, criteria, target)
reprocounter.retries += 1
if others_agree:
reprocounter.upstream_stable += 1
if Diff(qid, mismatches) == report.target_disagreements[qid]:
reprocounter.verified += 1
def main():
cli.setup_logging()
parser = argparse.ArgumentParser(
......@@ -43,49 +106,49 @@ def main():
cli.add_arg_envdir(parser)
cli.add_arg_config(parser)
cli.add_arg_datafile(parser)
parser.add_argument('-s', '--sequential', action='store_true', default=False,
help='send one query at a time (slower, but more reliable)')
args = parser.parse_args()
sendrecv.module_init(args)
datafile = cli.get_datafile(args)
report = DiffReport.from_json(datafile)
criteria = args.cfg['diff']['criteria']
timeout = args.cfg['sendrecv']['timeout']
selector, sockets = sendrecv.sock_init(get_resolvers(args.cfg))
restart_scripts = get_restart_scripts(args.cfg)
if len(sockets) < len(args.cfg['servers']['names']):
logging.critical("Couldn't create sockets for all resolvers.")
sys.exit(1)
if args.sequential:
nproc = 1
else:
nproc = args.cfg['sendrecv']['jobs']
if report.reprodata is None:
report.reprodata = ReproData()
with LMDB(args.envdir, readonly=True) as lmdb:
lmdb.open_db(LMDB.QUERIES)
queries = diffsum.get_query_iterator(lmdb, report.target_disagreements)
for qid, qwire in queries:
diff = report.target_disagreements[qid]
reprocounter = report.reprodata[qid]
# verify if answers are stable
if reprocounter.retries != reprocounter.upstream_stable:
logging.debug('Skipping QID %d: unstable upstream', diff.qid)
continue
for script in restart_scripts.values():
restart_resolver(script)
wire_blobs, _ = sendrecv.send_recv_parallel(qwire, selector, sockets, timeout)
answers = msgdiff.decode_wire_dict(wire_blobs)
others_agree, mismatches = msgdiff.compare(
answers, criteria, args.cfg['diff']['target'])
reprocounter.retries += 1
if others_agree:
reprocounter.upstream_stable += 1
if diff == Diff(diff.qid, mismatches):
reprocounter.verified += 1
report.export_json(datafile)
dstream = disagreement_query_stream(lmdb, report)
try:
with pool.Pool(processes=nproc) as p:
done = 0
for process_args in chunker(dstream, nproc):
# restart resolvers and clear their cache
for script in restart_scripts.values():
restart_resolver(script)
process_args = [args for args in process_args if args is not None]
for qkey, replies, in p.imap_unordered(
sendrecv.worker_perform_single_query,
process_args,
chunksize=1):
process_answers(qkey, replies, report,
args.cfg['diff']['criteria'],
args.cfg['diff']['target'])
done += len(process_args)
logging.info('Processed {:4d} queries'.format(done))
finally:
# make sure data is saved in case of interrupt
report.export_json(datafile)
if __name__ == '__main__':
......
......@@ -4,7 +4,7 @@ import argparse
import collections
import logging
import sys
from typing import Any, Callable, Iterator, ItemsView, List, Set, Tuple, Union # noqa
from typing import Any, Callable, Iterable, Iterator, ItemsView, List, Set, Tuple, Union # noqa
import dns.message
import dns.rdatatype
......@@ -124,7 +124,7 @@ def print_mismatch_queries(
def get_query_iterator(
lmdb,
qids: Set[QID]
qids: Iterable[QID]
) -> Iterator[Tuple[QID, WireFormat]]:
qdb = lmdb.get_db(LMDB.QUERIES)
with lmdb.env.begin(qdb) as txn:
......
......@@ -2,11 +2,8 @@
import argparse
from functools import partial
import logging
import multiprocessing.pool as pool
import os
import pickle
import sys
from typing import Any, Dict, Iterator, Mapping, Optional, Sequence, Tuple # noqa
import dns.message
......@@ -15,8 +12,9 @@ import dns.exception
import cli
from dataformat import (
DataMismatch, DiffReport, Disagreements, DisagreementsCounter, FieldLabel, MismatchValue,
Reply, ResolverID, QID)
Reply, QID)
from dbhelper import LMDB, key2qid
from sendrecv import ResolverID
lmdb = None # type: Optional[Any]
......@@ -268,11 +266,6 @@ def main():
criteria = args.cfg['diff']['criteria']
target = args.cfg['diff']['target']
# JSON report has to be created by orchestrator
if not os.path.exists(datafile):
logging.error("JSON report (%s) doesn't exist!", datafile)
sys.exit(1)
with LMDB(args.envdir, fast=True) as lmdb_:
lmdb = lmdb_
lmdb.open_db(LMDB.ANSWERS)
......
......@@ -4,89 +4,15 @@ import argparse
import logging
import multiprocessing.pool as pool
import os
import pickle
import random
import threading
import time
from typing import List, Tuple, Dict, Any, Mapping, Sequence # noqa: type hints
import sys
import cli
from dataformat import DiffReport, ResolverID
from dataformat import DiffReport
from dbhelper import LMDB
import sendrecv
worker_state = threading.local()
resolvers = [] # type: List[Tuple[str, str, str, int]]
ignore_timeout = False
max_timeouts = 10 # crash when N consecutive timeouts are received from a single resolver
timeout = None
time_delay_min = 0
time_delay_max = 0
def worker_init():
"""
make sure it works with distincts processes and threads as well
"""
worker_state.timeouts = {}
worker_reinit()
def worker_reinit():
selector, sockets = sendrecv.sock_init(resolvers)
worker_state.selector = selector
worker_state.sockets = sockets
def worker_deinit(selector, sockets):
"""
Close all sockets and selector.
"""
selector.close()
for _, sck, _ in sockets:
sck.close()
def worker_query_lmdb_wrapper(args):
qid, qwire = args
selector = worker_state.selector
sockets = worker_state.sockets
# optional artificial delay for testing
if time_delay_max > 0:
time.sleep(random.uniform(time_delay_min, time_delay_max))
replies, reinit = sendrecv.send_recv_parallel(qwire, selector, sockets, timeout)
if not ignore_timeout:
check_timeout(replies)
if reinit: # a connection is broken or something
# TODO: log this?
worker_deinit(selector, sockets)
worker_reinit()
blob = pickle.dumps(replies)
return (qid, blob)
def check_timeout(replies):
for resolver, reply in replies.items():
timeouts = worker_state.timeouts
if reply.wire is not None:
timeouts[resolver] = 0
else:
timeouts[resolver] = timeouts.get(resolver, 0) + 1
if timeouts[resolver] >= max_timeouts:
raise RuntimeError(
"Resolver '{}' timed-out {:d} times in a row. "
"Use '--ignore-timeout' to supress this error.".format(
resolver, max_timeouts))
def export_statistics(lmdb, datafile, start_time):
qdb = lmdb.get_db(LMDB.QUERIES)
adb = lmdb.get_db(LMDB.ANSWERS)
......@@ -109,22 +35,7 @@ def export_statistics(lmdb, datafile, start_time):
report.export_json(datafile)
def get_resolvers(config: Mapping[str, Any]) -> Sequence[Tuple[ResolverID, str, str, int]]:
resolvers_ = []
for resname in config['servers']['names']:
rescfg = config[resname]
resolvers_.append((resname, rescfg['ip'], rescfg['transport'], rescfg['port']))
return resolvers_
def main():
global ignore_timeout
global max_timeouts
global resolvers
global timeout
global time_delay_min
global time_delay_max
cli.setup_logging()
parser = argparse.ArgumentParser(
description='read queries from LMDB, send them in parallel to servers '
......@@ -136,16 +47,8 @@ def main():
help='continue despite consecutive timeouts from resolvers')
args = parser.parse_args()
datafile = cli.get_datafile(args)
resolvers = get_resolvers(args.cfg)
ignore_timeout = args.ignore_timeout
timeout = args.cfg['sendrecv']['timeout']
time_delay_min = args.cfg['sendrecv']['time_delay_min']
time_delay_max = args.cfg['sendrecv']['time_delay_max']
try:
max_timeouts = args.cfg['sendrecv']['max_timeouts']
except KeyError:
pass
sendrecv.module_init(args)
datafile = cli.get_datafile(args, check_exists=False)
start_time = int(time.time())
with LMDB(args.envdir) as lmdb:
......@@ -158,14 +61,14 @@ def main():
# process queries in parallel
with pool.Pool(
processes=args.cfg['sendrecv']['jobs'],
initializer=worker_init) as p:
initializer=sendrecv.worker_init) as p:
i = 0
for qid, blob in p.imap(worker_query_lmdb_wrapper, qstream,
chunksize=100):
for qkey, blob in p.imap(sendrecv.worker_perform_query, qstream,
chunksize=100):
i += 1
if i % 10000 == 0:
logging.info('Received {:d} answers'.format(i))
txn.put(qid, blob)
txn.put(qkey, blob)
except RuntimeError as err:
logging.error(err)
sys.exit(1)
......
"""
sendrecv module
===============
This module is used by orchestrator and diffrepro to perform DNS queries in parallel.
The entire module keeps a global state, which enables its easy use with both
threads or processes. Make sure not to break this compatibility.
"""
from argparse import Namespace
import pickle
import random
import selectors
import socket
import ssl
import struct
import time
from typing import Dict, Mapping, Tuple # noqa
import threading
from typing import Any, Dict, List, Mapping, Sequence, Tuple # noqa: type hints
import dns.inet
import dns.message
from dataformat import Reply, ResolverID
from dataformat import Reply, WireFormat
from dbhelper import QKey
ResolverID = str
IP = str
Protocol = str
Port = int
RepliesBlob = bytes
IsStreamFlag = bool # Is message preceeded by RFC 1035 section 4.2.2 length?
ReinitFlag = bool
Selector = selectors.BaseSelector
Socket = socket.socket
ResolverSockets = Sequence[Tuple[ResolverID, Socket, IsStreamFlag]]
# module-wide state variables
__resolvers = [] # type: Sequence[Tuple[ResolverID, IP, Protocol, Port]]
__worker_state = threading.local()
__max_timeouts = 10 # crash when N consecutive timeouts are received from a single resolver
__ignore_timeout = False
__timeout = None
__time_delay_min = 0
__time_delay_max = 0
__timeout_replies = {} # type: Dict[float, Reply]
def module_init(args: Namespace) -> None:
global __resolvers
global __max_timeouts
global __ignore_timeout
global __timeout
global __time_delay_min
global __time_delay_max
__resolvers = get_resolvers(args.cfg)
__timeout = args.cfg['sendrecv']['timeout']
__time_delay_min = args.cfg['sendrecv']['time_delay_min']
__time_delay_max = args.cfg['sendrecv']['time_delay_max']
try:
__max_timeouts = args.cfg['sendrecv']['max_timeouts']
except KeyError:
pass
try:
__ignore_timeout = args.ignore_timeout
except AttributeError:
pass
def worker_init() -> None:
__worker_state.timeouts = {}
worker_reinit()
def worker_reinit() -> None:
selector, sockets = sock_init() # type: Tuple[Selector, ResolverSockets]
__worker_state.selector = selector
__worker_state.sockets = sockets
def worker_deinit() -> None:
selector = __worker_state.selector
sockets = __worker_state.sockets
selector.close()
for _, sck, _ in sockets: # type: ignore # python/mypy#465
sck.close()
TIMEOUT_REPLIES = {} # type: Dict[float, Reply]
def worker_perform_query(args: Tuple[QKey, WireFormat]) -> Tuple[QKey, RepliesBlob]:
"""DNS query performed by orchestrator"""
qkey, qwire = args
selector = __worker_state.selector
sockets = __worker_state.sockets
# optional artificial delay for testing
if __time_delay_max > 0:
time.sleep(random.uniform(__time_delay_min, __time_delay_max))
replies, reinit = send_recv_parallel(qwire, selector, sockets, __timeout)
if not __ignore_timeout:
_check_timeout(replies)
if reinit: # a connection is broken or something
worker_deinit()
worker_reinit()
blob = pickle.dumps(replies)
return qkey, blob
def worker_perform_single_query(args: Tuple[QKey, WireFormat]) -> Tuple[QKey, RepliesBlob]:
"""Perform a single DNS query with setup and teardown of sockets. Used by diffrepro."""
qkey, qwire = args
worker_reinit()
selector = __worker_state.selector
sockets = __worker_state.sockets
replies, _ = send_recv_parallel(qwire, selector, sockets, __timeout)
worker_deinit()
blob = pickle.dumps(replies)
return qkey, blob
def get_resolvers(
config: Mapping[str, Any]
) -> Sequence[Tuple[ResolverID, IP, Protocol, Port]]:
resolvers = []
for resname in config['servers']['names']:
rescfg = config[resname]
resolvers.append((resname, rescfg['ip'], rescfg['transport'], rescfg['port']))
return resolvers
def _check_timeout(replies: Mapping[ResolverID, Reply]) -> None:
for resolver, reply in replies.items():
timeouts = __worker_state.timeouts
if reply.wire is not None:
timeouts[resolver] = 0
else:
timeouts[resolver] = timeouts.get(resolver, 0) + 1
if timeouts[resolver] >= __max_timeouts:
raise RuntimeError(
"Resolver '{}' timed-out {:d} times in a row. "
"Use '--ignore-timeout' to supress this error.".format(
resolver, __max_timeouts))
def sock_init(resolvers):
"""
resolvers: [(name, ipaddr, transport, port)]
returns (selector, [(name, socket, isstream)])
"""
def sock_init() -> Tuple[Selector, Sequence[Tuple[ResolverID, Socket, IsStreamFlag]]]:
sockets = []
selector = selectors.DefaultSelector()
for name, ipaddr, transport, port in resolvers:
for name, ipaddr, transport, port in __resolvers:
af = dns.inet.af_for_address(ipaddr)
if af == dns.inet.AF_INET:
destination = (ipaddr, port)
destination = (ipaddr, port) # type: Any
elif af == dns.inet.AF_INET6:
destination = (ipaddr, port, 0, 0)
else:
......@@ -46,16 +182,11 @@ def sock_init(resolvers):
sockets.append((name, sock, isstream))
selector.register(sock, selectors.EVENT_READ, (name, isstream))
# selector.close() ? # TODO
return selector, sockets
def _recv_msg(sock, isstream):
"""
receive DNS message from socket
issteam: Is message preceeded by RFC 1035 section 4.2.2 length?
returns: wire format without preambule or ConnectionError
"""
def _recv_msg(sock: Socket, isstream: IsStreamFlag) -> WireFormat:
"""Receive DNS message from socket and remove preambule (if present)."""
if isstream: # parse preambule
blength = sock.recv(2) # TODO: does not work with TLS: , socket.MSG_WAITALL)
if not blength: # stream closed
......@@ -67,18 +198,15 @@ def _recv_msg(sock, isstream):
def send_recv_parallel(
dgram,
selector,
sockets,
dgram: WireFormat, # DNS message suitable for UDP transport
selector: Selector,
sockets: ResolverSockets,
timeout: float
) -> Tuple[Mapping[ResolverID, Reply], bool]:
"""