Commit 60a4be9a authored by Tomas Krizek's avatar Tomas Krizek

diffrepro: fix query shuffling

parent ff14bebd
......@@ -7,12 +7,14 @@ from multiprocessing import pool
import pickle
import random
import subprocess
from typing import Any, Iterable, Iterator, Mapping, Sequence, Tuple, TypeVar
from typing import ( # noqa
Any, AbstractSet, Iterable, Iterator, Mapping, Sequence, Tuple, TypeVar,
Union)
import cli
from dbhelper import LMDB, qid2key, key2qid, QKey
import diffsum
from dataformat import Diff, DiffReport, FieldLabel, ReproData, WireFormat
from dataformat import Diff, DiffReport, FieldLabel, ReproData, WireFormat, QID # noqa
import msgdiff
import sendrecv
from sendrecv import ResolverID, RepliesBlob
......@@ -49,10 +51,10 @@ def disagreement_query_stream(
skip_non_reproducible: bool = True,
shuffle: bool = True
) -> Iterator[Tuple[QKey, WireFormat]]:
qids = report.target_disagreements.keys()
qids = report.target_disagreements.keys() # type: Union[Sequence[QID], AbstractSet[QID]]
if shuffle:
# create a new, randomized list from disagreements
qids = set(random.sample(qids, len(qids)))
qids = random.sample(qids, len(qids))
queries = diffsum.get_query_iterator(lmdb, qids)
for qid, qwire in queries:
diff = report.target_disagreements[qid]
......
......@@ -4,7 +4,7 @@ import argparse
import collections
import logging
import sys
from typing import Any, Callable, Iterator, ItemsView, List, Set, Tuple, Union # noqa
from typing import Any, Callable, Iterable, Iterator, ItemsView, List, Set, Tuple, Union # noqa
import dns.message
import dns.rdatatype
......@@ -124,7 +124,7 @@ def print_mismatch_queries(
def get_query_iterator(
lmdb,
qids: Set[QID]
qids: Iterable[QID]
) -> Iterator[Tuple[QID, WireFormat]]:
qdb = lmdb.get_db(LMDB.QUERIES)
with lmdb.env.begin(qdb) as txn:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment