Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • knot/knot-resolver
  • dkg/resolver
  • sbalazik/resolver
  • anb/knot-resolver
  • tkrizek/knot-resolver
  • jono/knot-resolver
  • analogic/knot-resolver
  • flokli/knot-resolver
  • hectorm/knot-resolver
  • aisha/knot-resolver
10 results
Show changes
Showing
with 9494 additions and 373 deletions
-- vim:syntax=lua
setLocal('{{SELF_ADDR}}')
setVerboseHealthChecks(true)
setServerPolicy(firstAvailable)
local server = newServer({
address="{{PROGRAMS['kresd']['address']}}",
useProxyProtocol=true,
checkName="example.cz."
})
server:setUp()
-- SPDX-License-Identifier: GPL-3.0-or-later
{% raw %}
modules.load('view < policy')
view:addr("127.127.0.0", policy.suffix(policy.DENY_MSG("addr 127.127.0.0 matched com"),{"\3com\0"}))
-- make sure DNSSEC is turned off for tests
trust_anchors.remove('.')
-- Disable RFC5011 TA update
if ta_update then
modules.unload('ta_update')
end
-- Disable RFC8145 signaling, scenario doesn't provide expected answers
if ta_signal_query then
modules.unload('ta_signal_query')
end
-- Disable RFC8109 priming, scenario doesn't provide expected answers
if priming then
modules.unload('priming')
end
-- Disable this module because it make one priming query
if detect_time_skew then
modules.unload('detect_time_skew')
end
_hint_root_file('hints')
cache.size = 2*MB
log_level('debug')
{% endraw %}
-- Allow PROXYv2 from dnsdist's address
--net.proxy_allowed("{{PROGRAMS['dnsdist']['address']}}")
net.proxy_allowed("127.127.0.0/16")
net = { '{{SELF_ADDR}}' }
{% if QMIN == "false" %}
option('NO_MINIMIZE', true)
{% else %}
option('NO_MINIMIZE', false)
{% endif %}
-- Self-checks on globals
assert(help() ~= nil)
assert(worker.id ~= nil)
-- Self-checks on facilities
assert(cache.count() == 0)
assert(cache.stats() ~= nil)
assert(cache.backends() ~= nil)
assert(worker.stats() ~= nil)
assert(net.interfaces() ~= nil)
-- Self-checks on loaded stuff
assert(net.list()[1].transport.ip == '{{SELF_ADDR}}')
assert(#modules.list() > 0)
-- Self-check timers
ev = event.recurrent(1 * sec, function (ev) return 1 end)
event.cancel(ev)
ev = event.after(0, function (ev) return 1 end)
; SPDX-License-Identifier: GPL-3.0-or-later
; config options
server:
harden-referral-path: no
target-fetch-policy: "0 0 0 0 0"
stub-zone:
name: "."
stub-addr: 193.0.14.129 # K.ROOT-SERVERS.NET.
stub-addr: 1.2.3.4
query-minimization: off
CONFIG_END
SCENARIO_BEGIN Disable EDNS0 and fancy stuff when the server replies with FORMERR.
SCENARIO_BEGIN proxyv2:valid test
RANGE_BEGIN 0 110
ADDRESS 1.2.3.4
STEP 10 QUERY
ENTRY_BEGIN
REPLY RD
MATCH opcode qtype qname
ADJUST copy_id
REPLY QR RD RA NOERROR
SECTION QUESTION
cz. IN A
example.cz. IN A
SECTION ANSWER
example.cz. IN A 5.6.7.8
ENTRY_END
; root prime
STEP 30 REPLY
ENTRY_BEGIN
MATCH opcode qtype qname
ADJUST copy_id
REPLY QR AA NOERROR
REPLY QR RD RA NOERROR
SECTION QUESTION
. IN NS
k.root-servers.net. IN AAAA
SECTION ANSWER
. IN NS K.ROOT-SERVERS.NET.
SECTION ADDITIONAL
K.ROOT-SERVERS.NET. IN A 193.0.14.129
k.root-servers.net. IN AAAA ::1
ENTRY_END
; query sent to root server
STEP 50 REPLY
RANGE_END
; query with PROXYv2 header - not blocked
STEP 10 QUERY
ENTRY_BEGIN
MATCH opcode qtype qname
ADJUST copy_id
REPLY QR NOERROR
ADJUST raw_id
REPLY RD
SECTION QUESTION
cz. IN A
SECTION AUTHORITY
cz. IN NS ns1.cz.
SECTION ADDITIONAL
ns1.cz. IN A 168.192.2.2
example.cz. IN A
ENTRY_END
; this is the formerr answer
STEP 60 REPLY
STEP 20 CHECK_ANSWER
ENTRY_BEGIN
MATCH opcode qtype qname
ADJUST copy_id
REPLY QR AA FORMERR
MATCH flags rcode question answer
REPLY QR RD RA NOERROR
SECTION QUESTION
cz. IN A
example.cz. IN A
SECTION ANSWER
example.cz. IN A 5.6.7.8
ENTRY_END
; this is the correct answer
STEP 60 REPLY
; query with PROXYv2 header - blocked by view:addr
; NXDOMAIN expected
STEP 30 QUERY
ENTRY_BEGIN
MATCH opcode qtype qname
ADJUST copy_id
REPLY QR AA NOERROR
REPLY RD
SECTION QUESTION
cz. IN A
SECTION ANSWER
cz. IN A 10.20.30.40
SECTION AUTHORITY
cz. IN NS ns1.cz.
SECTION ADDITIONAL
ns1.cz. IN A 168.192.2.2
example.com. IN A
ENTRY_END
; is the final answer correct?
STEP 100 CHECK_ANSWER
STEP 31 CHECK_ANSWER
ENTRY_BEGIN
MATCH all
REPLY QR RD RA
MATCH opcode question rcode additional
REPLY QR RD RA AA NXDOMAIN
SECTION QUESTION
cz. IN A
SECTION ANSWER
cz. IN A 10.20.30.40
example.com. IN A
SECTION ADDITIONAL
explanation.invalid. 10800 IN TXT "addr 127.127.0.0 matched com"
ENTRY_END
SCENARIO_END
/* Copyright (C) CZ.NIC, z.s.p.o. <knot-resolver@labs.nic.cz>
* SPDX-License-Identifier: GPL-3.0-or-later
*/
#include <stdatomic.h>
#include "daemon/ratelimiting.h"
#include "lib/kru.h"
#include "lib/mmapped.h"
#include "lib/utils.h"
#include "lib/resolve.h"
#define V4_PREFIXES (uint8_t[]) { 18, 20, 24, 32 }
#define V4_RATE_MULT (kru_price_t[]) { 768, 256, 32, 1 }
#define V6_PREFIXES (uint8_t[]) { 32, 48, 56, 64, 128 }
#define V6_RATE_MULT (kru_price_t[]) { 64, 4, 3, 2, 1 }
#define V4_PREFIXES_CNT (sizeof(V4_PREFIXES) / sizeof(*V4_PREFIXES))
#define V6_PREFIXES_CNT (sizeof(V6_PREFIXES) / sizeof(*V6_PREFIXES))
#define MAX_PREFIXES_CNT ((V4_PREFIXES_CNT > V6_PREFIXES_CNT) ? V4_PREFIXES_CNT : V6_PREFIXES_CNT)
struct ratelimiting {
size_t capacity;
uint32_t instant_limit;
uint32_t rate_limit;
uint32_t log_period;
uint16_t slip;
bool dry_run;
bool using_avx2;
_Atomic uint32_t log_time;
kru_price_t v4_prices[V4_PREFIXES_CNT];
kru_price_t v6_prices[V6_PREFIXES_CNT];
_Alignas(64) uint8_t kru[];
};
struct ratelimiting *ratelimiting = NULL;
struct mmapped ratelimiting_mmapped = {0};
/// return whether we're using optimized variant right now
static bool using_avx2(void)
{
bool result = (KRU.initialize == KRU_AVX2.initialize);
kr_require(result || KRU.initialize == KRU_GENERIC.initialize);
return result;
}
int ratelimiting_init(const char *mmap_file, size_t capacity, uint32_t instant_limit,
uint32_t rate_limit, uint16_t slip, uint32_t log_period, bool dry_run)
{
size_t capacity_log = 0;
for (size_t c = capacity - 1; c > 0; c >>= 1) capacity_log++;
size_t size = offsetof(struct ratelimiting, kru) + KRU.get_size(capacity_log);
struct ratelimiting header = {
.capacity = capacity,
.instant_limit = instant_limit,
.rate_limit = rate_limit,
.log_period = log_period,
.slip = slip,
.dry_run = dry_run,
.using_avx2 = using_avx2()
};
size_t header_size = offsetof(struct ratelimiting, using_avx2) + sizeof(header.using_avx2);
static_assert( // no padding up to .using_avx2
offsetof(struct ratelimiting, using_avx2) ==
sizeof(header.capacity) +
sizeof(header.instant_limit) +
sizeof(header.rate_limit) +
sizeof(header.log_period) +
sizeof(header.slip) +
sizeof(header.dry_run),
"detected padding with undefined data inside mmapped header");
int ret = mmapped_init(&ratelimiting_mmapped, mmap_file, size, &header, header_size);
if (ret == MMAPPED_WAS_FIRST) {
kr_log_info(SYSTEM, "Initializing rate-limiting...\n");
ratelimiting = ratelimiting_mmapped.mem;
const kru_price_t base_price = KRU_LIMIT / instant_limit;
const kru_price_t max_decay = rate_limit > 1000ll * instant_limit ? base_price :
(uint64_t) base_price * rate_limit / 1000;
bool succ = KRU.initialize((struct kru *)ratelimiting->kru, capacity_log, max_decay);
if (!succ) {
ratelimiting = NULL;
ret = kr_error(EINVAL);
goto fail;
}
ratelimiting->log_time = kr_now() - log_period;
for (size_t i = 0; i < V4_PREFIXES_CNT; i++) {
ratelimiting->v4_prices[i] = base_price / V4_RATE_MULT[i];
}
for (size_t i = 0; i < V6_PREFIXES_CNT; i++) {
ratelimiting->v6_prices[i] = base_price / V6_RATE_MULT[i];
}
ret = mmapped_init_continue(&ratelimiting_mmapped);
if (ret != 0) goto fail;
kr_log_info(SYSTEM, "Rate-limiting initialized (%s).\n", (ratelimiting->using_avx2 ? "AVX2" : "generic"));
return 0;
} else if (ret == 0) {
ratelimiting = ratelimiting_mmapped.mem;
kr_log_info(SYSTEM, "Using existing rate-limiting data (%s).\n", (ratelimiting->using_avx2 ? "AVX2" : "generic"));
return 0;
} // else fail
fail:
kr_log_crit(SYSTEM, "Initialization of shared rate-limiting data failed.\n");
return ret;
}
void ratelimiting_deinit(void)
{
mmapped_deinit(&ratelimiting_mmapped);
ratelimiting = NULL;
}
bool ratelimiting_request_begin(struct kr_request *req)
{
if (!ratelimiting) return false;
if (!req->qsource.addr)
return false; // don't consider internal requests
if (req->qsource.price_factor16 == 0)
return false; // whitelisted
// We only do this on pure UDP. (also TODO if cookies get implemented)
const bool ip_validated = req->qsource.flags.tcp || req->qsource.flags.tls;
if (ip_validated) return false;
const uint32_t time_now = kr_now();
// classify
_Alignas(16) uint8_t key[16] = {0, };
uint8_t limited_prefix;
if (req->qsource.addr->sa_family == AF_INET6) {
struct sockaddr_in6 *ipv6 = (struct sockaddr_in6 *)req->qsource.addr;
memcpy(key, &ipv6->sin6_addr, 16);
// compute adjusted prices, using standard rounding
kru_price_t prices[V6_PREFIXES_CNT];
for (int i = 0; i < V6_PREFIXES_CNT; ++i) {
prices[i] = (req->qsource.price_factor16
* (uint64_t)ratelimiting->v6_prices[i] + (1<<15)) >> 16;
}
limited_prefix = KRU.limited_multi_prefix_or((struct kru *)ratelimiting->kru, time_now,
1, key, V6_PREFIXES, prices, V6_PREFIXES_CNT, NULL);
} else {
struct sockaddr_in *ipv4 = (struct sockaddr_in *)req->qsource.addr;
memcpy(key, &ipv4->sin_addr, 4); // TODO append port?
// compute adjusted prices, using standard rounding
kru_price_t prices[V4_PREFIXES_CNT];
for (int i = 0; i < V4_PREFIXES_CNT; ++i) {
prices[i] = (req->qsource.price_factor16
* (uint64_t)ratelimiting->v4_prices[i] + (1<<15)) >> 16;
}
limited_prefix = KRU.limited_multi_prefix_or((struct kru *)ratelimiting->kru, time_now,
0, key, V4_PREFIXES, prices, V4_PREFIXES_CNT, NULL);
}
if (!limited_prefix) return false; // not limited
// slip: truncating vs dropping
bool tc =
(ratelimiting->slip > 1) ?
((kr_rand_bytes(1) % ratelimiting->slip == 0) ? true : false) :
((ratelimiting->slip == 1) ? true : false);
// logging
uint32_t log_time_orig = atomic_load_explicit(&ratelimiting->log_time, memory_order_relaxed);
if (ratelimiting->log_period) {
while (time_now - log_time_orig + 1024 >= ratelimiting->log_period + 1024) {
if (atomic_compare_exchange_weak_explicit(&ratelimiting->log_time, &log_time_orig, time_now,
memory_order_relaxed, memory_order_relaxed)) {
kr_log_notice(SYSTEM, "address %s rate-limited on /%d (%s%s)\n",
kr_straddr(req->qsource.addr), limited_prefix,
ratelimiting->dry_run ? "dry-run, " : "",
tc ? "truncated" : "dropped");
break;
}
}
}
req->ratelimited = true; // we set this even on dry_run
if (ratelimiting->dry_run) return false;
// perform limiting
if (tc) { // TC=1: return truncated reply to force source IP validation
knot_pkt_t *answer = kr_request_ensure_answer(req);
if (!answer) { // something bad; TODO: perhaps improve recovery from this
kr_assert(false);
return true;
}
// at this point the packet should be pretty clear
// The TC=1 answer is not perfect, as the right RCODE might differ
// in some cases, but @vcunat thinks that NOERROR isn't really risky here.
knot_wire_set_tc(answer->wire);
knot_wire_clear_ad(answer->wire);
req->state = KR_STATE_DONE;
} else {
// no answer
req->options.NO_ANSWER = true;
req->state = KR_STATE_FAIL;
}
return true;
}
/* Copyright (C) CZ.NIC, z.s.p.o. <knot-resolver@labs.nic.cz>
* SPDX-License-Identifier: GPL-3.0-or-later
*/
#include <stdbool.h>
#include "lib/defines.h"
#include "lib/utils.h"
struct kr_request;
/** Initialize rate-limiting with shared mmapped memory.
* The existing data are used if another instance is already using the file
* and it was initialized with the same parameters; it fails on mismatch. */
KR_EXPORT
int ratelimiting_init(const char *mmap_file, size_t capacity, uint32_t instant_limit,
uint32_t rate_limit, uint16_t slip, uint32_t log_period, bool dry_run);
/** Do rate-limiting, during knot_layer_api::begin. */
KR_EXPORT
bool ratelimiting_request_begin(struct kr_request *req);
/** Remove mmapped file data if not used by other processes. */
KR_EXPORT
void ratelimiting_deinit(void);
/* Copyright (C) 2024 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
static void the_tests(void **state);
#include "./tests.inc.c" // NOLINT(bugprone-suspicious-include)
#define THREADS 4
#define BATCH_QUERIES_LOG 3 // threads acquire queries in batches of 8
#define HOSTS_LOG 3 // at most 6 attackers + 2 wildcard addresses for normal users
#define TICK_QUERIES_LOG 13 // at most 1024 queries per host per tick
// Expected range of limits for parallel test.
#define RANGE_INST(Vx, prefix) INST(Vx, prefix) - 1, INST(Vx, prefix) + THREADS - 1
#define RANGE_RATEM(Vx, prefix) RATEM(Vx, prefix) - 1, RATEM(Vx, prefix)
#define RANGE_UNLIM(queries) queries, queries
struct host {
uint32_t queries_per_tick;
int addr_family;
char *addr_format;
uint32_t min_passed, max_passed;
_Atomic uint32_t passed;
};
struct stage {
uint32_t first_tick, last_tick;
struct host hosts[1 << HOSTS_LOG];
};
struct runnable_data {
int prime;
_Atomic uint32_t *queries_acquired, *queries_done;
struct stage *stages;
};
static void *runnable(void *arg)
{
struct runnable_data *d = (struct runnable_data *)arg;
size_t si = 0;
char addr_str[40];
struct sockaddr_storage addr;
uint8_t wire[KNOT_WIRE_MIN_PKTSIZE] = { 0 };
knot_pkt_t answer = { .wire = wire };
struct kr_request req = {
.qsource.addr = (struct sockaddr *) &addr,
.qsource.price_factor16 = 1 << 16,
.answer = &answer
};
while (true) {
uint32_t qi1 = atomic_fetch_add(d->queries_acquired, 1 << BATCH_QUERIES_LOG);
/* increment time if needed; sync on incrementing using spinlock */
uint32_t tick = qi1 >> TICK_QUERIES_LOG;
for (size_t i = 1; tick != fakeclock_tick; i++) {
if ((*d->queries_done >> TICK_QUERIES_LOG) >= tick) {
fakeclock_tick = tick;
}
if (i % (1<<14) == 0) sched_yield();
__sync_synchronize();
}
/* increment stage if needed */
while (tick > d->stages[si].last_tick) {
++si;
if (!d->stages[si].first_tick) return NULL;
}
if (tick >= d->stages[si].first_tick) {
uint32_t qi2 = 0;
do {
uint32_t qi = qi1 + qi2;
/* perform query qi */
uint32_t hi = qi % (1 << HOSTS_LOG);
if (!d->stages[si].hosts[hi].queries_per_tick) continue;
uint32_t hqi = (qi % (1 << TICK_QUERIES_LOG)) >> HOSTS_LOG; // host query index within tick
if (hqi >= d->stages[si].hosts[hi].queries_per_tick) continue;
hqi += (qi >> TICK_QUERIES_LOG) * d->stages[si].hosts[hi].queries_per_tick; // across ticks
(void)snprintf(addr_str, sizeof(addr_str), d->stages[si].hosts[hi].addr_format,
hqi % 0xff, (hqi >> 8) % 0xff, (hqi >> 16) % 0xff);
kr_straddr_socket_set((struct sockaddr *)&addr, addr_str, 0);
if (!ratelimiting_request_begin(&req)) {
atomic_fetch_add(&d->stages[si].hosts[hi].passed, 1);
}
} while ((qi2 = (qi2 + d->prime) % (1 << BATCH_QUERIES_LOG)));
}
atomic_fetch_add(d->queries_done, 1 << BATCH_QUERIES_LOG);
}
}
static void the_tests(void **state)
{
/* parallel tests */
struct stage stages[] = {
/* first tick, last tick, hosts */
{32, 32, {
/* queries per tick, family, address, min passed, max passed */
{1024, AF_INET, "%d.%d.%d.1", RANGE_UNLIM ( 1024 )},
{1024, AF_INET, "3.3.3.3", RANGE_INST ( V4, 32 )},
{ 512, AF_INET, "4.4.4.4", RANGE_INST ( V4, 32 )},
{1024, AF_INET6, "%x%x:%x00::1", RANGE_UNLIM ( 1024 )},
{1024, AF_INET6, "3333::3333", RANGE_INST ( V6, 128 )},
{ 512, AF_INET6, "4444::4444", RANGE_INST ( V6, 128 )}
}},
{33, 255, {
{1024, AF_INET, "%d.%d.%d.1", RANGE_UNLIM ( 1024 )},
{1024, AF_INET, "3.3.3.3", RANGE_RATEM ( V4, 32 )},
{ 512, AF_INET, "4.4.4.4", RANGE_RATEM ( V4, 32 )},
{1024, AF_INET6, "%x%x:%x00::1", RANGE_UNLIM ( 1024 )},
{1024, AF_INET6, "3333::3333", RANGE_RATEM ( V6, 128 )},
{ 512, AF_INET6, "4444::4444", RANGE_RATEM ( V6, 128 )},
}},
{256, 511, {
{1024, AF_INET, "3.3.3.3", RANGE_RATEM ( V4, 32 )},
{1024, AF_INET6, "3333::3333", RANGE_RATEM ( V6, 128 )}
}},
{512, 512, {
{1024, AF_INET, "%d.%d.%d.1", RANGE_UNLIM ( 1024 )},
{1024, AF_INET, "3.3.3.3", RANGE_RATEM ( V4, 32 )},
{ 512, AF_INET, "4.4.4.4", RANGE_INST ( V4, 32 )},
{1024, AF_INET6, "%x%x:%x00::1", RANGE_UNLIM ( 1024 )},
{1024, AF_INET6, "3333::3333", RANGE_RATEM ( V6, 128 )},
{ 512, AF_INET6, "4444::4444", RANGE_INST ( V6, 128 )}
}},
{0}
};
pthread_t thr[THREADS];
struct runnable_data rd[THREADS];
_Atomic uint32_t queries_acquired = 0, queries_done = 0;
int primes[] = {3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61};
assert(sizeof(primes)/sizeof(*primes) >= THREADS);
for (unsigned i = 0; i < THREADS; ++i) {
rd[i].queries_acquired = &queries_acquired;
rd[i].queries_done = &queries_done;
rd[i].prime = primes[i];
rd[i].stages = stages;
pthread_create(thr + i, NULL, &runnable, rd + i);
}
for (unsigned i = 0; i < THREADS; ++i) {
pthread_join(thr[i], NULL);
}
unsigned si = 0;
do {
struct host * const h = stages[si].hosts;
uint32_t ticks = stages[si].last_tick - stages[si].first_tick + 1;
for (size_t i = 0; h[i].queries_per_tick; i++) {
assert_int_between(h[i].passed, ticks * h[i].min_passed, ticks * h[i].max_passed,
"parallel stage %d, addr %-25s", si, h[i].addr_format);
}
} while (stages[++si].first_tick);
}
/* Copyright (C) 2024 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
static void the_tests(void **state);
#include "./tests.inc.c" // NOLINT(bugprone-suspicious-include)
// defining count_test as macro to let it print usable line number on failure
#define count_test(DESC, EXPECTED_PASSING, MARGIN_FRACT, ...) { \
int _max_diff = (EXPECTED_PASSING) * (MARGIN_FRACT); \
int cnt = _count_test(EXPECTED_PASSING, __VA_ARGS__); \
assert_int_between(cnt, (EXPECTED_PASSING) - _max_diff, (EXPECTED_PASSING) + _max_diff, DESC); }
uint32_t _count_test(int expected_passing, int addr_family, char *addr_format, uint32_t min_value, uint32_t max_value)
{
uint32_t max_queries = expected_passing > 0 ? 2 * expected_passing : -expected_passing;
struct sockaddr_storage addr;
uint8_t wire[KNOT_WIRE_MIN_PKTSIZE] = { 0 };
knot_pkt_t answer = { .wire = wire };
struct kr_request req = {
.qsource.addr = (struct sockaddr *) &addr,
.qsource.price_factor16 = 1 << 16,
.answer = &answer
};
char addr_str[40];
int cnt = -1;
for (size_t i = 0; i < max_queries; i++) {
(void)snprintf(addr_str, sizeof(addr_str), addr_format,
i % (max_value - min_value + 1) + min_value,
i / (max_value - min_value + 1) % 256);
kr_straddr_socket_set((struct sockaddr *) &addr, addr_str, 0);
if (ratelimiting_request_begin(&req)) {
cnt = i;
break;
}
}
return cnt;
}
static void the_tests(void **state)
{
/* IPv4 multi-prefix tests */
static_assert(V4_PREFIXES_CNT == 4,
"There are no more IPv4 limited prefixes (/32, /24, /20, /18 will be tested).");
count_test("IPv4 instant limit /32", INST(V4, 32), 0,
AF_INET, "128.0.0.0", 0, 0);
count_test("IPv4 instant limit /32 not applied on /31", -1, 0,
AF_INET, "128.0.0.1", 0, 0);
count_test("IPv4 instant limit /24", INST(V4, 24) - INST(V4, 32) - 1, 0,
AF_INET, "128.0.0.%d", 2, 255);
count_test("IPv4 instant limit /24 not applied on /23", -1, 0,
AF_INET, "128.0.1.0", 0, 0);
count_test("IPv4 instant limit /20", INST(V4, 20) - INST(V4, 24) - 1, 0.001,
AF_INET, "128.0.%d.%d", 2, 15);
count_test("IPv4 instant limit /20 not applied on /19", -1, 0,
AF_INET, "128.0.16.0", 0, 0);
count_test("IPv4 instant limit /18", INST(V4, 18) - INST(V4, 20) - 1, 0.01,
AF_INET, "128.0.%d.%d", 17, 63);
count_test("IPv4 instant limit /18 not applied on /17", -1, 0,
AF_INET, "128.0.64.0", 0, 0);
/* IPv6 multi-prefix tests */
static_assert(V6_PREFIXES_CNT == 5,
"There are no more IPv6 limited prefixes (/128, /64, /56, /48, /32 will be tested).");
count_test("IPv6 instant limit /128, independent to IPv4", INST(V6, 128), 0,
AF_INET6, "8000::", 0, 0);
count_test("IPv6 instant limit /128 not applied on /127", -1, 0,
AF_INET6, "8000::1", 0, 0);
count_test("IPv6 instant limit /64", INST(V6, 64) - INST(V6, 128) - 1, 0,
AF_INET6, "8000:0:0:0:%02x%02x::", 0x01, 0xff);
count_test("IPv6 instant limit /64 not applied on /63", -1, 0,
AF_INET6, "8000:0:0:1::", 0, 0);
count_test("IPv6 instant limit /56", INST(V6, 56) - INST(V6, 64) - 1, 0,
AF_INET6, "8000:0:0:00%02x:%02x00::", 0x02, 0xff);
count_test("IPv6 instant limit /56 not applied on /55", -1, 0,
AF_INET6, "8000:0:0:0100::", 0, 0);
count_test("IPv6 instant limit /48", INST(V6, 48) - INST(V6, 56) - 1, 0.01,
AF_INET6, "8000:0:0:%02x%02x::", 0x02, 0xff);
count_test("IPv6 instant limit /48 not applied on /47", -1, 0,
AF_INET6, "8000:0:1::", 0, 0);
count_test("IPv6 instant limit /32", INST(V6, 32) - INST(V6, 48) - 1, 0.001,
AF_INET6, "8000:0:%02x%02x::", 0x02, 0xff);
count_test("IPv6 instant limit /32 not applied on /31", -1, 0,
AF_INET6, "8000:1::", 0, 0);
/* limit after 1 msec */
fakeclock_tick++;
count_test("IPv4 rate limit /32 after 1 msec", RATEM(V4, 32), 0,
AF_INET, "128.0.0.0", 0, 0);
count_test("IPv6 rate limit /128 after 1 msec", RATEM(V6, 128), 0,
AF_INET6, "8000::", 0, 0);
}
/* Copyright (C) 2024 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
#include <pthread.h>
#include <sched.h>
#include <stdio.h>
#include <stdatomic.h>
#include "tests/unit/test.h"
#include "libdnssec/crypto.h"
#include "libdnssec/random.h"
#include "libknot/libknot.h"
#include "contrib/openbsd/siphash.h"
#include "lib/resolve.h"
#include "lib/utils.h"
uint64_t fakeclock_now(void);
#define kr_now fakeclock_now
#include "daemon/ratelimiting.c"
#undef kr_now
#define RRL_TABLE_SIZE (1 << 20)
#define RRL_INSTANT_LIMIT (1 << 8)
#define RRL_RATE_LIMIT (1 << 17)
#define RRL_BASE_PRICE (KRU_LIMIT / RRL_INSTANT_LIMIT)
// Accessing RRL configuration of INSTANT/RATE limits for V4/V6 and specific prefix.
#define LIMIT(type, Vx, prefix) (RRL_MULT(Vx, prefix) * RRL_ ## type ## _LIMIT)
#define RRL_CONFIG(Vx, name) Vx ## _ ## name
#define RRL_MULT(Vx, prefix) get_mult(RRL_CONFIG(Vx, PREFIXES), RRL_CONFIG(Vx, RATE_MULT), RRL_CONFIG(Vx, PREFIXES_CNT), prefix)
static inline kru_price_t get_mult(uint8_t prefixes[], kru_price_t mults[], size_t cnt, uint8_t wanted_prefix) {
for (size_t i = 0; i < cnt; i++)
if (prefixes[i] == wanted_prefix)
return mults[i];
assert(0);
return 0;
}
// Instant limits and rate limits per msec.
#define INST(Vx, prefix) LIMIT(INSTANT, Vx, prefix)
#define RATEM(Vx, prefix) (LIMIT(RATE, Vx, prefix) / 1000)
/* Fix seed for randomness in RLL module. Change if improbable collisions arise. (one byte) */
#define RRL_SEED_GENERIC 1
#define RRL_SEED_AVX2 1
#define assert_int_between(VAL, MIN, MAX, ...) \
if (((MIN) > (VAL)) || ((VAL) > (MAX))) { \
fprintf(stderr, __VA_ARGS__); fprintf(stderr, ": %d <= %d <= %d, ", MIN, VAL, MAX); \
assert_true(false); }
struct kru_generic {
SIPHASH_KEY hash_key;
// ...
};
struct kru_avx2 {
_Alignas(32) char hash_key[48];
// ...
};
/* Override time. */
uint64_t fakeclock_tick = 0;
uint64_t fakeclock_start = 0;
void fakeclock_init(void)
{
fakeclock_start = kr_now();
fakeclock_tick = 0;
}
uint64_t fakeclock_now(void)
{
return fakeclock_start + fakeclock_tick;
}
static void test_rrl(void **state) {
dnssec_crypto_init();
fakeclock_init();
/* create rrl table */
const char *tmpdir = test_tmpdir_create();
char mmap_file[64];
stpcpy(stpcpy(mmap_file, tmpdir), "/ratelimiting");
ratelimiting_init(mmap_file, RRL_TABLE_SIZE, RRL_INSTANT_LIMIT, RRL_RATE_LIMIT, 0, 0, false);
if (KRU.initialize == KRU_GENERIC.initialize) {
struct kru_generic *kru = (struct kru_generic *) ratelimiting->kru;
memset(&kru->hash_key, RRL_SEED_GENERIC, sizeof(kru->hash_key));
} else if (KRU.initialize == KRU_AVX2.initialize) {
struct kru_avx2 *kru = (struct kru_avx2 *) ratelimiting->kru;
memset(&kru->hash_key, RRL_SEED_AVX2, sizeof(kru->hash_key));
} else {
assert(0);
}
the_tests(state);
ratelimiting_deinit();
test_tmpdir_remove(tmpdir);
dnssec_crypto_cleanup();
}
static void test_rrl_generic(void **state) {
KRU = KRU_GENERIC;
test_rrl(state);
}
static void test_rrl_avx2(void **state) {
KRU = KRU_AVX2;
test_rrl(state);
}
int main(int argc, char *argv[])
{
assert(KRU_GENERIC.initialize != KRU_AVX2.initialize);
if (KRU.initialize == KRU_AVX2.initialize) {
const UnitTest tests[] = {
unit_test(test_rrl_generic),
unit_test(test_rrl_avx2)
};
return run_tests(tests);
} else {
const UnitTest tests[] = {
unit_test(test_rrl_generic)
};
return run_tests(tests);
}
}
.. SPDX-License-Identifier: GPL-3.0-or-later
.. _runtime-cfg:
Run-time reconfiguration
========================
Knot Resolver offers several ways to modify its configuration at run-time:
- Using control socket driven by an external system
- Using Lua program embedded in Resolver's configuration file
Both ways can also be combined: For example the configuration file can contain
a little Lua function which gathers statistics and returns them in JSON string.
This can be used by an external system which uses control socket to call this
user-defined function and to retrieve its results.
.. _control-sockets:
Control sockets
---------------
Control socket acts like "an interactive configuration file" so all actions
available in configuration file can be executed interactively using the control
socket. One possible use-case is reconfiguring the resolver instances from
another program, e.g. a maintenance script.
.. note:: Each instance of Knot Resolver exposes its own control socket. Take
that into account when scripting deployments with
:ref:`systemd-multiple-instances`.
When Knot Resolver is started using Systemd (see section
`Startup <../gettingstarted-startup.html>`_) it creates a control socket in path
``/run/knot-resolver/control/$ID``. Connection to the socket can be made from
command line using e.g. ``socat``:
.. code-block:: bash
$ socat - UNIX-CONNECT:/run/knot-resolver/control/1
When successfully connected to a socket, the command line should change to
something like ``>``. Then you can interact with kresd to see configuration or
set a new one. There are some basic commands to start with.
.. code-block:: lua
> help() -- shows help
> net.interfaces() -- lists available interfaces
> net.list() -- lists running network services
The *direct output* of commands sent over socket is captured and sent back,
which gives you an immediate response on the outcome of your command.
The commands and their output are also logged in ``contrl`` group,
on ``debug`` level if successful or ``warning`` level if failed
(see around :func:`log_level`).
Control sockets are also a way to enumerate and test running instances, the
list of sockets corresponds to the list of processes, and you can test the
process for liveliness by connecting to the UNIX socket.
.. function:: map(lua_snippet)
Executes the provided string as lua code on every running resolver instance
and returns the results as a table.
Key ``n`` is always present in the returned table and specifies the total
number of instances the command was executed on. The table also contains
results from each instance accessible through keys ``1`` to ``n``
(inclusive). If any instance returns ``nil``, it is not explicitly part of
the table, but you can detect it by iterating through ``1`` to ``n``.
.. code-block:: lua
> map('worker.id') -- return an ID of every active instance
{
'2',
'1',
['n'] = 2,
}
> map('worker.id == "1" or nil') -- example of `nil` return value
{
[2] = true,
['n'] = 2,
}
The order of instances isn't guaranteed or stable. When you need to identify
the instances, you may use ``kluautil.kr_table_pack()`` function to return multiple
values as a table. It uses similar semantics with ``n`` as described above
to allow ``nil`` values.
.. code-block:: lua
> map('require("kluautil").kr_table_pack(worker.id, stats.get("answer.total"))')
{
{
'2',
42,
['n'] = 2,
},
{
'1',
69,
['n'] = 2,
},
['n'] = 2,
}
If the command fails on any instance, an error is returned and the execution
is in an undefined state (the command might not have been executed on all
instances). When using the ``map()`` function to execute any code that might
fail, your code should be wrapped in `pcall()
<https://www.lua.org/manual/5.1/manual.html#pdf-pcall>`_ to avoid this
issue.
.. code-block:: lua
> map('require("kluautil").kr_table_pack(pcall(net.tls, "cert.pem", "key.pem"))')
{
{
true, -- function succeeded
true, -- function return value(s)
['n'] = 2,
},
{
false, -- function failed
'error occurred...', -- the returned error message
['n'] = 2,
},
['n'] = 2,
}
Lua scripts
-----------
As it was mentioned in section :ref:`config-lua-syntax`, Resolver's configuration
file contains program in Lua programming language. This allows you to write
dynamic rules and helps you to avoid repetitive templating that is unavoidable
with static configuration. For example parts of configuration can depend on
:func:`hostname` of the machine:
.. code-block:: lua
if hostname() == 'hidden' then
net.listen(net.eth0, 5353)
else
net.listen('127.0.0.1')
net.listen(net.eth1.addr[1])
end
Another example would show how it is possible to bind to all interfaces, using
iteration.
.. code-block:: lua
for name, addr_list in pairs(net.interfaces()) do
net.listen(addr_list)
end
.. tip:: Some users observed a considerable, close to 100%, performance gain in
Docker containers when they bound the daemon to a single interface:ip
address pair. One may expand the aforementioned example with browsing
available addresses as:
.. code-block:: lua
addrpref = env.EXPECTED_ADDR_PREFIX
for k, v in pairs(addr_list["addr"]) do
if string.sub(v,1,string.len(addrpref)) == addrpref then
net.listen(v)
...
You can also use third-party Lua libraries (available for example through
LuaRocks_) as on this example to download cache from parent,
to avoid cold-cache start.
.. code-block:: lua
local http = require('socket.http')
local ltn12 = require('ltn12')
local cache_size = 100*MB
local cache_path = '/var/cache/knot-resolver'
cache.open(cache_size, 'lmdb://' .. cache_path)
if cache.count() == 0 then
cache.close()
-- download cache from parent
http.request {
url = 'http://parent/data.mdb',
sink = ltn12.sink.file(io.open(cache_path .. '/data.mdb', 'w'))
}
-- reopen cache with 100M limit
cache.open(cache_size, 'lmdb://' .. cache_path)
end
Helper functions
^^^^^^^^^^^^^^^^
Following built-in functions are useful for scripting:
.. envvar:: env (table)
Retrieve environment variables.
Example:
.. code-block:: lua
env.USER -- equivalent to $USER in shell
.. function:: fromjson(JSONstring)
:return: Lua representation of data in JSON string.
Example:
.. code-block:: lua
> fromjson('{"key1": "value1", "key2": {"subkey1": 1, "subkey2": 2}}')
[key1] => value1
[key2] => {
[subkey1] => 1
[subkey2] => 2
}
.. function:: hostname([fqdn])
:return: Machine hostname.
If called with a parameter, it will set kresd's internal
hostname. If called without a parameter, it will return kresd's
internal hostname, or the system's POSIX hostname (see
gethostname(2)) if kresd's internal hostname is unset.
This also affects ephemeral (self-signed) certificates generated by kresd
for DNS over TLS.
.. function:: package_version()
:return: Current package version as string.
Example:
.. code-block:: lua
> package_version()
2.1.1
.. function:: resolve(name, type[, class = kres.class.IN, options = {}, finish = nil, init = nil])
:param string name: Query name (e.g. 'com.')
:param number type: Query type (e.g. ``kres.type.NS``)
:param number class: Query class *(optional)* (e.g. ``kres.class.IN``)
:param strings options: Resolution options (see :c:type:`kr_qflags`)
:param function finish: Callback to be executed when resolution completes (e.g. `function cb (pkt, req) end`). The callback gets a packet containing the final answer and doesn't have to return anything.
:param function init: Callback to be executed with the :c:type:`kr_request` before resolution starts.
:return: boolean, ``true`` if resolution was started
The function can also be executed with a table of arguments instead. This is
useful if you'd like to skip some arguments, for example:
.. code-block:: lua
resolve {
name = 'example.com',
type = kres.type.AAAA,
init = function (req)
end,
}
Example:
.. code-block:: lua
-- Send query for root DNSKEY, ignore cache
resolve('.', kres.type.DNSKEY, kres.class.IN, 'NO_CACHE')
-- Query for AAAA record
resolve('example.com', kres.type.AAAA, kres.class.IN, 0,
function (pkt, req)
-- Check answer RCODE
if pkt:rcode() == kres.rcode.NOERROR then
-- Print matching records
local records = pkt:section(kres.section.ANSWER)
for i = 1, #records do
local rr = records[i]
if rr.type == kres.type.AAAA then
print ('record:', kres.rr2str(rr))
end
end
else
print ('rcode: ', pkt:rcode())
end
end)
.. function:: tojson(object)
:return: JSON text representation of `object`.
Example:
.. code-block:: lua
> testtable = { key1 = "value1", "key2" = { subkey1 = 1, subkey2 = 2 } }
> tojson(testtable)
{"key1":"value1","key2":{"subkey1":1,"subkey2":2}}
.. _async-events:
Asynchronous events
-------------------
Lua language used in configuration file allows you to script actions upon
various events, for example publish statistics each minute. Following example
uses built-in function :func:`event.recurrent()` which calls user-supplied
anonymous function:
.. code-block:: lua
local ffi = require('ffi')
modules.load('stats')
-- log statistics every second
local stat_id = event.recurrent(1 * second, function(evid)
log_info(ffi.C.LOG_GRP_STATISTICS, table_print(stats.list()))
end)
-- stop printing statistics after first minute
event.after(1 * minute, function(evid)
event.cancel(stat_id)
end)
Note that each scheduled event is identified by a number valid for the duration
of the event, you may use it to cancel the event at any time.
To persist state between two invocations of a function Lua uses concept called
closures_. In the following example function ``speed_monitor()`` is a closure
function, which provides persistent variable called ``previous``.
.. code-block:: lua
local ffi = require('ffi')
modules.load('stats')
-- make a closure, encapsulating counter
function speed_monitor()
local previous = stats.list()
-- monitoring function
return function(evid)
local now = stats.list()
local total_increment = now['answer.total'] - previous['answer.total']
local slow_increment = now['answer.slow'] - previous['answer.slow']
if slow_increment / total_increment > 0.05 then
log_warn(ffi.C.LOG_GRP_STATISTICS, 'WARNING! More than 5 %% of queries was slow!')
end
previous = now -- store current value in closure
end
end
-- monitor every minute
local monitor_id = event.recurrent(1 * minute, speed_monitor())
Another type of actionable event is activity on a file descriptor. This allows
you to embed other event loops or monitor open files and then fire a callback
when an activity is detected. This allows you to build persistent services
like monitoring probes that cooperate well with the daemon internal operations.
See :func:`event.socket()`.
Filesystem watchers are possible with :func:`worker.coroutine()` and cqueues_,
see the cqueues documentation for more information. Here is an simple example:
.. code-block:: lua
local notify = require('cqueues.notify')
local watcher = notify.opendir('/etc')
watcher:add('hosts')
-- Watch changes to /etc/hosts
worker.coroutine(function ()
for flags, name in watcher:changes() do
for flag in notify.flags(flags) do
-- print information about the modified file
print(name, notify[flag])
end
end
end)
.. include:: ../../daemon/bindings/event.rst
.. include:: ../../modules/etcd/README.rst
.. _closures: https://www.lua.org/pil/6.1.html
.. _cqueues: https://25thandclement.com/~william/projects/cqueues.html
.. _LuaRocks: https://luarocks.org/
/* Copyright (C) CZ.NIC, z.s.p.o. <knot-resolver@labs.nic.cz>
* SPDX-License-Identifier: GPL-3.0-or-later
*/
#include "kresconfig.h"
#include <ucw/lib.h>
#include <sys/socket.h>
#if ENABLE_XDP
#include <libknot/xdp/xdp.h>
#endif
#include "lib/log.h"
#include "lib/utils.h"
#include "daemon/io.h"
#include "daemon/udp_queue.h"
#include "daemon/worker.h"
#include "daemon/defer.h"
#include "daemon/proxyv2.h"
#include "daemon/session2.h"
#define VERBOSE_LOG(session, fmt, ...) do {\
if (kr_log_is_debug(PROTOLAYER, NULL)) {\
const char *sess_dir = (session)->outgoing ? "out" : "in";\
kr_log_debug(PROTOLAYER, "[%08X] (%s) " fmt, \
(session)->log_id, sess_dir, __VA_ARGS__);\
}\
} while (0);\
static uint32_t next_log_id = 1;
struct protolayer_globals protolayer_globals[PROTOLAYER_TYPE_COUNT] = {{0}};
static const enum protolayer_type protolayer_grp_udp53[] = {
PROTOLAYER_TYPE_UDP,
PROTOLAYER_TYPE_PROXYV2_DGRAM,
PROTOLAYER_TYPE_DEFER,
PROTOLAYER_TYPE_DNS_DGRAM,
};
static const enum protolayer_type protolayer_grp_tcp53[] = {
PROTOLAYER_TYPE_TCP,
PROTOLAYER_TYPE_PROXYV2_STREAM,
PROTOLAYER_TYPE_DEFER,
PROTOLAYER_TYPE_DNS_MULTI_STREAM,
};
static const enum protolayer_type protolayer_grp_dot[] = {
PROTOLAYER_TYPE_TCP,
PROTOLAYER_TYPE_PROXYV2_STREAM,
PROTOLAYER_TYPE_DEFER,
PROTOLAYER_TYPE_TLS,
PROTOLAYER_TYPE_DNS_MULTI_STREAM,
};
static const enum protolayer_type protolayer_grp_doh[] = {
PROTOLAYER_TYPE_TCP,
PROTOLAYER_TYPE_PROXYV2_STREAM,
PROTOLAYER_TYPE_DEFER,
PROTOLAYER_TYPE_TLS,
PROTOLAYER_TYPE_HTTP,
PROTOLAYER_TYPE_DNS_UNSIZED_STREAM,
};
static const enum protolayer_type protolayer_grp_doq[] = {
// not yet used
PROTOLAYER_TYPE_NULL,
};
struct protolayer_grp {
const enum protolayer_type *layers;
size_t num_layers;
};
#define PROTOLAYER_GRP(p_array) { \
.layers = (p_array), \
.num_layers = sizeof((p_array)) / sizeof((p_array)[0]), \
}
/** Sequences of layers, or groups, mapped by `enum kr_proto`.
*
* Each group represents a sequence of layers in the unwrap direction (wrap
* direction being the opposite). The sequence dictates the order in which
* individual layers are processed. This macro is used to generate global data
* about groups.
*
* To define a new group, add a new entry in the `KR_PROTO_MAP()` macro and
* create a new static `protolayer_grp_*` array above, similarly to the already
* existing ones. Each array must end with `PROTOLAYER_TYPE_NULL`, to
* indicate the end of the list of protocol layers. The array name's suffix must
* be the one defined as *Variable name* (2nd parameter) in the
* `KR_PROTO_MAP` macro. */
static const struct protolayer_grp protolayer_grps[KR_PROTO_COUNT] = {
#define XX(cid, vid, name) [KR_PROTO_##cid] = PROTOLAYER_GRP(protolayer_grp_##vid),
KR_PROTO_MAP(XX)
#undef XX
};
const char *protolayer_layer_name(enum protolayer_type p)
{
switch (p) {
case PROTOLAYER_TYPE_NULL:
return "(null)";
#define XX(cid) case PROTOLAYER_TYPE_ ## cid: \
return #cid;
PROTOLAYER_TYPE_MAP(XX)
#undef XX
default:
return "(invalid)";
}
}
const char *protolayer_event_name(enum protolayer_event_type e)
{
switch (e) {
case PROTOLAYER_EVENT_NULL:
return "(null)";
#define XX(cid) case PROTOLAYER_EVENT_ ## cid: \
return #cid;
PROTOLAYER_EVENT_MAP(XX)
#undef XX
default:
return "(invalid)";
}
}
const char *protolayer_payload_name(enum protolayer_payload_type p)
{
switch (p) {
case PROTOLAYER_PAYLOAD_NULL:
return "(null)";
#define XX(cid, name) case PROTOLAYER_PAYLOAD_ ## cid: \
return (name);
PROTOLAYER_PAYLOAD_MAP(XX)
#undef XX
default:
return "(invalid)";
}
}
/* Forward decls. */
static int session2_transport_pushv(struct session2 *s,
struct iovec *iov, int iovcnt,
bool iov_short_lived,
const struct comm_info *comm,
protolayer_finished_cb cb, void *baton);
static inline int session2_transport_push(struct session2 *s,
char *buf, size_t buf_len,
bool buf_short_lived,
const struct comm_info *comm,
protolayer_finished_cb cb, void *baton);
static int session2_transport_event(struct session2 *s,
enum protolayer_event_type event,
void *baton);
static size_t iovecs_size(const struct iovec *iov, int cnt)
{
size_t sum = 0;
for (int i = 0; i < cnt; i++) {
sum += iov[i].iov_len;
}
return sum;
}
static size_t iovecs_copy(void *dest, const struct iovec *iov, int cnt,
size_t max_len)
{
const size_t pld_size = iovecs_size(iov, cnt);
const size_t copy_size = MIN(max_len, pld_size);
char *cur = dest;
size_t remaining = copy_size;
for (int i = 0; i < cnt && remaining; i++) {
size_t l = iov[i].iov_len;
size_t to_copy = MIN(l, remaining);
memcpy(cur, iov[i].iov_base, to_copy);
remaining -= l;
cur += l;
}
kr_assert(remaining == 0 && (cur - (char *)dest) == copy_size);
return copy_size;
}
size_t protolayer_payload_size(const struct protolayer_payload *payload)
{
switch (payload->type) {
case PROTOLAYER_PAYLOAD_BUFFER:
return payload->buffer.len;
case PROTOLAYER_PAYLOAD_IOVEC:
return iovecs_size(payload->iovec.iov, payload->iovec.cnt);
case PROTOLAYER_PAYLOAD_WIRE_BUF:
return wire_buf_data_length(payload->wire_buf);
case PROTOLAYER_PAYLOAD_NULL:
return 0;
default:
kr_assert(false && "Invalid payload type");
return 0;
}
}
size_t protolayer_payload_copy(void *dest,
const struct protolayer_payload *payload,
size_t max_len)
{
const size_t pld_size = protolayer_payload_size(payload);
const size_t copy_size = MIN(max_len, pld_size);
if (payload->type == PROTOLAYER_PAYLOAD_BUFFER) {
memcpy(dest, payload->buffer.buf, copy_size);
return copy_size;
} else if (payload->type == PROTOLAYER_PAYLOAD_IOVEC) {
char *cur = dest;
size_t remaining = copy_size;
for (int i = 0; i < payload->iovec.cnt && remaining; i++) {
size_t l = payload->iovec.iov[i].iov_len;
size_t to_copy = MIN(l, remaining);
memcpy(cur, payload->iovec.iov[i].iov_base, to_copy);
remaining -= l;
cur += l;
}
kr_assert(remaining == 0 && (cur - (char *)dest) == copy_size);
return copy_size;
} else if (payload->type == PROTOLAYER_PAYLOAD_WIRE_BUF) {
memcpy(dest, wire_buf_data(payload->wire_buf), copy_size);
return copy_size;
} else if(!payload->type) {
return 0;
} else {
kr_assert(false && "Invalid payload type");
return 0;
}
}
struct protolayer_payload protolayer_payload_as_buffer(
const struct protolayer_payload *payload)
{
if (payload->type == PROTOLAYER_PAYLOAD_BUFFER)
return *payload;
if (payload->type == PROTOLAYER_PAYLOAD_WIRE_BUF) {
struct protolayer_payload new_payload = {
.type = PROTOLAYER_PAYLOAD_BUFFER,
.short_lived = payload->short_lived,
.ttl = payload->ttl,
.buffer = {
.buf = wire_buf_data(payload->wire_buf),
.len = wire_buf_data_length(payload->wire_buf)
}
};
wire_buf_reset(payload->wire_buf);
return new_payload;
}
kr_assert(false && "Unsupported payload type.");
return (struct protolayer_payload){
.type = PROTOLAYER_PAYLOAD_NULL
};
}
size_t protolayer_queue_count_payload(const protolayer_iter_ctx_queue_t *queue)
{
if (!queue || queue_len(*queue) == 0)
return 0;
size_t sum = 0;
/* We're only reading from the queue, but we need to discard the
* `const` so that `queue_it_begin()` accepts it. As long as
* `queue_it_` operations do not write into the queue (which they do
* not, checked at the time of writing), we should be safely in the
* defined behavior territory. */
queue_it_t(struct protolayer_iter_ctx *) it =
queue_it_begin(*(protolayer_iter_ctx_queue_t *)queue);
for (; !queue_it_finished(it); queue_it_next(it)) {
struct protolayer_iter_ctx *ctx = queue_it_val(it);
sum += protolayer_payload_size(&ctx->payload);
}
return sum;
}
bool protolayer_queue_has_payload(const protolayer_iter_ctx_queue_t *queue)
{
if (!queue || queue_len(*queue) == 0)
return false;
/* We're only reading from the queue, but we need to discard the
* `const` so that `queue_it_begin()` accepts it. As long as
* `queue_it_` operations do not write into the queue (which they do
* not, checked at the time of writing), we should be safely in the
* defined behavior territory. */
queue_it_t(struct protolayer_iter_ctx *) it =
queue_it_begin(*(protolayer_iter_ctx_queue_t *)queue);
for (; !queue_it_finished(it); queue_it_next(it)) {
struct protolayer_iter_ctx *ctx = queue_it_val(it);
if (protolayer_payload_size(&ctx->payload))
return true;
}
return false;
}
static inline ssize_t session2_get_protocol(
struct session2 *s, enum protolayer_type protocol)
{
const struct protolayer_grp *grp = &protolayer_grps[s->proto];
for (ssize_t i = 0; i < grp->num_layers; i++) {
enum protolayer_type found = grp->layers[i];
if (protocol == found)
return i;
}
return -1;
}
/** Gets layer-specific session data for the layer with the specified index
* from the manager. */
static inline struct protolayer_data *protolayer_sess_data_get(
struct session2 *s, size_t layer_ix)
{
const struct protolayer_grp *grp = &protolayer_grps[s->proto];
if (kr_fails_assert(layer_ix < grp->num_layers))
return NULL;
/* See doc comment of `struct session2::layer_data` */
const ssize_t *offsets = (ssize_t *)s->layer_data;
char *pl_data_beg = &s->layer_data[2 * grp->num_layers * sizeof(*offsets)];
ssize_t offset = offsets[layer_ix];
if (offset < 0) /* No session data for this layer */
return NULL;
return (struct protolayer_data *)(pl_data_beg + offset);
}
void *protolayer_sess_data_get_current(struct protolayer_iter_ctx *ctx)
{
return protolayer_sess_data_get(ctx->session, ctx->layer_ix);
}
void *protolayer_sess_data_get_proto(struct session2 *s, enum protolayer_type protocol) {
ssize_t layer_ix = session2_get_protocol(s, protocol);
if (layer_ix < 0)
return NULL;
return protolayer_sess_data_get(s, layer_ix);
}
/** Gets layer-specific iteration data for the layer with the specified index
* from the context. */
static inline struct protolayer_data *protolayer_iter_data_get(
struct protolayer_iter_ctx *ctx, size_t layer_ix)
{
struct session2 *s = ctx->session;
const struct protolayer_grp *grp = &protolayer_grps[s->proto];
if (kr_fails_assert(layer_ix < grp->num_layers))
return NULL;
/* See doc comment of `struct session2::layer_data` */
const ssize_t *offsets = (ssize_t *)&s->layer_data[grp->num_layers * sizeof(*offsets)];
ssize_t offset = offsets[layer_ix];
if (offset < 0) /* No iteration data for this layer */
return NULL;
return (struct protolayer_data *)(ctx->data + offset);
}
void *protolayer_iter_data_get_current(struct protolayer_iter_ctx *ctx)
{
return protolayer_iter_data_get(ctx, ctx->layer_ix);
}
size_t protolayer_sess_size_est(struct session2 *s)
{
return s->session_size + s->wire_buf.size;
}
size_t protolayer_iter_size_est(struct protolayer_iter_ctx *ctx, bool incl_payload)
{
size_t size = ctx->session->iter_ctx_size;
if (incl_payload)
size += protolayer_payload_size(&ctx->payload);
return size;
}
static inline bool protolayer_iter_ctx_is_last(struct protolayer_iter_ctx *ctx)
{
unsigned int last_ix = (ctx->direction == PROTOLAYER_UNWRAP)
? protolayer_grps[ctx->session->proto].num_layers - 1
: 0;
return ctx->layer_ix == last_ix;
}
static inline void protolayer_iter_ctx_next(struct protolayer_iter_ctx *ctx)
{
if (ctx->direction == PROTOLAYER_UNWRAP)
ctx->layer_ix++;
else
ctx->layer_ix--;
}
static inline const char *layer_name(enum kr_proto grp, ssize_t layer_ix)
{
if (grp >= KR_PROTO_COUNT)
return "(invalid)";
enum protolayer_type p = protolayer_grps[grp].layers[layer_ix];
return protolayer_layer_name(p);
}
static inline const char *layer_name_ctx(struct protolayer_iter_ctx *ctx)
{
return layer_name(ctx->session->proto, ctx->layer_ix);
}
static int protolayer_iter_ctx_finish(struct protolayer_iter_ctx *ctx, int ret)
{
struct session2 *s = ctx->session;
const struct protolayer_globals *globals = &protolayer_globals[s->proto];
const struct protolayer_grp *grp = &protolayer_grps[s->proto];
for (size_t i = 0; i < grp->num_layers; i++) {
struct protolayer_data *d = protolayer_iter_data_get(ctx, i);
if (globals->iter_deinit)
globals->iter_deinit(ctx, d);
}
if (ret) {
VERBOSE_LOG(s, "layer context of group '%s' (on %u: %s) ended with return code %d\n",
kr_proto_name(s->proto),
ctx->layer_ix, layer_name_ctx(ctx), ret);
}
if (ctx->status) {
VERBOSE_LOG(s, "iteration of group '%s' (on %u: %s) ended with status '%s (%d)'\n",
kr_proto_name(s->proto),
ctx->layer_ix, layer_name_ctx(ctx),
kr_strerror(ctx->status), ctx->status);
}
if (ctx->finished_cb)
ctx->finished_cb(ret, s, ctx->comm, ctx->finished_cb_baton);
mm_ctx_delete(&ctx->pool);
free(ctx);
session2_unhandle(s);
return ret;
}
static void protolayer_push_finished(int status, struct session2 *s, const struct comm_info *comm, void *baton)
{
struct protolayer_iter_ctx *ctx = baton;
ctx->status = status;
protolayer_iter_ctx_finish(ctx, PROTOLAYER_RET_NORMAL);
}
/** Pushes the specified protocol layer's payload to the session's transport. */
static int protolayer_push(struct protolayer_iter_ctx *ctx)
{
struct session2 *session = ctx->session;
if (ctx->payload.type == PROTOLAYER_PAYLOAD_WIRE_BUF) {
ctx->payload = protolayer_payload_as_buffer(&ctx->payload);
}
if (kr_log_is_debug(PROTOLAYER, NULL)) {
VERBOSE_LOG(session, "Pushing %s\n",
protolayer_payload_name(ctx->payload.type));
}
if (ctx->payload.type == PROTOLAYER_PAYLOAD_BUFFER) {
session2_transport_push(session,
ctx->payload.buffer.buf, ctx->payload.buffer.len,
ctx->payload.short_lived,
ctx->comm, protolayer_push_finished, ctx);
} else if (ctx->payload.type == PROTOLAYER_PAYLOAD_IOVEC) {
session2_transport_pushv(session,
ctx->payload.iovec.iov, ctx->payload.iovec.cnt,
ctx->payload.short_lived,
ctx->comm, protolayer_push_finished, ctx);
} else {
kr_assert(false && "Invalid payload type");
return kr_error(EINVAL);
}
return PROTOLAYER_RET_ASYNC;
}
static void protolayer_payload_ensure_long_lived(struct protolayer_iter_ctx *ctx)
{
if (!ctx->payload.short_lived)
return;
size_t buf_len = protolayer_payload_size(&ctx->payload);
if (kr_fails_assert(buf_len))
return;
void *buf = mm_alloc(&ctx->pool, buf_len);
kr_require(buf);
protolayer_payload_copy(buf, &ctx->payload, buf_len);
ctx->payload = protolayer_payload_buffer(buf, buf_len, false);
}
/** Processes as many layers as possible synchronously, returning when either
* a layer has gone asynchronous, or when the whole sequence has finished.
*
* May be called multiple times on the same `ctx` to continue processing after
* an asynchronous operation - user code will do this via *layer sequence return
* functions*. */
static int protolayer_step(struct protolayer_iter_ctx *ctx)
{
while (true) {
if (kr_fails_assert(ctx->session->proto < KR_PROTO_COUNT))
return kr_error(EFAULT);
enum protolayer_type protocol = protolayer_grps[ctx->session->proto].layers[ctx->layer_ix];
struct protolayer_globals *globals = &protolayer_globals[protocol];
bool was_async = ctx->async_mode;
ctx->async_mode = false;
/* Basically if we went asynchronous, we want to "resume" from
* underneath this `if` block. */
if (!was_async) {
ctx->status = 0;
ctx->action = PROTOLAYER_ITER_ACTION_NULL;
protolayer_iter_cb cb = (ctx->direction == PROTOLAYER_UNWRAP)
? globals->unwrap : globals->wrap;
if (ctx->session->closing) {
return protolayer_iter_ctx_finish(
ctx, kr_error(ECANCELED));
}
if (cb) {
struct protolayer_data *sess_data = protolayer_sess_data_get(
ctx->session, ctx->layer_ix);
struct protolayer_data *iter_data = protolayer_iter_data_get(
ctx, ctx->layer_ix);
enum protolayer_iter_cb_result result = cb(sess_data, iter_data, ctx);
if (kr_fails_assert(result == PROTOLAYER_ITER_CB_RESULT_MAGIC)) {
/* Callback did not use a *layer
* sequence return function* (see
* glossary). */
return protolayer_iter_ctx_finish(ctx, kr_error(EINVAL));
}
} else {
ctx->action = PROTOLAYER_ITER_ACTION_CONTINUE;
}
if (!ctx->action) {
/* We're going asynchronous - the next step is
* probably going to be from some sort of a
* callback and we will "resume" from underneath
* this `if` block. */
ctx->async_mode = true;
protolayer_payload_ensure_long_lived(ctx);
return PROTOLAYER_RET_ASYNC;
}
}
if (kr_fails_assert(ctx->action)) {
return protolayer_iter_ctx_finish(ctx, kr_error(EINVAL));
}
if (ctx->action == PROTOLAYER_ITER_ACTION_BREAK) {
return protolayer_iter_ctx_finish(
ctx, PROTOLAYER_RET_NORMAL);
}
if (kr_fails_assert(ctx->status == 0)) {
/* Status should be zero without a BREAK. */
return protolayer_iter_ctx_finish(ctx, kr_error(EINVAL));
}
if (ctx->action == PROTOLAYER_ITER_ACTION_CONTINUE) {
if (protolayer_iter_ctx_is_last(ctx)) {
if (ctx->direction == PROTOLAYER_WRAP)
return protolayer_push(ctx);
return protolayer_iter_ctx_finish(
ctx, PROTOLAYER_RET_NORMAL);
}
protolayer_iter_ctx_next(ctx);
continue;
}
/* Should never get here */
kr_assert(false && "Invalid layer callback action");
return protolayer_iter_ctx_finish(ctx, kr_error(EINVAL));
}
}
/** Submits the specified buffer to the sequence of layers represented by the
* specified protolayer manager. The sequence will be processed in the
* specified `direction`, starting by the layer specified by `layer_ix`.
*
* Returns PROTOLAYER_RET_NORMAL when all layers have finished,
* PROTOLAYER_RET_ASYNC when some layers are asynchronous and waiting for
* continuation, or a negative number for errors (kr_error). */
static int session2_submit(
struct session2 *session,
enum protolayer_direction direction, size_t layer_ix,
struct protolayer_payload payload, const struct comm_info *comm,
protolayer_finished_cb cb, void *baton)
{
if (session->closing)
return kr_error(ECANCELED);
if (session->ref_count >= INT_MAX - 1)
return kr_error(ETOOMANYREFS);
if (kr_fails_assert(session->proto < KR_PROTO_COUNT))
return kr_error(EFAULT);
bool had_comm_param = (comm != NULL);
if (!had_comm_param)
comm = &session->comm_storage;
// DEFER: at this point we might start doing nontrivial work,
// but we may not know the client's IP yet.
// Note two cases: incoming session (new request)
// vs. outgoing session (resuming work on some request)
if ((direction == PROTOLAYER_UNWRAP) && (layer_ix == 0))
defer_sample_start(NULL);
struct protolayer_iter_ctx *ctx = malloc(session->iter_ctx_size);
kr_require(ctx);
VERBOSE_LOG(session,
"%s submitted to grp '%s' in %s direction (%zu: %s)\n",
protolayer_payload_name(payload.type),
kr_proto_name(session->proto),
(direction == PROTOLAYER_UNWRAP) ? "unwrap" : "wrap",
layer_ix, layer_name(session->proto, layer_ix));
*ctx = (struct protolayer_iter_ctx) {
.payload = payload,
.direction = direction,
.layer_ix = layer_ix,
.session = session,
.finished_cb = cb,
.finished_cb_baton = baton
};
session->ref_count++;
if (had_comm_param) {
struct comm_addr_storage *addrst = &ctx->comm_addr_storage;
if (comm->src_addr) {
int len = kr_sockaddr_len(comm->src_addr);
kr_require(len > 0 && len <= sizeof(union kr_sockaddr));
memcpy(&addrst->src_addr, comm->src_addr, len);
ctx->comm_storage.src_addr = &addrst->src_addr.ip;
}
if (comm->comm_addr) {
int len = kr_sockaddr_len(comm->comm_addr);
kr_require(len > 0 && len <= sizeof(union kr_sockaddr));
memcpy(&addrst->comm_addr, comm->comm_addr, len);
ctx->comm_storage.comm_addr = &addrst->comm_addr.ip;
}
if (comm->dst_addr) {
int len = kr_sockaddr_len(comm->dst_addr);
kr_require(len > 0 && len <= sizeof(union kr_sockaddr));
memcpy(&addrst->dst_addr, comm->dst_addr, len);
ctx->comm_storage.dst_addr = &addrst->dst_addr.ip;
}
ctx->comm = &ctx->comm_storage;
} else {
ctx->comm = &session->comm_storage;
}
mm_ctx_mempool(&ctx->pool, CPU_PAGE_SIZE);
const struct protolayer_grp *grp = &protolayer_grps[session->proto];
for (size_t i = 0; i < grp->num_layers; i++) {
struct protolayer_globals *globals = &protolayer_globals[grp->layers[i]];
struct protolayer_data *iter_data = protolayer_iter_data_get(ctx, i);
if (iter_data) {
memset(iter_data, 0, globals->iter_size);
iter_data->session = session;
}
if (globals->iter_init)
globals->iter_init(ctx, iter_data);
}
int ret = protolayer_step(ctx);
if ((direction == PROTOLAYER_UNWRAP) && (layer_ix == 0))
defer_sample_stop(NULL, false);
return ret;
}
static void *get_init_param(enum protolayer_type p,
struct protolayer_data_param *layer_param,
size_t layer_param_count)
{
if (!layer_param || !layer_param_count)
return NULL;
for (size_t i = 0; i < layer_param_count; i++) {
if (layer_param[i].protocol == p)
return layer_param[i].param;
}
return NULL;
}
/** Called by *Layer sequence return functions* to proceed with protolayer
* processing. If the */
static inline void maybe_async_do_step(struct protolayer_iter_ctx *ctx)
{
if (ctx->async_mode)
protolayer_step(ctx);
}
enum protolayer_iter_cb_result protolayer_continue(struct protolayer_iter_ctx *ctx)
{
ctx->action = PROTOLAYER_ITER_ACTION_CONTINUE;
maybe_async_do_step(ctx);
return PROTOLAYER_ITER_CB_RESULT_MAGIC;
}
enum protolayer_iter_cb_result protolayer_break(struct protolayer_iter_ctx *ctx, int status)
{
ctx->status = status;
ctx->action = PROTOLAYER_ITER_ACTION_BREAK;
maybe_async_do_step(ctx);
return PROTOLAYER_ITER_CB_RESULT_MAGIC;
}
int wire_buf_init(struct wire_buf *wb, size_t initial_size)
{
char *buf = malloc(initial_size);
kr_require(buf);
*wb = (struct wire_buf){
.buf = buf,
.size = initial_size
};
return kr_ok();
}
void wire_buf_deinit(struct wire_buf *wb)
{
free(wb->buf);
}
int wire_buf_reserve(struct wire_buf *wb, size_t size)
{
if (wb->buf && wb->size >= size)
return kr_ok();
char *newbuf = realloc(wb->buf, size);
kr_require(newbuf);
wb->buf = newbuf;
wb->size = size;
return kr_ok();
}
int wire_buf_consume(struct wire_buf *wb, size_t length)
{
size_t ne = wb->end + length;
if (kr_fails_assert(wb->buf && ne <= wb->size))
return kr_error(EINVAL);
wb->end = ne;
return kr_ok();
}
int wire_buf_trim(struct wire_buf *wb, size_t length)
{
size_t ns = wb->start + length;
if (kr_fails_assert(ns <= wb->end))
return kr_error(EINVAL);
wb->start = ns;
return kr_ok();
}
int wire_buf_movestart(struct wire_buf *wb)
{
if (kr_fails_assert(wb->buf))
return kr_error(EINVAL);
if (wb->start == 0)
return kr_ok();
size_t len = wire_buf_data_length(wb);
if (len) {
if (wb->start < len)
memmove(wb->buf, wire_buf_data(wb), len);
else
memcpy(wb->buf, wire_buf_data(wb), len);
}
wb->start = 0;
wb->end = len;
return kr_ok();
}
int wire_buf_reset(struct wire_buf *wb)
{
wb->start = 0;
wb->end = 0;
return kr_ok();
}
struct session2 *session2_new(enum session2_transport_type transport_type,
enum kr_proto proto,
struct protolayer_data_param *layer_param,
size_t layer_param_count,
bool outgoing)
{
kr_require(transport_type && proto);
size_t session_size = sizeof(struct session2);
size_t iter_ctx_size = sizeof(struct protolayer_iter_ctx);
const struct protolayer_grp *grp = &protolayer_grps[proto];
if (kr_fails_assert(grp->num_layers))
return NULL;
size_t wire_buf_length = 0;
ssize_t offsets[2 * grp->num_layers];
session_size += sizeof(offsets);
ssize_t *sess_offsets = offsets;
ssize_t *iter_offsets = &offsets[grp->num_layers];
/* Space for layer-specific data, guaranteeing alignment */
size_t total_sess_data_size = 0;
size_t total_iter_data_size = 0;
for (size_t i = 0; i < grp->num_layers; i++) {
const struct protolayer_globals *g = &protolayer_globals[grp->layers[i]];
sess_offsets[i] = g->sess_size ? total_sess_data_size : -1;
total_sess_data_size += ALIGN_TO(g->sess_size, CPU_STRUCT_ALIGN);
iter_offsets[i] = g->iter_size ? total_iter_data_size : -1;
total_iter_data_size += ALIGN_TO(g->iter_size, CPU_STRUCT_ALIGN);
size_t wire_buf_overhead = (g->wire_buf_overhead_cb)
? g->wire_buf_overhead_cb(outgoing)
: g->wire_buf_overhead;
wire_buf_length += wire_buf_overhead;
}
session_size += total_sess_data_size;
iter_ctx_size += total_iter_data_size;
struct session2 *s = malloc(session_size);
kr_require(s);
*s = (struct session2) {
.transport = {
.type = transport_type,
},
.log_id = next_log_id++,
.outgoing = outgoing,
.tasks = trie_create(NULL),
.proto = proto,
.iter_ctx_size = iter_ctx_size,
.session_size = session_size,
};
memcpy(&s->layer_data, offsets, sizeof(offsets));
queue_init(s->waiting);
int ret = wire_buf_init(&s->wire_buf, wire_buf_length);
kr_require(!ret);
ret = uv_timer_init(uv_default_loop(), &s->timer);
kr_require(!ret);
s->timer.data = s;
s->ref_count++; /* Session owns the timer */
/* Initialize the layer's session data */
for (size_t i = 0; i < grp->num_layers; i++) {
struct protolayer_globals *globals = &protolayer_globals[grp->layers[i]];
struct protolayer_data *sess_data = protolayer_sess_data_get(s, i);
if (sess_data) {
memset(sess_data, 0, globals->sess_size);
sess_data->session = s;
}
void *param = get_init_param(grp->layers[i], layer_param, layer_param_count);
if (globals->sess_init)
globals->sess_init(s, sess_data, param);
}
session2_touch(s);
return s;
}
/** De-allocates the session. Must only be called once the underlying IO handle
* and timer are already closed, otherwise may leak resources. */
static void session2_free(struct session2 *s)
{
const struct protolayer_grp *grp = &protolayer_grps[s->proto];
for (size_t i = 0; i < grp->num_layers; i++) {
struct protolayer_globals *globals = &protolayer_globals[grp->layers[i]];
if (globals->sess_deinit) {
struct protolayer_data *sess_data = protolayer_sess_data_get(s, i);
globals->sess_deinit(s, sess_data);
}
}
wire_buf_deinit(&s->wire_buf);
trie_free(s->tasks);
queue_deinit(s->waiting);
free(s);
}
void session2_unhandle(struct session2 *s)
{
if (kr_fails_assert(s->ref_count > 0)) {
session2_free(s);
return;
}
s->ref_count--;
if (s->ref_count <= 0)
session2_free(s);
}
int session2_start_read(struct session2 *session)
{
if (session->transport.type == SESSION2_TRANSPORT_IO)
return io_start_read(session->transport.io.handle);
/* TODO - probably just some event for this */
kr_assert(false && "Parent start_read unsupported");
return kr_error(EINVAL);
}
int session2_stop_read(struct session2 *session)
{
if (session->transport.type == SESSION2_TRANSPORT_IO)
return io_stop_read(session->transport.io.handle);
/* TODO - probably just some event for this */
kr_assert(false && "Parent stop_read unsupported");
return kr_error(EINVAL);
}
struct sockaddr *session2_get_peer(struct session2 *s)
{
while (s && s->transport.type == SESSION2_TRANSPORT_PARENT)
s = s->transport.parent;
return (s && s->transport.type == SESSION2_TRANSPORT_IO)
? &s->transport.io.peer.ip
: NULL;
}
struct sockaddr *session2_get_sockname(struct session2 *s)
{
while (s && s->transport.type == SESSION2_TRANSPORT_PARENT)
s = s->transport.parent;
return (s && s->transport.type == SESSION2_TRANSPORT_IO)
? &s->transport.io.sockname.ip
: NULL;
}
uv_handle_t *session2_get_handle(struct session2 *s)
{
while (s && s->transport.type == SESSION2_TRANSPORT_PARENT)
s = s->transport.parent;
return (s && s->transport.type == SESSION2_TRANSPORT_IO)
? s->transport.io.handle
: NULL;
}
static void session2_on_timeout(uv_timer_t *timer)
{
struct session2 *s = timer->data;
session2_event(s, s->timer_event, NULL);
}
int session2_timer_start(struct session2 *s, enum protolayer_event_type event, uint64_t timeout, uint64_t repeat)
{
s->timer_event = event;
return uv_timer_start(&s->timer, session2_on_timeout, timeout, repeat);
}
int session2_timer_restart(struct session2 *s)
{
return uv_timer_again(&s->timer);
}
int session2_timer_stop(struct session2 *s)
{
return uv_timer_stop(&s->timer);
}
int session2_tasklist_add(struct session2 *session, struct qr_task *task)
{
trie_t *t = session->tasks;
uint16_t task_msg_id = 0;
const char *key = NULL;
size_t key_len = 0;
if (session->outgoing) {
knot_pkt_t *pktbuf = worker_task_get_pktbuf(task);
task_msg_id = knot_wire_get_id(pktbuf->wire);
key = (const char *)&task_msg_id;
key_len = sizeof(task_msg_id);
} else {
key = (const char *)&task;
key_len = sizeof(char *);
}
trie_val_t *v = trie_get_ins(t, key, key_len);
if (kr_fails_assert(v))
return kr_error(ENOMEM);
if (*v == NULL) {
*v = task;
worker_task_ref(task);
} else if (kr_fails_assert(*v == task)) {
return kr_error(EINVAL);
}
return kr_ok();
}
int session2_tasklist_del(struct session2 *session, struct qr_task *task)
{
trie_t *t = session->tasks;
uint16_t task_msg_id = 0;
const char *key = NULL;
size_t key_len = 0;
trie_val_t val;
if (session->outgoing) {
knot_pkt_t *pktbuf = worker_task_get_pktbuf(task);
task_msg_id = knot_wire_get_id(pktbuf->wire);
key = (const char *)&task_msg_id;
key_len = sizeof(task_msg_id);
} else {
key = (const char *)&task;
key_len = sizeof(char *);
}
int ret = trie_del(t, key, key_len, &val);
if (ret == KNOT_EOK) {
kr_require(val == task);
worker_task_unref(val);
}
return ret;
}
struct qr_task *session2_tasklist_get_first(struct session2 *session)
{
trie_val_t *val = trie_get_first(session->tasks, NULL, NULL);
return val ? (struct qr_task *) *val : NULL;
}
struct qr_task *session2_tasklist_del_first(struct session2 *session, bool deref)
{
trie_val_t val = NULL;
int res = trie_del_first(session->tasks, NULL, NULL, &val);
if (res != KNOT_EOK) {
val = NULL;
} else if (deref) {
worker_task_unref(val);
}
return (struct qr_task *)val;
}
struct qr_task *session2_tasklist_find_msgid(const struct session2 *session, uint16_t msg_id)
{
if (kr_fails_assert(session->outgoing))
return NULL;
trie_t *t = session->tasks;
struct qr_task *ret = NULL;
trie_val_t *val = trie_get_try(t, (char *)&msg_id, sizeof(msg_id));
if (val) {
ret = *val;
}
return ret;
}
struct qr_task *session2_tasklist_del_msgid(const struct session2 *session, uint16_t msg_id)
{
if (kr_fails_assert(session->outgoing))
return NULL;
trie_t *t = session->tasks;
struct qr_task *ret = NULL;
const char *key = (const char *)&msg_id;
size_t key_len = sizeof(msg_id);
trie_val_t val;
int res = trie_del(t, key, key_len, &val);
if (res == KNOT_EOK) {
if (worker_task_numrefs(val) > 1) {
ret = val;
}
worker_task_unref(val);
}
return ret;
}
void session2_tasklist_finalize(struct session2 *session, int status)
{
if (session2_tasklist_get_len(session) > 0) {
defer_sample_state_t defer_prev_sample_state;
defer_sample_start(&defer_prev_sample_state);
do {
struct qr_task *t = session2_tasklist_del_first(session, false);
kr_require(worker_task_numrefs(t) > 0);
worker_task_finalize(t, status);
worker_task_unref(t);
defer_sample_restart();
} while (session2_tasklist_get_len(session) > 0);
defer_sample_stop(&defer_prev_sample_state, true);
}
}
int session2_tasklist_finalize_expired(struct session2 *session)
{
int ret = 0;
queue_t(struct qr_task *) q;
uint64_t now = kr_now();
trie_t *t = session->tasks;
trie_it_t *it;
queue_init(q);
for (it = trie_it_begin(t); !trie_it_finished(it); trie_it_next(it)) {
trie_val_t *v = trie_it_val(it);
struct qr_task *task = (struct qr_task *)*v;
if ((now - worker_task_creation_time(task)) >= KR_RESOLVE_TIME_LIMIT) {
struct kr_request *req = worker_task_request(task);
if (!kr_fails_assert(req))
kr_query_inform_timeout(req, req->current_query);
queue_push(q, task);
worker_task_ref(task);
}
}
trie_it_free(it);
struct qr_task *task = NULL;
uint16_t msg_id = 0;
char *key = (char *)&task;
int32_t keylen = sizeof(struct qr_task *);
if (session->outgoing) {
key = (char *)&msg_id;
keylen = sizeof(msg_id);
}
if (queue_len(q) > 0) {
defer_sample_state_t defer_prev_sample_state;
defer_sample_start(&defer_prev_sample_state);
do {
task = queue_head(q);
if (session->outgoing) {
knot_pkt_t *pktbuf = worker_task_get_pktbuf(task);
msg_id = knot_wire_get_id(pktbuf->wire);
}
int res = trie_del(t, key, keylen, NULL);
if (!worker_task_finished(task)) {
/* task->pending_count must be zero,
* but there are can be followers,
* so run worker_task_subreq_finalize() to ensure retrying
* for all the followers. */
worker_task_subreq_finalize(task);
worker_task_finalize(task, KR_STATE_FAIL);
}
if (res == KNOT_EOK) {
worker_task_unref(task);
}
queue_pop(q);
worker_task_unref(task);
++ret;
defer_sample_restart();
} while (queue_len(q) > 0);
defer_sample_stop(&defer_prev_sample_state, true);
}
queue_deinit(q);
return ret;
}
int session2_waitinglist_push(struct session2 *session, struct qr_task *task)
{
queue_push(session->waiting, task);
worker_task_ref(task);
return kr_ok();
}
struct qr_task *session2_waitinglist_get(const struct session2 *session)
{
return (queue_len(session->waiting) > 0) ? (queue_head(session->waiting)) : NULL;
}
struct qr_task *session2_waitinglist_pop(struct session2 *session, bool deref)
{
struct qr_task *t = session2_waitinglist_get(session);
queue_pop(session->waiting);
if (deref) {
worker_task_unref(t);
}
return t;
}
void session2_waitinglist_retry(struct session2 *session, bool increase_timeout_cnt)
{
if (!session2_waitinglist_is_empty(session)) {
defer_sample_state_t defer_prev_sample_state;
defer_sample_start(&defer_prev_sample_state);
do {
struct qr_task *task = session2_waitinglist_pop(session, false);
if (increase_timeout_cnt) {
worker_task_timeout_inc(task);
}
worker_task_step(task, session2_get_peer(session), NULL);
worker_task_unref(task);
defer_sample_restart();
} while (!session2_waitinglist_is_empty(session));
defer_sample_stop(&defer_prev_sample_state, true);
}
}
void session2_waitinglist_finalize(struct session2 *session, int status)
{
if (!session2_waitinglist_is_empty(session)) {
defer_sample_state_t defer_prev_sample_state;
defer_sample_start(&defer_prev_sample_state);
do {
struct qr_task *t = session2_waitinglist_pop(session, false);
worker_task_finalize(t, status);
worker_task_unref(t);
defer_sample_restart();
} while (!session2_waitinglist_is_empty(session));
defer_sample_stop(&defer_prev_sample_state, true);
}
}
void session2_penalize(struct session2 *session)
{
if (session->was_useful || !session->outgoing)
return;
/* We want to penalize the IP address, if a task is asking a query.
* It might not be the right task, but that doesn't matter so much
* for attributing the useless session to the IP address. */
struct qr_task *t = session2_tasklist_get_first(session);
struct kr_query *qry = NULL;
if (t) {
struct kr_request *req = worker_task_request(t);
qry = array_tail(req->rplan.pending);
}
if (qry) /* We reuse the error for connection, as it's quite similar. */
qry->server_selection.error(qry, worker_task_get_transport(t),
KR_SELECTION_TCP_CONNECT_FAILED);
}
int session2_unwrap(struct session2 *s, struct protolayer_payload payload,
const struct comm_info *comm, protolayer_finished_cb cb,
void *baton)
{
return session2_submit(s, PROTOLAYER_UNWRAP,
0, payload, comm, cb, baton);
}
int session2_unwrap_after(struct session2 *s, enum protolayer_type protocol,
struct protolayer_payload payload,
const struct comm_info *comm,
protolayer_finished_cb cb, void *baton)
{
ssize_t layer_ix = session2_get_protocol(s, protocol);
bool ok = layer_ix >= 0 && layer_ix + 1 < protolayer_grps[s->proto].num_layers;
if (kr_fails_assert(ok)) // not found or "last layer"
return kr_error(EINVAL);
return session2_submit(s, PROTOLAYER_UNWRAP,
layer_ix + 1, payload, comm, cb, baton);
}
int session2_wrap(struct session2 *s, struct protolayer_payload payload,
const struct comm_info *comm, protolayer_finished_cb cb,
void *baton)
{
return session2_submit(s, PROTOLAYER_WRAP,
protolayer_grps[s->proto].num_layers - 1,
payload, comm, cb, baton);
}
int session2_wrap_after(struct session2 *s, enum protolayer_type protocol,
struct protolayer_payload payload,
const struct comm_info *comm,
protolayer_finished_cb cb, void *baton)
{
ssize_t layer_ix = session2_get_protocol(s, protocol);
if (kr_fails_assert(layer_ix > 0)) // not found or "last layer"
return kr_error(EINVAL);
return session2_submit(s, PROTOLAYER_WRAP, layer_ix - 1,
payload, comm, cb, baton);
}
static void session2_event_wrap(struct session2 *s, enum protolayer_event_type event, void *baton)
{
bool cont;
const struct protolayer_grp *grp = &protolayer_grps[s->proto];
for (ssize_t i = grp->num_layers - 1; i >= 0; i--) {
struct protolayer_globals *globals = &protolayer_globals[grp->layers[i]];
if (globals->event_wrap) {
struct protolayer_data *sess_data = protolayer_sess_data_get(s, i);
cont = globals->event_wrap(event, &baton, s, sess_data);
} else {
cont = true;
}
if (!cont)
return;
}
session2_transport_event(s, event, baton);
}
static void session2_event_unwrap(struct session2 *s, ssize_t start_ix, enum protolayer_event_type event, void *baton)
{
bool cont;
const struct protolayer_grp *grp = &protolayer_grps[s->proto];
for (ssize_t i = start_ix; i < grp->num_layers; i++) {
struct protolayer_globals *globals = &protolayer_globals[grp->layers[i]];
if (globals->event_unwrap) {
struct protolayer_data *sess_data = protolayer_sess_data_get(s, i);
cont = globals->event_unwrap(event, &baton, s, sess_data);
} else {
cont = true;
}
if (!cont)
return;
}
/* Immediately bounce back in the `wrap` direction.
*
* TODO: This might be undesirable for cases with sub-sessions - the
* current idea is for the layers managing sub-sessions to just return
* `PROTOLAYER_EVENT_CONSUME` on `event_unwrap`, but a more "automatic"
* mechanism may be added when this is relevant, to make it less
* error-prone. */
session2_event_wrap(s, event, baton);
}
void session2_event(struct session2 *s, enum protolayer_event_type event, void *baton)
{
/* Events may be sent from inside or outside of already measured code.
* From inside: close by us, statistics, ...
* From outside: timeout, EOF, close by external reasons, ... */
bool defer_accounting_here = false;
if (!defer_sample_is_accounting() && s->stream && !s->outgoing) {
defer_sample_start(NULL);
defer_accounting_here = true;
}
session2_event_unwrap(s, 0, event, baton);
if (defer_accounting_here)
defer_sample_stop(NULL, false);
}
void session2_event_after(struct session2 *s, enum protolayer_type protocol,
enum protolayer_event_type event, void *baton)
{
ssize_t start_ix = session2_get_protocol(s, protocol);
if (kr_fails_assert(start_ix >= 0))
return;
session2_event_unwrap(s, start_ix + 1, event, baton);
}
void session2_init_request(struct session2 *s, struct kr_request *req)
{
const struct protolayer_grp *grp = &protolayer_grps[s->proto];
for (ssize_t i = 0; i < grp->num_layers; i++) {
struct protolayer_globals *globals = &protolayer_globals[grp->layers[i]];
if (globals->request_init) {
struct protolayer_data *sess_data = protolayer_sess_data_get(s, i);
globals->request_init(s, req, sess_data);
}
}
}
struct session2_pushv_ctx {
struct session2 *session;
protolayer_finished_cb cb;
const struct comm_info *comm;
void *baton;
char *async_buf;
};
static void session2_transport_parent_pushv_finished(int status,
struct session2 *session,
const struct comm_info *comm,
void *baton)
{
struct session2_pushv_ctx *ctx = baton;
if (ctx->cb)
ctx->cb(status, ctx->session, comm, ctx->baton);
free(ctx->async_buf);
free(ctx);
}
static void session2_transport_pushv_finished(int status, struct session2_pushv_ctx *ctx)
{
if (ctx->cb)
ctx->cb(status, ctx->session, ctx->comm, ctx->baton);
free(ctx->async_buf);
free(ctx);
}
static void session2_transport_udp_queue_pushv_finished(int status, void *baton)
{
session2_transport_pushv_finished(status, baton);
}
static void session2_transport_udp_pushv_finished(uv_udp_send_t *req, int status)
{
session2_transport_pushv_finished(status, req->data);
free(req);
}
static void session2_transport_stream_pushv_finished(uv_write_t *req, int status)
{
session2_transport_pushv_finished(status, req->data);
free(req);
}
#if ENABLE_XDP
static void xdp_tx_waker(uv_idle_t *handle)
{
xdp_handle_data_t *xhd = handle->data;
int ret = knot_xdp_send_finish(xhd->socket);
if (ret != KNOT_EAGAIN && ret != KNOT_EOK)
kr_log_error(XDP, "check: ret = %d, %s\n", ret, knot_strerror(ret));
/* Apparently some drivers need many explicit wake-up calls
* even if we push no additional packets (in case they accumulated a lot) */
if (ret != KNOT_EAGAIN)
uv_idle_stop(handle);
knot_xdp_send_prepare(xhd->socket);
/* LATER(opt.): it _might_ be better for performance to do these two steps
* at different points in time */
while (queue_len(xhd->tx_waker_queue)) {
struct session2_pushv_ctx *ctx = queue_head(xhd->tx_waker_queue);
if (ctx->cb)
ctx->cb(kr_ok(), ctx->session, ctx->comm, ctx->baton);
free(ctx);
queue_pop(xhd->tx_waker_queue);
}
}
#endif
static void session2_transport_pushv_ensure_long_lived(
struct iovec **iov, int *iovcnt, bool iov_short_lived,
struct iovec *out_iovecmem, struct session2_pushv_ctx *ctx)
{
if (!iov_short_lived)
return;
size_t iovsize = iovecs_size(*iov, *iovcnt);
if (kr_fails_assert(iovsize))
return;
void *buf = malloc(iovsize);
kr_require(buf);
iovecs_copy(buf, *iov, *iovcnt, iovsize);
ctx->async_buf = buf;
out_iovecmem->iov_base = buf;
out_iovecmem->iov_len = iovsize;
*iov = out_iovecmem;
*iovcnt = 1;
}
/// Count the total size of an iovec[] in bytes.
static inline size_t iovec_sum(const struct iovec iov[], const int iovcnt)
{
size_t result = 0;
for (int i = 0; i < iovcnt; ++i)
result += iov[i].iov_len;
return result;
}
static int session2_transport_pushv(struct session2 *s,
struct iovec *iov, int iovcnt,
bool iov_short_lived,
const struct comm_info *comm,
protolayer_finished_cb cb, void *baton)
{
struct iovec iovecmem;
if (kr_fails_assert(s))
return kr_error(EINVAL);
struct session2_pushv_ctx *ctx = malloc(sizeof(*ctx));
kr_require(ctx);
*ctx = (struct session2_pushv_ctx){
.session = s,
.cb = cb,
.baton = baton,
.comm = comm
};
int err_ret = kr_ok();
switch (s->transport.type) {
case SESSION2_TRANSPORT_IO:;
uv_handle_t *handle = s->transport.io.handle;
if (kr_fails_assert(handle)) {
err_ret = kr_error(EINVAL);
goto exit_err;
}
if (handle->type == UV_UDP) {
if (ENABLE_SENDMMSG && !s->outgoing) {
int fd;
int ret = uv_fileno(handle, &fd);
if (kr_fails_assert(!ret)) {
err_ret = kr_error(EIO);
goto exit_err;
}
/* TODO: support multiple iovecs properly? */
if (kr_fails_assert(iovcnt == 1)) {
err_ret = kr_error(EINVAL);
goto exit_err;
}
session2_transport_pushv_ensure_long_lived(
&iov, &iovcnt, iov_short_lived,
&iovecmem, ctx);
udp_queue_push(fd, comm->comm_addr, iov->iov_base, iov->iov_len,
session2_transport_udp_queue_pushv_finished,
ctx);
return kr_ok();
} else {
int ret = uv_udp_try_send((uv_udp_t*)handle, (uv_buf_t *)iov, iovcnt,
the_network->enable_connect_udp ? NULL : comm->comm_addr);
if (ret > 0) // equals buffer size, only confuses us
ret = 0;
if (ret == UV_EAGAIN) {
ret = kr_error(ENOBUFS);
session2_event(s, PROTOLAYER_EVENT_OS_BUFFER_FULL, NULL);
}
if (false && ret == UV_EAGAIN) { // XXX: see uv_try_write() below
uv_udp_send_t *req = malloc(sizeof(*req));
req->data = ctx;
session2_transport_pushv_ensure_long_lived(
&iov, &iovcnt, iov_short_lived,
&iovecmem, ctx);
ret = uv_udp_send(req, (uv_udp_t *)handle,
(uv_buf_t *)iov, iovcnt, comm->comm_addr,
session2_transport_udp_pushv_finished);
if (ret)
session2_transport_udp_pushv_finished(req, ret);
return ret;
}
session2_transport_pushv_finished(ret, ctx);
return ret;
}
} else if (handle->type == UV_TCP) {
int ret = uv_try_write((uv_stream_t *)handle, (uv_buf_t *)iov, iovcnt);
// XXX: queueing disabled for now if the OS can't accept the data.
// Typically that happens when OS buffers are full.
// We were missing any handling of partial write success, too.
if (ret == UV_EAGAIN || (ret >= 0 && ret != iovec_sum(iov, iovcnt))) {
ret = kr_error(ENOBUFS);
session2_event(s, PROTOLAYER_EVENT_OS_BUFFER_FULL, NULL);
}
else if (ret > 0) // iovec_sum was checked, let's not get confused anymore
ret = 0;
if (false && ret == UV_EAGAIN) {
uv_write_t *req = malloc(sizeof(*req));
req->data = ctx;
session2_transport_pushv_ensure_long_lived(
&iov, &iovcnt, iov_short_lived,
&iovecmem, ctx);
ret = uv_write(req, (uv_stream_t *)handle, (uv_buf_t *)iov, iovcnt,
session2_transport_stream_pushv_finished);
if (ret)
session2_transport_stream_pushv_finished(req, ret);
return ret;
}
session2_transport_pushv_finished(ret, ctx);
return ret;
#if ENABLE_XDP
} else if (handle->type == UV_POLL) {
xdp_handle_data_t *xhd = handle->data;
if (kr_fails_assert(xhd && xhd->socket)) {
err_ret = kr_error(EIO);
goto exit_err;
}
/* TODO: support multiple iovecs properly? */
if (kr_fails_assert(iovcnt == 1)) {
err_ret = kr_error(EINVAL);
goto exit_err;
}
session2_transport_pushv_ensure_long_lived(
&iov, &iovcnt, iov_short_lived,
&iovecmem, ctx);
knot_xdp_msg_t msg;
/* We don't have a nice way of preserving the _msg_t from frame allocation,
* so we manually redo all other parts of knot_xdp_send_alloc() */
memset(&msg, 0, sizeof(msg));
bool ipv6 = comm->comm_addr->sa_family == AF_INET6;
msg.flags = ipv6 ? KNOT_XDP_MSG_IPV6 : 0;
memcpy(msg.eth_from, comm->eth_from, sizeof(comm->eth_from));
memcpy(msg.eth_to, comm->eth_to, sizeof(comm->eth_to));
const struct sockaddr *ip_from = comm->dst_addr;
const struct sockaddr *ip_to = comm->comm_addr;
memcpy(&msg.ip_from, ip_from, kr_sockaddr_len(ip_from));
memcpy(&msg.ip_to, ip_to, kr_sockaddr_len(ip_to));
msg.payload = *iov;
uint32_t sent;
int ret = knot_xdp_send(xhd->socket, &msg, 1, &sent);
queue_push(xhd->tx_waker_queue, ctx);
uv_idle_start(&xhd->tx_waker, xdp_tx_waker);
kr_log_debug(XDP, "pushed a packet, ret = %d\n", ret);
return kr_ok();
#endif
} else {
kr_assert(false && "Unsupported handle");
err_ret = kr_error(EINVAL);
goto exit_err;
}
case SESSION2_TRANSPORT_PARENT:;
struct session2 *parent = s->transport.parent;
if (kr_fails_assert(parent)) {
err_ret = kr_error(EINVAL);
goto exit_err;
}
int ret = session2_wrap(parent,
protolayer_payload_iovec(iov, iovcnt, iov_short_lived),
comm, session2_transport_parent_pushv_finished,
ctx);
return (ret < 0) ? ret : kr_ok();
default:
kr_assert(false && "Invalid transport");
err_ret = kr_error(EINVAL);
goto exit_err;
}
exit_err:
session2_transport_pushv_finished(err_ret, ctx);
return err_ret;
}
struct push_ctx {
struct iovec iov;
protolayer_finished_cb cb;
void *baton;
};
static void session2_transport_single_push_finished(int status,
struct session2 *s,
const struct comm_info *comm,
void *baton)
{
struct push_ctx *ctx = baton;
if (ctx->cb)
ctx->cb(status, s, comm, ctx->baton);
free(ctx);
}
static inline int session2_transport_push(struct session2 *s,
char *buf, size_t buf_len,
bool buf_short_lived,
const struct comm_info *comm,
protolayer_finished_cb cb, void *baton)
{
struct push_ctx *ctx = malloc(sizeof(*ctx));
kr_require(ctx);
*ctx = (struct push_ctx){
.iov = {
.iov_base = buf,
.iov_len = buf_len
},
.cb = cb,
.baton = baton
};
return session2_transport_pushv(s, &ctx->iov, 1, buf_short_lived, comm,
session2_transport_single_push_finished, ctx);
}
static void on_session2_handle_close(uv_handle_t *handle)
{
struct session2 *session = handle->data;
kr_require(session->transport.type == SESSION2_TRANSPORT_IO &&
session->transport.io.handle == handle);
io_free(handle);
}
static void on_session2_timer_close(uv_handle_t *handle)
{
session2_unhandle(handle->data);
}
static int session2_handle_close(struct session2 *s)
{
if (kr_fails_assert(s->transport.type == SESSION2_TRANSPORT_IO))
return kr_error(EINVAL);
uv_handle_t *handle = s->transport.io.handle;
if (!handle->loop) {
/* This happens when kresd is stopping and the libUV loop has
* been ended. We do not `uv_close` the handles, we just free
* up the memory. */
session2_unhandle(s); /* For timer handle */
io_free(handle); /* This will unhandle the transport handle */
return kr_ok();
}
io_stop_read(handle);
uv_close((uv_handle_t *)&s->timer, on_session2_timer_close);
uv_close(handle, on_session2_handle_close);
return kr_ok();
}
static int session2_transport_event(struct session2 *s,
enum protolayer_event_type event,
void *baton)
{
if (s->closing)
return kr_ok();
if (event == PROTOLAYER_EVENT_EOF) {
// no layer wanted to retain TCP half-closed state
session2_force_close(s);
return kr_ok();
}
bool is_close_event = (event == PROTOLAYER_EVENT_CLOSE ||
event == PROTOLAYER_EVENT_FORCE_CLOSE);
if (is_close_event) {
kr_require(session2_is_empty(s));
session2_timer_stop(s);
s->closing = true;
}
switch (s->transport.type) {
case SESSION2_TRANSPORT_IO:;
if (kr_fails_assert(s->transport.io.handle)) {
return kr_error(EINVAL);
}
if (is_close_event)
return session2_handle_close(s);
return kr_ok();
case SESSION2_TRANSPORT_PARENT:;
session2_event_wrap(s, event, baton);
return kr_ok();
default:
kr_assert(false && "Invalid transport");
return kr_error(EINVAL);
}
}
void session2_kill_ioreq(struct session2 *session, struct qr_task *task)
{
if (!session || session->closing)
return;
if (kr_fails_assert(session->outgoing
&& session->transport.type == SESSION2_TRANSPORT_IO
&& session->transport.io.handle))
return;
session2_tasklist_del(session, task);
if (session->transport.io.handle->type == UV_UDP)
session2_close(session);
}
/* Copyright (C) CZ.NIC, z.s.p.o. <knot-resolver@labs.nic.cz>
* SPDX-License-Identifier: GPL-3.0-or-later
*/
/* High-level explanation of layered protocols: ./layered-protocols.rst */
/* HINT: If you are looking to implement support for a new transport protocol,
* start with the doc comment of the `PROTOLAYER_TYPE_MAP` macro and
* continue from there. */
/* GLOSSARY:
*
* Event:
* - An Event may be processed by the protocol layer sequence much like a
* Payload, but with a special callback. Events may be used to notify layers
* that e.g. a connection has been established; a timeout has occurred; a
* malformed packet has been received, etc. Events are generally not sent
* through the transport - they may, however, trigger a new payload to be
* sent, e.g. a HTTP error status response.
*
* Iteration:
* - The processing of Payload data or an event using a particular sequence
* of Protocol layers, either in Wrap or Unwrap direction. For payload
* processing, it is also the lifetime of `struct protolayer_iter_ctx` and
* layer-specific data contained therein.
*
* Layer sequence return function:
* - One of `protolayer_break()`, `protolayer_continue()`, or
* `protolayer_async()` - a function that a protolayer's `_wrap` or `_unwrap`
* callback should call to get its return value. They may either be called
* synchronously directly in the callback to end/pause the processing, or, if
* the processing went asynchronous, called to resume the iteration of layers.
*
* Payload:
* - Data processed by protocol layers in a particular sequence. In the wrap
* direction, this data generally starts as a DNS packet, which is then
* wrapped in protocol ceremony data by each layer. In the unwrap direction,
* the opposite takes place - ceremony data is removed until a raw DNS packet
* is retrieved.
*
* Protocol layer:
* - Not to be confused with `struct kr_layer_api`. An implementation of a
* particular protocol. A protocol layer transforms payloads to conform to a
* particular protocol, e.g. UDP, TCP, TLS, HTTP, QUIC, etc. While
* transforming a payload, a layer may also modify metadata - e.g. the UDP and
* TCP layers in the Unwrap direction implement the PROXYv2 protocol, using
* which they retrieve the IP address of the actual originating client and
* store it in the appropriate struct.
*
* Protolayer:
* - Short for 'protocol layer'.
*
* Unwrap:
* - The direction of data transformation, which starts with the transport
* (e.g. bytes that came from the network) and ends with an internal subsystem
* (e.g. DNS query resolution).
*
* Wrap:
* - The direction of data transformation, which starts with an internal
* subsystem (e.g. an answer to a resolved DNS query) and ends with the
* transport (e.g. bytes that are going to be sent to the client). */
#pragma once
#include <stdalign.h>
#include <stdint.h>
#include <stdlib.h>
#include <uv.h>
#include "contrib/mempattern.h"
#include "lib/generic/queue.h"
#include "lib/generic/trie.h"
#include "lib/proto.h"
#include "lib/utils.h"
/* Forward declarations */
struct session2;
struct protolayer_iter_ctx;
/** Type of MAC addresses. */
typedef uint8_t ethaddr_t[6];
/** Information about the transport - addresses and proxy. */
struct comm_info {
/** The original address the data came from. May be that of a proxied
* client, if they came through a proxy. May be `NULL` if
* the communication did not come from network. */
const struct sockaddr *src_addr;
/** The actual address the resolver is communicating with. May be
* the address of a proxy if the communication came through one,
* otherwise it will be the same as `src_addr`. May be `NULL` if
* the communication did not come from network. */
const struct sockaddr *comm_addr;
/** The original destination address. May be the resolver's address, or
* the address of a proxy if the communication came through one. May be
* `NULL` if the communication did not come from network. */
const struct sockaddr *dst_addr;
/** Data parsed from a PROXY header. May be `NULL` if the communication
* did not come through a proxy, or if the PROXYv2 protocol was not
* used. */
const struct proxy_result *proxy;
/** Pointer to protolayer-specific data, e.g. a key to decide, which
* sub-session to use. */
void *target;
/* XDP data */
ethaddr_t eth_from;
ethaddr_t eth_to;
bool xdp:1;
};
/** Just a simple struct able to hold three IPv6 or IPv4 addresses, so that we
* can hold them somewhere. */
struct comm_addr_storage {
union kr_sockaddr src_addr;
union kr_sockaddr comm_addr;
union kr_sockaddr dst_addr;
};
/** A buffer control struct, with indices marking a chunk containing received
* but as of yet unprocessed data - the data in this chunk is called "valid
* data". The struct may be manipulated using `wire_buf_` functions, which
* contain bounds checks to ensure correct behaviour.
*
* The struct may be used to retrieve data piecewise (e.g. from a stream-based
* transport like TCP) by writing data to the buffer's free space, then
* "consuming" that space with `wire_buf_consume`. It can also be handy for
* processing message headers, then trimming the beginning of the buffer (using
* `wire_buf_trim`) so that the next part of the data may be processed by
* another part of a pipeline.
*
* May be initialized in two possible ways:
* - via `wire_buf_init`
* - to zero, then reserved via `wire_buf_reserve`. */
struct wire_buf {
char *buf; /**< Buffer memory. */
size_t size; /**< Current size of the buffer memory. */
size_t start; /**< Index at which the valid data of the buffer starts (inclusive). */
size_t end; /**< Index at which the valid data of the buffer ends (exclusive). */
};
/** Initializes the wire buffer with the specified `initial_size` and allocates
* the underlying memory. */
int wire_buf_init(struct wire_buf *wb, size_t initial_size);
/** De-allocates the wire buffer's underlying memory (the struct itself is left
* intact). */
void wire_buf_deinit(struct wire_buf *wb);
/** Ensures that the wire buffer's size is at least `size`. The memory at `wb`
* must be initialized, either to zero or via `wire_buf_init`. */
int wire_buf_reserve(struct wire_buf *wb, size_t size);
/** Adds `length` to the end index of the valid data, marking `length` more
* bytes as valid.
*
* Returns 0 on success.
* Assert-fails and/or returns `kr_error(EINVAL)` if the end index would exceed
* the buffer size. */
int wire_buf_consume(struct wire_buf *wb, size_t length);
/** Adds `length` to the start index of the valid data, marking `length` less
* bytes as valid.
*
* Returns 0 on success.
* Assert-fails and/or returns `kr_error(EINVAL)` if the start index would
* exceed the end index. */
int wire_buf_trim(struct wire_buf *wb, size_t length);
/** Moves the valid bytes of the buffer to the buffer's beginning. */
int wire_buf_movestart(struct wire_buf *wb);
/** Marks the wire buffer as empty. */
int wire_buf_reset(struct wire_buf *wb);
/** Gets a pointer to the data marked as valid in the wire buffer. */
static inline void *wire_buf_data(const struct wire_buf *wb)
{
return &wb->buf[wb->start];
}
/** Gets the length of the data marked as valid in the wire buffer. */
static inline size_t wire_buf_data_length(const struct wire_buf *wb)
{
return wb->end - wb->start;
}
/** Gets a pointer to the free space after the valid data of the wire buffer. */
static inline void *wire_buf_free_space(const struct wire_buf *wb)
{
return (wb->buf) ? &wb->buf[wb->end] : NULL;
}
/** Gets the length of the free space after the valid data of the wire buffer. */
static inline size_t wire_buf_free_space_length(const struct wire_buf *wb)
{
if (kr_fails_assert(wb->end <= wb->size))
return 0;
return (wb->buf) ? wb->size - wb->end : 0;
}
/** Protocol layer types map - an enumeration of individual protocol layer
* implementations
*
* This macro is used to generate `enum protolayer_type` as well as other
* additional data on protocols, e.g. name string constants.
*
* To define a new protocol, add a new identifier to this macro, and, within
* some logical compilation unit (e.g. `daemon/worker.c` for DNS layers),
* initialize the protocol's `protolayer_globals[]`, ideally in a function
* called at the start of the program (e.g. `worker_init()`). See the docs of
* `struct protolayer_globals` for details on what data this structure should
* contain.
*
* To use protocols within sessions, protocol layer groups also need to be
* defined, to indicate the order in which individual protocols are to be
* processed. See `KR_PROTO_MAP` below for more details. */
#define PROTOLAYER_TYPE_MAP(XX) \
/* General transport protocols */\
XX(UDP)\
XX(TCP)\
XX(TLS)\
XX(HTTP)\
\
/* PROXYv2 */\
XX(PROXYV2_DGRAM)\
XX(PROXYV2_STREAM)\
\
/* DNS (`worker`) */\
XX(DNS_DGRAM) /**< Packets WITHOUT prepended size, one per (un)wrap,
* limited to UDP sizes, multiple sources (single
* session for multiple clients). */\
XX(DNS_UNSIZED_STREAM) /**< Singular packet WITHOUT prepended size, one
* per (un)wrap, no UDP limits, single source. */\
XX(DNS_MULTI_STREAM) /**< Multiple packets WITH prepended sizes in a
* stream (may span multiple (un)wraps). */\
XX(DNS_SINGLE_STREAM) /**< Singular packet WITH prepended size in a
* stream (may span multiple (un)wraps). */\
/* Prioritization of requests */\
XX(DEFER)
/** The identifiers of protocol layer types. */
enum protolayer_type {
PROTOLAYER_TYPE_NULL = 0,
#define XX(cid) PROTOLAYER_TYPE_ ## cid,
PROTOLAYER_TYPE_MAP(XX)
#undef XX
PROTOLAYER_TYPE_COUNT /* must be the last! */
};
/** Gets the constant string name of the specified protocol. */
const char *protolayer_layer_name(enum protolayer_type p);
/** Flow control indicators for protocol layer `wrap` and `unwrap` callbacks.
* Use via `protolayer_continue`, `protolayer_break`, and `protolayer_push`
* functions. */
enum protolayer_iter_action {
PROTOLAYER_ITER_ACTION_NULL = 0,
PROTOLAYER_ITER_ACTION_CONTINUE,
PROTOLAYER_ITER_ACTION_BREAK,
};
/** Direction of layer sequence processing. */
enum protolayer_direction {
/** Processes buffers in order of layers as defined in the layer group.
* In this direction, protocol ceremony data should be removed from the
* buffer, parsing additional data provided by the protocol. */
PROTOLAYER_UNWRAP,
/** Processes buffers in reverse order of layers as defined in the
* layer group. In this direction, protocol ceremony data should be
* added. */
PROTOLAYER_WRAP,
};
/** Returned by a successful call to `session2_wrap()` or `session2_unwrap()`
* functions. */
enum protolayer_ret {
/** Returned when a protolayer context iteration has finished
* processing, i.e. with `protolayer_break()`. */
PROTOLAYER_RET_NORMAL = 0,
/** Returned when a protolayer context iteration is waiting for an
* asynchronous callback to a continuation function. This will never be
* passed to `protolayer_finished_cb`, only returned by
* `session2_unwrap` or `session2_wrap`. */
PROTOLAYER_RET_ASYNC,
};
/** Called when a payload iteration (started by `session2_unwrap` or
* `session2_wrap`) has ended - i.e. the input buffer will not be processed any
* further.
*
* `status` may be one of `enum protolayer_ret` or a negative number indicating
* an error.
* `target` is the `target` parameter passed to the `session2_(un)wrap`
* function.
* `baton` is the `baton` parameter passed to the `session2_(un)wrap` function. */
typedef void (*protolayer_finished_cb)(int status, struct session2 *session,
const struct comm_info *comm, void *baton);
/** Protocol layer event type map
*
* This macro is used to generate `enum protolayer_event_type` as well as the
* relevant name string constants for each event type.
*
* Event types are used to distinguish different events that can be passed to
* sessions using `session2_event()`. */
#define PROTOLAYER_EVENT_MAP(XX) \
/** Closes the session gracefully - i.e. layers add their standard
* disconnection ceremony (e.g. `gnutls_bye()`). */\
XX(CLOSE) \
/** Closes the session forcefully - i.e. layers SHOULD NOT add any
* disconnection ceremony, if avoidable. */\
XX(FORCE_CLOSE) \
/** Signal that a connection could not be established due to a timeout. */\
XX(CONNECT_TIMEOUT) \
/** Signal that a general application-defined timeout has occurred. */\
XX(GENERAL_TIMEOUT) \
/** Signal that a connection has been established. */\
XX(CONNECT) \
/** Signal that a connection could not have been established. */\
XX(CONNECT_FAIL) \
/** Signal that a malformed request has been received. */\
XX(MALFORMED) \
/** Signal that a connection has ended. */\
XX(DISCONNECT) \
/** Signal EOF from peer (e.g. half-closed TCP connection). */\
XX(EOF) \
/** Failed task send - update stats. */\
XX(STATS_SEND_ERR) \
/** Outgoing query submission - update stats. */\
XX(STATS_QRY_OUT) \
/** OS buffers are full, so not sending any more data. */\
XX(OS_BUFFER_FULL) \
//
/** Event type, to be interpreted by a layer. */
enum protolayer_event_type {
PROTOLAYER_EVENT_NULL = 0,
#define XX(cid) PROTOLAYER_EVENT_ ## cid,
PROTOLAYER_EVENT_MAP(XX)
#undef XX
PROTOLAYER_EVENT_COUNT
};
/** Gets the constant string name of the specified event. */
const char *protolayer_event_name(enum protolayer_event_type e);
/** Payload types.
*
* Parameters are:
* 1. Constant name
* 2. Human-readable name for logging */
#define PROTOLAYER_PAYLOAD_MAP(XX) \
XX(BUFFER, "Buffer") \
XX(IOVEC, "IOVec") \
XX(WIRE_BUF, "Wire buffer")
/** Determines which union member of `struct protolayer_payload` is currently
* valid. */
enum protolayer_payload_type {
PROTOLAYER_PAYLOAD_NULL = 0,
#define XX(cid, name) PROTOLAYER_PAYLOAD_##cid,
PROTOLAYER_PAYLOAD_MAP(XX)
#undef XX
PROTOLAYER_PAYLOAD_COUNT
};
/** Gets the constant string name of the specified payload type. */
const char *protolayer_payload_name(enum protolayer_payload_type p);
/** Data processed by the sequence of layers. All pointed-to memory is always
* owned by its creator. It is also the layer (group) implementor's
* responsibility to keep data compatible in between layers. No payload memory
* is ever (de-)allocated by the protolayer manager! */
struct protolayer_payload {
enum protolayer_payload_type type;
/** Time-to-live hint (e.g. for HTTP Cache-Control) */
unsigned int ttl;
/** If `true`, signifies that the memory this payload points to may
* become invalid when we return from one of the functions in the
* current stack. That is fine as long as all the protocol layer
* processing for this payload takes place in a single `session2_wrap()`
* or `session2_unwrap()` call, but may become a problem, when a layer
* goes asynchronous (via `protolayer_async()`).
*
* Setting this to `true` will ensure that the payload will get copied
* into a separate memory buffer if and only if a layer goes
* asynchronous. It makes sure that if all processing for the payload is
* synchronous, no copies or reallocations for the payload are done. */
bool short_lived;
union {
/** Only valid if `type` is `_BUFFER`. */
struct {
void *buf;
size_t len;
} buffer;
/** Only valid if `type` is `_IOVEC`. */
struct {
struct iovec *iov;
int cnt;
} iovec;
/** Only valid if `type` is `_WIRE_BUF`. */
struct wire_buf *wire_buf;
};
};
/** Context for protocol layer iterations, containing payload data,
* layer-specific data, and internal information for the protocol layer
* manager. */
struct protolayer_iter_ctx {
/* read-write for layers: */
/** The payload */
struct protolayer_payload payload;
/** Pointer to communication information. For TCP, this will generally
* point to the storage in the session. For UDP, this will generally
* point to the storage in this context. */
struct comm_info *comm;
/** Communication information storage. This will generally be set by one
* of the first layers in the sequence, if used, e.g. UDP PROXYv2. */
struct comm_info comm_storage;
struct comm_addr_storage comm_addr_storage;
/** Per-iter memory pool. Has no `free` procedure, gets freed as a whole
* when the context is being destroyed. Initialized and destroyed
* automatically - layers may use it to allocate memory. */
knot_mm_t pool;
/* callback for when the layer iteration has ended - read-only for layers: */
protolayer_finished_cb finished_cb;
void *finished_cb_baton;
/* internal information for the manager - should only be used by the protolayer
* system, never by layers: */
enum protolayer_direction direction;
/** If `true`, the processing of the layer sequence has been paused and
* is waiting to be resumed (`protolayer_continue()`) or cancelled
* (`protolayer_break()`). */
bool async_mode;
/** The index of the layer that is currently being (or has just been)
* processed. */
unsigned int layer_ix;
struct session2 *session;
/** Status passed to the finish callback. */
int status;
enum protolayer_iter_action action;
/** Contains a sequence of variably-sized CPU-aligned layer-specific
* structs. See `struct session2::layer_data` for details. */
alignas(CPU_STRUCT_ALIGN) char data[];
};
/** Gets the total size of the data in the specified payload. */
size_t protolayer_payload_size(const struct protolayer_payload *payload);
/** Copies the specified payload to `dest`. Only `max_len` or the size of the
* payload is written, whichever is less.
*
* Returns the actual length of copied data. */
size_t protolayer_payload_copy(void *dest,
const struct protolayer_payload *payload,
size_t max_len);
/** Convenience function to get a buffer-type payload. */
static inline struct protolayer_payload protolayer_payload_buffer(
void *buf, size_t len, bool short_lived)
{
return (struct protolayer_payload){
.type = PROTOLAYER_PAYLOAD_BUFFER,
.short_lived = short_lived,
.buffer = {
.buf = buf,
.len = len
}
};
}
/** Convenience function to get an iovec-type payload. */
static inline struct protolayer_payload protolayer_payload_iovec(
struct iovec *iov, int iovcnt, bool short_lived)
{
return (struct protolayer_payload){
.type = PROTOLAYER_PAYLOAD_IOVEC,
.short_lived = short_lived,
.iovec = {
.iov = iov,
.cnt = iovcnt
}
};
}
/** Convenience function to get a wire-buf-type payload. */
static inline struct protolayer_payload protolayer_payload_wire_buf(
struct wire_buf *wire_buf, bool short_lived)
{
return (struct protolayer_payload){
.type = PROTOLAYER_PAYLOAD_WIRE_BUF,
.short_lived = short_lived,
.wire_buf = wire_buf
};
}
/** Convenience function to represent the specified payload as a buffer-type.
* Supports only `_BUFFER` and `_WIRE_BUF` on the input, otherwise returns
* `_NULL` type or aborts on assertion if allowed.
*
* If the input payload is `_WIRE_BUF`, the pointed-to wire buffer is reset to
* indicate that all of its contents have been used up, and the buffer is ready
* to be reused. */
struct protolayer_payload protolayer_payload_as_buffer(
const struct protolayer_payload *payload);
/** A predefined queue type for iteration context. */
typedef queue_t(struct protolayer_iter_ctx *) protolayer_iter_ctx_queue_t;
/** Iterates through the specified `queue` and gets the sum of all payloads
* available in it. */
size_t protolayer_queue_count_payload(const protolayer_iter_ctx_queue_t *queue);
/** Checks if the specified `queue` has any payload data (i.e.
* `protolayer_queue_count_payload` would be non-zero). This optimizes calls to
* queue iterators, as it does not need to iterate through the whole queue. */
bool protolayer_queue_has_payload(const protolayer_iter_ctx_queue_t *queue);
/** Gets layer-specific session data for the specified protocol layer.
* Returns NULL if the layer is not present in the session. */
void *protolayer_sess_data_get_proto(struct session2 *s, enum protolayer_type protocol);
/** Gets layer-specific session data for the last processed layer.
* To be used after returning from its callback for async continuation but before calling protolayer_continue. */
void *protolayer_sess_data_get_current(struct protolayer_iter_ctx *ctx);
/** Gets layer-specific iteration data for the last processed layer.
* To be used after returning from its callback for async continuation but before calling protolayer_continue. */
void *protolayer_iter_data_get_current(struct protolayer_iter_ctx *ctx);
/** Gets rough memory footprint estimate of session/iteration for use in defer.
* Different, hopefully minor, allocations are not counted here;
* tasks and subsessions are also not counted;
* read the code before using elsewhere. */
size_t protolayer_sess_size_est(struct session2 *s);
size_t protolayer_iter_size_est(struct protolayer_iter_ctx *ctx, bool incl_payload);
/** Layer-specific data - the generic struct. To be added as the first member of
* each specific struct. */
struct protolayer_data {
struct session2 *session; /**< Pointer to the owner session. */\
};
/** Return value of `protolayer_iter_cb` callbacks. To be returned by *layer
* sequence return functions* (see glossary) as a sanity check. Not to be used
* directly by user code. */
enum protolayer_iter_cb_result {
PROTOLAYER_ITER_CB_RESULT_MAGIC = 0x364F392E,
};
/** Function type for `struct protolayer_globals::wrap` and `struct
* protolayer_globals::unwrap`. The function processes the provided
* `ctx->payload` and decides the next action for the currently processed
* sequence.
*
* The function (or another function, that the pointed-to function causes to be
* called, directly or through an asynchronous operation), must call one of the
* *layer sequence return functions* (see glossary) to advance (or end) the
* layer sequence. The function must return the result of such a return
* function. */
typedef enum protolayer_iter_cb_result (*protolayer_iter_cb)(
void *sess_data,
void *iter_data,
struct protolayer_iter_ctx *ctx);
/** Return value of `protolayer_event_cb` callbacks. Controls the flow of
* events. See `protolayer_event_cb` for details. */
enum protolayer_event_cb_result {
PROTOLAYER_EVENT_CONSUME = 0,
PROTOLAYER_EVENT_PROPAGATE = 1
};
/** Function type for `struct protolayer_globals::event_wrap` and `struct
* protolayer_globals::event_unwrap` callbacks of layers. The `baton` parameter
* points to the mutable, iteration-specific baton pointer, initialized by the
* `baton` parameter of one of the `session2_event` functions. The pointed-to
* value of `baton` may be modified to accommodate for the next layer in the
* sequence.
*
* When `PROTOLAYER_EVENT_PROPAGATE` is returned, iteration over the sequence
* of layers continues. When `PROTOLAYER_EVENT_CONSUME` is returned, iteration
* stops.
*
* **IMPORTANT:** A well-behaved layer will **ALWAYS** propagate events it knows
* nothing about. Only ever consume events you actually have good reason to
* consume (like TLS consumes `CONNECT` from TCP, because it needs to perform
* its own handshake first). */
typedef enum protolayer_event_cb_result (*protolayer_event_cb)(
enum protolayer_event_type event, void **baton,
struct session2 *session, void *sess_data);
/** Function type for initialization callbacks of layer session data.
*
* The `param` value is the one associated with the currently initialized
* layer, from the `layer_param` array of `session2_new()` - may be NULL if
* none is provided for the current layer.
*
* `data` points to the layer-specific data struct.
*
* Returning 0 means success, other return values mean error and halt the
* initialization. */
typedef int (*protolayer_data_sess_init_cb)(struct session2 *session,
void *data, void *param);
/** Function type for determining the size of a layer's wire buffer overhead. */
typedef size_t (*protolayer_wire_buf_overhead_cb)(bool outgoing);
/** Function type for (de)initialization callback of layer iteration data.
*
* `ctx` points to the iteration context that `data` belongs to.
*
* `data` points to the layer-specific data struct.
*
* Returning 0 means success, other return values mean error and halt the
* initialization. */
typedef int (*protolayer_iter_data_cb)(struct protolayer_iter_ctx *ctx,
void *data);
/** Function type for (de)initialization callbacks of layers.
*
* `data` points to the layer-specific data struct.
*
* Returning 0 means success, other return values mean error and halt the
* initialization. */
typedef int (*protolayer_data_cb)(struct session2 *session, void *data);
/** Function type for (de)initialization callbacks of DNS requests.
*
* `req` points to the request for initialization.
* `sess_data` points to layer-specific session data struct. */
typedef void (*protolayer_request_cb)(struct session2 *session,
struct kr_request *req,
void *sess_data);
/** Initialization parameters for protocol layer session data. */
struct protolayer_data_param {
enum protolayer_type protocol; /**< Which protocol these parameters
* are meant for. */
void *param; /**< Pointer to protolayer-related initialization
* parameters. Only needs to be valid during session
* initialization. */
};
/** Global data for a specific layered protocol. This is to be initialized in
* the `protolayer_globals` global array (below) during the the resolver's
* startup. It contains pointers to functions implementing a particular
* protocol, as well as other important data.
*
* Every member of this struct is allowed to be zero/NULL if a particular
* protocol has no use for it. */
struct protolayer_globals {
/** Size of the layer-specific data struct, valid per-session.
*
* The struct MUST begin with a `struct protolayer_data` member. If
* no session struct is used by the layer, the value may be zero. */
size_t sess_size;
/** Size of the layer-specific data struct, valid per-iteration. It
* gets created and destroyed together with a `struct
* protolayer_iter_ctx`.
*
* The struct MUST begin with a `struct protolayer_data` member. If
* no iteration struct is used by the layer, the value may be zero. */
size_t iter_size;
/** Number of bytes that this layer adds onto the session's wire buffer
* by default. All overheads in a group are summed together to form the
* resulting default wire buffer length.
*
* Ignored when `wire_buf_overhead_cb` is non-NULL. */
size_t wire_buf_overhead;
/** Called during session initialization to determine the number of
* bytes that this layer adds onto the session's wire buffer.
*
* It is the dynamic version of `wire_buf_overhead`, which is ignored
* when this is non-NULL. */
protolayer_wire_buf_overhead_cb wire_buf_overhead_cb;
/** Number of bytes that this layer adds onto the session's wire buffer
* at most. All overheads in a group are summed together to form the
* resulting default wire buffer length.
*
* If this is less than the default overhead, the default is used
* instead. */
size_t wire_buf_max_overhead;
/** Called during session creation to initialize
* layer-specific session data. The data is always provided
* zero-initialized to this function. */
protolayer_data_sess_init_cb sess_init;
/** Called during session destruction to deinitialize
* layer-specific session data. */
protolayer_data_cb sess_deinit;
/** Called at the beginning of a non-event layer sequence to initialize
* layer-specific iteration data. The data is always zero-initialized
* during iteration context initialization. */
protolayer_iter_data_cb iter_init;
/** Called at the end of a non-event layer sequence to deinitialize
* layer-specific iteration data. */
protolayer_iter_data_cb iter_deinit;
/** Strips the buffer of protocol-specific data. E.g. a HTTP layer
* removes HTTP status and headers. Optional - iteration continues
* automatically if this is NULL. */
protolayer_iter_cb unwrap;
/** Wraps the buffer into protocol-specific data. E.g. a HTTP layer
* adds HTTP status and headers. Optional - iteration continues
* automatically if this is NULL. */
protolayer_iter_cb wrap;
/** Processes events in the unwrap order (sent from the outside).
* Optional - iteration continues automatically if this is NULL. */
protolayer_event_cb event_unwrap;
/** Processes events in the wrap order (bounced back by the session).
* Optional - iteration continues automatically if this is NULL. */
protolayer_event_cb event_wrap;
/** Modifies the provided request for use with the layer. Mostly for
* setting `struct kr_request::qsource.comm_flags`. */
protolayer_request_cb request_init;
};
/** Global data about layered protocols. Mapped by `enum protolayer_type`.
* Individual protocols are to be initialized during resolver startup. */
extern struct protolayer_globals protolayer_globals[PROTOLAYER_TYPE_COUNT];
/** *Layer sequence return function* (see glossary) - signalizes the protolayer
* manager to continue processing the next layer. */
enum protolayer_iter_cb_result protolayer_continue(struct protolayer_iter_ctx *ctx);
/** *Layer sequence return function* (see glossary) - signalizes that the layer
* wants to stop processing of the buffer and clean up, possibly due to an error
* (indicated by a non-zero `status`). */
enum protolayer_iter_cb_result protolayer_break(struct protolayer_iter_ctx *ctx, int status);
/** *Layer sequence return function* (see glossary) - signalizes that the
* current sequence will continue in an asynchronous manner. The layer should
* store the context and call another sequence return function at another point.
* This may be used in layers that work through libraries whose operation is
* asynchronous, like GnuTLS.
*
* Note that this one is basically just a readability hint - another return
* function may be actually called before it (generally during a call to an
* external library function, e.g. GnuTLS or nghttp2). This is completely legal
* and the sequence will continue correctly. */
static inline enum protolayer_iter_cb_result protolayer_async(void)
{
return PROTOLAYER_ITER_CB_RESULT_MAGIC;
}
/** Indicates how a session sends data in the `wrap` direction and receives
* data in the `unwrap` direction. */
enum session2_transport_type {
SESSION2_TRANSPORT_NULL = 0,
SESSION2_TRANSPORT_IO,
SESSION2_TRANSPORT_PARENT,
};
/** A data unit for a single sequential data source. The data may be organized
* as a stream or a sequence of datagrams - this is up to the actual individual
* protocols used by the session - see `enum kr_proto` and
* `protolayer_`-prefixed types and functions for more information.
*
* A session processes data in two directions:
*
* - `_UNWRAP` deals with raw data received from the session's transport. It
* strips the ceremony of individual protocols from the buffers, retaining any
* required metadata in an iteration context (`struct protolayer_iter_ctx`).
* The last layer (as defined by a `protolayer_grp_*` array in `session2.c`) in
* a sequence is generally responsible for submitting the unwrapped data to be
* processed by an internal system, e.g. to be resolved as a DNS query.
*
* - `_WRAP` deals with data generated by an internal system. It adds the
* required protocol ceremony to it (e.g. encryption). The first layer (as
* defined by a `protolayer_grp_*` array in `session2.c`) is responsible for
* preparing the data to be sent through the session's transport. */
struct session2 {
/** Data for sending data out in the `wrap` direction and receiving new
* data in the `unwrap` direction. */
struct {
enum session2_transport_type type; /**< See `enum session2_transport_type` */
union {
/** For `_IO` type transport. Contains a libuv handle
* and session-related address storage. */
struct {
uv_handle_t *handle;
union kr_sockaddr peer;
union kr_sockaddr sockname;
} io;
/** For `_PARENT` type transport. */
struct session2 *parent;
};
} transport;
uv_timer_t timer; /**< For session-wide timeout events. */
enum protolayer_event_type timer_event; /**< The event fired on timeout. */
trie_t *tasks; /**< List of tasks associated with given session. */
queue_t(struct qr_task *) waiting; /**< List of tasks waiting for
* sending to upstream. */
struct wire_buf wire_buf;
uint32_t log_id; /**< Session ID for logging. */
int ref_count; /**< Number of unclosed libUV handles owned by this
* session + iteration contexts referencing the session. */
/** Communication information. Typically written into by one of the
* first layers facilitating transport protocol processing.
* Zero-initialized by default. */
struct comm_info comm_storage;
/** Time of last IO activity (if any occurs). Otherwise session
* creation time. */
uint64_t last_activity;
/** If true, the session's transport is towards an upstream server.
* Otherwise, it is towards a client. */
bool outgoing : 1;
/** If true, the session is at the end of its lifecycle and is about
* to close. */
bool closing : 1;
/** If true, the session has done something useful,
* e.g. it has produced a packet. */
bool was_useful : 1;
/** If true, encryption takes place in this session. Layers may use
* this to determine whether padding should be applied. A layer that
* provides security shall set this to `true` during session
* initialization. */
bool secure : 1;
/** If true, the session contains a stream-based protocol layer.
* Set during protocol layer initialization by the stream-based layer. */
bool stream : 1;
/** If true, the session contains a protocol layer with custom handling
* of malformed queries. This is used e.g. by the HTTP layer, which will
* return a Bad Request status on a malformed query. */
bool custom_emalf_handling : 1;
/** If true, session is being rate-limited. One of the protocol layers
* is going to be the writer for this flag. */
bool throttled : 1;
/* Protocol layers */
/** The set of protocol layers used by this session. */
enum kr_proto proto;
/** The size of a single iteration context
* (`struct protolayer_iter_ctx`), including layer-specific data. */
size_t iter_ctx_size;
/** The size of this session struct. */
size_t session_size;
/** The following flexible array has basically this structure:
*
* struct {
* size_t sess_offsets[num_layers];
* size_t iter_offsets[num_layers];
* variably-sized-data sess_data[num_layers];
* }
*
* It is done this way, because different layer groups will have
* different numbers of layers and differently-sized layer-specific
* data. C does not have a convenient way to define this in structs, so
* we do it via this flexible array.
*
* `sess_data` is a sequence of variably-sized CPU-aligned
* layer-specific structs.
*
* `sess_offsets` determines data offsets in `sess_data` for pointer
* retrieval.
*
* `iter_offsets` determines data offsets in `struct
* protolayer_iter_ctx::data` for pointer retrieval. */
alignas(CPU_STRUCT_ALIGN) char layer_data[];
};
/** Allocates and initializes a new session with the specified protocol layer
* group, and the provided transport context.
*
* `layer_param` is a pointer to an array of size `layer_param_count`. The
* parameters are passed to the layer session initializers. The parameter array
* is only required to be valid before this function returns. It is up to the
* individual layer implementations to determine the lifetime of the data
* pointed to by the parameters. */
struct session2 *session2_new(enum session2_transport_type transport_type,
enum kr_proto proto,
struct protolayer_data_param *layer_param,
size_t layer_param_count,
bool outgoing);
/** Allocates and initializes a new session with the specified protocol layer
* group, using a *libuv handle* as its transport. */
static inline struct session2 *session2_new_io(uv_handle_t *handle,
enum kr_proto layer_grp,
struct protolayer_data_param *layer_param,
size_t layer_param_count,
bool outgoing)
{
struct session2 *s = session2_new(SESSION2_TRANSPORT_IO, layer_grp,
layer_param, layer_param_count, outgoing);
s->transport.io.handle = handle;
handle->data = s;
s->ref_count++; /* Session owns the handle */
return s;
}
/** Allocates and initializes a new session with the specified protocol layer
* group, using a *parent session* as its transport. */
static inline struct session2 *session2_new_child(struct session2 *parent,
enum kr_proto layer_grp,
struct protolayer_data_param *layer_param,
size_t layer_param_count,
bool outgoing)
{
struct session2 *s = session2_new(SESSION2_TRANSPORT_PARENT, layer_grp,
layer_param, layer_param_count, outgoing);
s->transport.parent = parent;
return s;
}
/** Used when a libUV handle owned by the session is closed. Once all owned
* handles are closed, the session is freed. */
void session2_unhandle(struct session2 *s);
/** Start reading from the underlying transport. */
int session2_start_read(struct session2 *session);
/** Stop reading from the underlying transport. */
int session2_stop_read(struct session2 *session);
/** Gets the peer address from the specified session, iterating through the
* session hierarchy (child-to-parent) until an `_IO` session is found if
* needed.
*
* May return `NULL` if no peer is set. */
struct sockaddr *session2_get_peer(struct session2 *s);
/** Gets the sockname from the specified session, iterating through the
* session hierarchy (child-to-parent) until an `_IO` session is found if
* needed.
*
* May return `NULL` if no peer is set. */
struct sockaddr *session2_get_sockname(struct session2 *s);
/** Gets the libuv handle from the specified session, iterating through the
* session hierarchy (child-to-parent) until an `_IO` session is found if
* needed.
*
* May return `NULL` if no peer is set. */
KR_EXPORT uv_handle_t *session2_get_handle(struct session2 *s);
/** Start the session timer. On timeout, the specified `event` is sent in the
* `_UNWRAP` direction. Only a single timeout can be active at a time. */
int session2_timer_start(struct session2 *s, enum protolayer_event_type event,
uint64_t timeout, uint64_t repeat);
/** Restart the session timer without changing any of its parameters. */
int session2_timer_restart(struct session2 *s);
/** Stop the session timer. */
int session2_timer_stop(struct session2 *s);
int session2_tasklist_add(struct session2 *session, struct qr_task *task);
int session2_tasklist_del(struct session2 *session, struct qr_task *task);
struct qr_task *session2_tasklist_get_first(struct session2 *session);
struct qr_task *session2_tasklist_del_first(struct session2 *session, bool deref);
struct qr_task *session2_tasklist_find_msgid(const struct session2 *session, uint16_t msg_id);
struct qr_task *session2_tasklist_del_msgid(const struct session2 *session, uint16_t msg_id);
void session2_tasklist_finalize(struct session2 *session, int status);
int session2_tasklist_finalize_expired(struct session2 *session);
static inline size_t session2_tasklist_get_len(const struct session2 *session)
{
return trie_weight(session->tasks);
}
static inline bool session2_tasklist_is_empty(const struct session2 *session)
{
return session2_tasklist_get_len(session) == 0;
}
int session2_waitinglist_push(struct session2 *session, struct qr_task *task);
struct qr_task *session2_waitinglist_get(const struct session2 *session);
struct qr_task *session2_waitinglist_pop(struct session2 *session, bool deref);
void session2_waitinglist_retry(struct session2 *session, bool increase_timeout_cnt);
void session2_waitinglist_finalize(struct session2 *session, int status);
static inline size_t session2_waitinglist_get_len(const struct session2 *session)
{
return queue_len(session->waiting);
}
static inline bool session2_waitinglist_is_empty(const struct session2 *session)
{
return session2_waitinglist_get_len(session) == 0;
}
static inline bool session2_is_empty(const struct session2 *session)
{
return session2_tasklist_is_empty(session) &&
session2_waitinglist_is_empty(session);
}
/** Penalizes the server the specified `session` is connected to, if the session
* has not been useful (see `struct session2::was_useful`). Only applies to
* `outgoing` sessions, and the session should not be connection-less. */
void session2_penalize(struct session2 *session);
/** Sends the specified `payload` to be processed in the `_UNWRAP` direction by
* the session's protocol layers.
*
* The `comm` parameter may contain a pointer to comm data, e.g. for UDP, that
* comm data shall contain a pointer to the sender's `struct sockaddr_*`. If
* `comm` is `NULL`, session-wide data shall be used.
*
* Note that the payload data may be modified by any of the layers, to avoid
* making copies. Once the payload is passed to this function, the content of
* the referenced data is undefined to the caller.
*
* Once all layers are processed, `cb` is called with `baton` passed as one
* of its parameters. `cb` may also be `NULL`. See `protolayer_finished_cb` for
* more info.
*
* Returns one of `enum protolayer_ret` or a negative number
* indicating an error. */
int session2_unwrap(struct session2 *s, struct protolayer_payload payload,
const struct comm_info *comm, protolayer_finished_cb cb,
void *baton);
/** Same as `session2_unwrap`, but looks up the specified `protocol` in the
* session's assigned protocol group and sends the `payload` to the layer that
* is next in the sequence in the `_UNWRAP` direction.
*
* Layers may use this to generate their own data to send in the sequence, e.g.
* for protocol-specific ceremony. */
int session2_unwrap_after(struct session2 *s, enum protolayer_type protocol,
struct protolayer_payload payload,
const struct comm_info *comm,
protolayer_finished_cb cb, void *baton);
/** Sends the specified `payload` to be processed in the `_WRAP` direction by
* the session's protocol layers. The `target` parameter may contain a pointer
* to some data specific to the bottommost layer of this session.
*
* Note that the payload data may be modified by any of the layers, to avoid
* making copies. Once the payload is passed to this function, the content of
* the referenced data is undefined to the caller.
*
* Once all layers are processed, `cb` is called with `baton` passed as one
* of its parameters. `cb` may also be `NULL`. See `protolayer_finished_cb` for
* more info.
*
* Returns one of `enum protolayer_ret` or a negative number
* indicating an error. */
int session2_wrap(struct session2 *s, struct protolayer_payload payload,
const struct comm_info *comm, protolayer_finished_cb cb,
void *baton);
/** Same as `session2_wrap`, but looks up the specified `protocol` in the
* session's assigned protocol group and sends the `payload` to the layer that
* is next in the sequence in the `_WRAP` direction.
*
* Layers may use this to generate their own data to send in the sequence, e.g.
* for protocol-specific ceremony. */
int session2_wrap_after(struct session2 *s, enum protolayer_type protocol,
struct protolayer_payload payload,
const struct comm_info *comm,
protolayer_finished_cb cb, void *baton);
/** Sends an event to be synchronously processed by the protocol layers of the
* specified session. The layers are first iterated through in the `_UNWRAP`
* direction, then bounced back in the `_WRAP` direction. */
void session2_event(struct session2 *s, enum protolayer_event_type event, void *baton);
/** Sends an event to be synchronously processed by the protocol layers of the
* specified session, starting from the specified `protocol` in the `_UNWRAP`
* direction. The layers are first iterated through in the `_UNWRAP` direction,
* then bounced back in the `_WRAP` direction.
*
* NOTE: The bounced iteration does not exclude any layers - the layer
* specified by `protocol` and those before it are only skipped in the
* `_UNWRAP` direction! */
void session2_event_after(struct session2 *s, enum protolayer_type protocol,
enum protolayer_event_type event, void *baton);
/** Sends a `PROTOLAYER_EVENT_CLOSE` event to be processed by the protocol
* layers of the specified session. This function exists for readability
* reasons, to signal the intent that sending this event is used to actually
* close the session. */
static inline void session2_close(struct session2 *s)
{
session2_event(s, PROTOLAYER_EVENT_CLOSE, NULL);
}
/** Sends a `PROTOLAYER_EVENT_FORCE_CLOSE` event to be processed by the
* protocol layers of the specified session. This function exists for
* readability reasons, to signal the intent that sending this event is used to
* actually close the session. */
static inline void session2_force_close(struct session2 *s)
{
session2_event(s, PROTOLAYER_EVENT_FORCE_CLOSE, NULL);
}
/** Performs initial setup of the specified `req`, using the session's protocol
* layers. Layers are processed in the `_UNWRAP` direction. */
void session2_init_request(struct session2 *s, struct kr_request *req);
/** Removes the specified request task from the session's tasklist. The session
* must be outgoing. If the session is UDP, a signal to close is also sent to it. */
void session2_kill_ioreq(struct session2 *session, struct qr_task *task);
/** Update `last_activity` to the current timestamp. */
static inline void session2_touch(struct session2 *session)
{
session->last_activity = kr_now();
}
/*
* Copyright (C) 2016 American Civil Liberties Union (ACLU)
* Copyright (C) CZ.NIC, z.s.p.o
*
* Initial Author: Daniel Kahn Gillmor <dkg@fifthhorseman.net>
* Ondřej Surý <ondrej@sury.org>
*
* SPDX-License-Identifier: GPL-3.0-or-later
*/
#include <gnutls/abstract.h>
#include <gnutls/crypto.h>
#include <gnutls/gnutls.h>
#include <gnutls/x509.h>
#include <uv.h>
#include <errno.h>
#include <stdalign.h>
#include <stdlib.h>
#include "contrib/ucw/lib.h"
#include "contrib/base64.h"
#include "daemon/tls.h"
#include "daemon/worker.h"
#include "daemon/session2.h"
#define EPHEMERAL_CERT_EXPIRATION_SECONDS_RENEW_BEFORE ((time_t)60*60*24*7)
#define GNUTLS_PIN_MIN_VERSION 0x030400
#define UNWRAP_BUF_SIZE 131072
#define TLS_CHUNK_SIZE ((size_t)16 * 1024)
#define VERBOSE_MSG(cl_side, ...)\
if (cl_side) \
kr_log_debug(TLSCLIENT, __VA_ARGS__); \
else \
kr_log_debug(TLS, __VA_ARGS__);
static const gnutls_datum_t tls_grp_alpn[KR_PROTO_COUNT] = {
[KR_PROTO_DOT] = { (uint8_t *)"dot", 3 },
[KR_PROTO_DOH] = { (uint8_t *)"h2", 2 },
};
typedef enum tls_client_hs_state {
TLS_HS_NOT_STARTED = 0,
TLS_HS_IN_PROGRESS,
TLS_HS_DONE,
TLS_HS_CLOSING,
TLS_HS_LAST
} tls_hs_state_t;
struct pl_tls_sess_data {
struct protolayer_data h;
bool client_side;
bool first_handshake_done;
gnutls_session_t tls_session;
tls_hs_state_t handshake_state;
protolayer_iter_ctx_queue_t unwrap_queue;
protolayer_iter_ctx_queue_t wrap_queue;
struct wire_buf unwrap_buf;
size_t write_queue_size;
union {
struct tls_credentials *server_credentials;
tls_client_param_t *client_params; /**< Ref-counted. */
};
};
struct tls_credentials * tls_get_ephemeral_credentials(void);
void tls_credentials_log_pins(struct tls_credentials *tls_credentials);
static int client_verify_certificate(gnutls_session_t tls_session);
static struct tls_credentials *tls_credentials_reserve(struct tls_credentials *tls_credentials);
/**
* Set restrictions on TLS features, in particular ciphers.
*
* We explicitly disable features according to:
* https://datatracker.ietf.org/doc/html/rfc8310#section-9
* in case the gnutls+OS defaults allow them.
* Performance optimizations are not implemented at the moment.
*
* OS defaults are taken into account, e.g. on Red Hat family there's
* /etc/crypto-policies/back-ends/gnutls.config and update-crypto-policies tool.
*/
static int kres_gnutls_set_priority(gnutls_session_t session) {
static const char * const extra_prio =
"-VERS-TLS1.0:-VERS-TLS1.1:-VERS-SSL3.0:-ARCFOUR-128:-COMP-ALL:+COMP-NULL";
const char *errpos = NULL;
int err = gnutls_set_default_priority_append(session, extra_prio, &errpos, 0);
if (err != GNUTLS_E_SUCCESS) {
kr_log_error(TLS, "setting priority '%s' failed at character %zd (...'%s') with %s (%d)\n",
extra_prio, errpos - extra_prio, errpos, gnutls_strerror_name(err), err);
}
return err;
}
static ssize_t kres_gnutls_pull(gnutls_transport_ptr_t h, void *buf, size_t len)
{
struct pl_tls_sess_data *tls = h;
if (kr_fails_assert(tls)) {
errno = EFAULT;
return -1;
}
bool avail = protolayer_queue_has_payload(&tls->unwrap_queue);
VERBOSE_MSG(tls->client_side, "pull wanted: %zu avail: %s\n",
len, avail ? "yes" : "no");
if (!avail) {
errno = EAGAIN;
return -1;
}
char *dest = buf;
size_t transfer = 0;
while (queue_len(tls->unwrap_queue) > 0 && len > 0) {
struct protolayer_iter_ctx *ctx = queue_head(tls->unwrap_queue);
struct protolayer_payload *pld = &ctx->payload;
bool fully_consumed = false;
if (pld->type == PROTOLAYER_PAYLOAD_BUFFER) {
size_t to_copy = MIN(len, pld->buffer.len);
memcpy(dest, pld->buffer.buf, to_copy);
dest += to_copy;
len -= to_copy;
pld->buffer.buf = (char *)pld->buffer.buf + to_copy;
pld->buffer.len -= to_copy;
transfer += to_copy;
if (pld->buffer.len == 0)
fully_consumed = true;
} else if (pld->type == PROTOLAYER_PAYLOAD_IOVEC) {
while (pld->iovec.cnt && len > 0) {
struct iovec *iov = pld->iovec.iov;
size_t to_copy = MIN(len, iov->iov_len);
memcpy(dest, iov->iov_base, to_copy);
dest += to_copy;
len -= to_copy;
iov->iov_base = ((char *)iov->iov_base) + to_copy;
iov->iov_len -= to_copy;
transfer += to_copy;
if (iov->iov_len == 0) {
pld->iovec.iov++;
pld->iovec.cnt--;
}
}
if (pld->iovec.cnt == 0)
fully_consumed = true;
} else if (pld->type == PROTOLAYER_PAYLOAD_WIRE_BUF) {
size_t wbl = wire_buf_data_length(pld->wire_buf);
size_t to_copy = MIN(len, wbl);
memcpy(dest, wire_buf_data(pld->wire_buf), to_copy);
dest += to_copy;
len -= to_copy;
transfer += to_copy;
wire_buf_trim(pld->wire_buf, to_copy);
if (wire_buf_data_length(pld->wire_buf) == 0) {
wire_buf_reset(pld->wire_buf);
fully_consumed = true;
}
} else if (!pld->type) {
fully_consumed = true;
} else {
kr_assert(false && "Unsupported payload type");
errno = EFAULT;
return -1;
}
if (!fully_consumed) /* `len` was smaller than the sum of payloads */
break;
if (queue_len(tls->unwrap_queue) > 1) {
/* Finalize queued contexts, except for the last one. */
protolayer_break(ctx, kr_ok());
queue_pop(tls->unwrap_queue);
} else {
/* The last queued context will `continue` on the next
* `gnutls_record_recv`. */
ctx->payload.type = PROTOLAYER_PAYLOAD_NULL;
break;
}
}
VERBOSE_MSG(tls->client_side, "pull transfer: %zu\n", transfer);
return transfer;
}
struct kres_gnutls_push_ctx {
struct pl_tls_sess_data *sess_data;
struct iovec iov[];
};
static void kres_gnutls_push_finished(int status, struct session2 *session,
const struct comm_info *comm, void *baton)
{
struct kres_gnutls_push_ctx *push_ctx = baton;
struct pl_tls_sess_data *tls = push_ctx->sess_data;
while (queue_len(tls->wrap_queue)) {
struct protolayer_iter_ctx *ctx = queue_head(tls->wrap_queue);
protolayer_break(ctx, kr_ok());
queue_pop(tls->wrap_queue);
}
free(push_ctx);
}
static ssize_t kres_gnutls_vec_push(gnutls_transport_ptr_t h, const giovec_t * iov, int iovcnt)
{
struct pl_tls_sess_data *tls = h;
if (kr_fails_assert(tls)) {
errno = EFAULT;
return -1;
}
if (iovcnt == 0) {
return 0;
}
if (kr_fails_assert(iovcnt > 0)) {
errno = EINVAL;
return -1;
}
size_t total_len = 0;
for (int i = 0; i < iovcnt; i++)
total_len += iov[i].iov_len;
struct kres_gnutls_push_ctx *push_ctx =
malloc(sizeof(*push_ctx) + sizeof(struct iovec[iovcnt]));
kr_require(push_ctx);
push_ctx->sess_data = tls;
memcpy(push_ctx->iov, iov, sizeof(struct iovec[iovcnt]));
session2_wrap_after(tls->h.session, PROTOLAYER_TYPE_TLS,
protolayer_payload_iovec(push_ctx->iov, iovcnt, true),
NULL, kres_gnutls_push_finished, push_ctx);
return total_len;
}
static void tls_handshake_success(struct pl_tls_sess_data *tls,
struct session2 *session)
{
if (tls->client_side) {
tls_client_param_t *tls_params = tls->client_params;
gnutls_session_t tls_session = tls->tls_session;
if (gnutls_session_is_resumed(tls_session) != 0) {
kr_log_debug(TLSCLIENT, "TLS session has resumed\n");
} else {
kr_log_debug(TLSCLIENT, "TLS session has not resumed\n");
/* session wasn't resumed, delete old session data ... */
if (tls_params->session_data.data != NULL) {
gnutls_free(tls_params->session_data.data);
tls_params->session_data.data = NULL;
tls_params->session_data.size = 0;
}
/* ... and get the new session data */
gnutls_datum_t tls_session_data = { NULL, 0 };
int ret = gnutls_session_get_data2(tls_session, &tls_session_data);
if (ret == 0) {
tls_params->session_data = tls_session_data;
}
}
}
if (!tls->first_handshake_done) {
session2_event_after(session, PROTOLAYER_TYPE_TLS,
PROTOLAYER_EVENT_CONNECT, NULL);
tls->first_handshake_done = true;
}
}
/** Perform TLS handshake and handle error codes according to the documentation.
* See See https://gnutls.org/manual/html_node/TLS-handshake.html#TLS-handshake
* The function returns kr_ok() or success or non fatal error, kr_error(EAGAIN) on blocking, or kr_error(EIO) on fatal error.
*/
static int tls_handshake(struct pl_tls_sess_data *tls, struct session2 *session)
{
int err = gnutls_handshake(tls->tls_session);
if (err == GNUTLS_E_SUCCESS) {
/* Handshake finished, return success */
tls->handshake_state = TLS_HS_DONE;
struct sockaddr *peer = session2_get_peer(session);
VERBOSE_MSG(tls->client_side, "TLS handshake with %s has completed\n",
kr_straddr(peer));
tls_handshake_success(tls, session);
} else if (err == GNUTLS_E_AGAIN) {
return kr_error(EAGAIN);
} else if (gnutls_error_is_fatal(err)) {
/* Fatal errors, return error as it's not recoverable */
VERBOSE_MSG(tls->client_side, "gnutls_handshake failed: %s (%d)\n",
gnutls_strerror_name(err), err);
/* Notify the peer about handshake failure via an alert. */
gnutls_alert_send_appropriate(tls->tls_session, err);
enum protolayer_event_type etype = (tls->first_handshake_done)
? PROTOLAYER_EVENT_DISCONNECT
: PROTOLAYER_EVENT_CONNECT_FAIL;
session2_event(session, etype,
(void *)KR_SELECTION_TLS_HANDSHAKE_FAILED);
return kr_error(EIO);
} else if (err == GNUTLS_E_WARNING_ALERT_RECEIVED) {
/* Handle warning when in verbose mode */
const char *alert_name = gnutls_alert_get_name(gnutls_alert_get(tls->tls_session));
if (alert_name != NULL) {
struct sockaddr *peer = session2_get_peer(session);
VERBOSE_MSG(tls->client_side, "TLS alert from %s received: %s\n",
kr_straddr(peer), alert_name);
}
}
return kr_ok();
}
/*! Close a TLS context (call gnutls_bye()) */
static void tls_close(struct pl_tls_sess_data *tls, struct session2 *session, bool allow_bye)
{
if (tls == NULL || tls->tls_session == NULL || kr_fails_assert(session))
return;
/* Store the current session data for potential resumption of this session */
if (session->outgoing && tls->client_params) {
gnutls_free(tls->client_params->session_data.data);
tls->client_params->session_data.data = NULL;
tls->client_params->session_data.size = 0;
gnutls_session_get_data2(
tls->tls_session,
&tls->client_params->session_data);
}
const struct sockaddr *peer = session2_get_peer(session);
if (allow_bye && tls->handshake_state == TLS_HS_DONE) {
VERBOSE_MSG(tls->client_side, "closing tls connection to `%s`\n",
kr_straddr(peer));
tls->handshake_state = TLS_HS_CLOSING;
gnutls_bye(tls->tls_session, GNUTLS_SHUT_RDWR);
} else {
VERBOSE_MSG(tls->client_side, "closing tls connection to `%s` (without bye)\n",
kr_straddr(peer));
}
}
/*
DNS-over-TLS Out of band key-pinned authentication profile uses the
same form of pins as HPKP:
e.g. pin-sha256="FHkyLhvI0n70E47cJlRTamTrnYVcsYdjUGbr79CfAVI="
DNS-over-TLS OOB key-pins: https://tools.ietf.org/html/rfc7858#appendix-A
HPKP pin reference: https://tools.ietf.org/html/rfc7469#appendix-A
*/
#define PINLEN ((((32) * 8 + 4)/6) + 3 + 1)
/* Compute pin_sha256 for the certificate.
* It may be in raw format - just TLS_SHA256_RAW_LEN bytes without termination,
* or it may be a base64 0-terminated string requiring up to
* TLS_SHA256_BASE64_BUFLEN bytes.
* \return error code */
static int get_oob_key_pin(gnutls_x509_crt_t crt, char *outchar, ssize_t outchar_len, bool raw)
{
/* TODO: simplify this function by using gnutls_x509_crt_get_key_id() */
if (kr_fails_assert(!raw || outchar_len >= TLS_SHA256_RAW_LEN)) {
return kr_error(ENOSPC);
/* With !raw we have check inside kr_base64_encode. */
}
gnutls_pubkey_t key;
int err = gnutls_pubkey_init(&key);
if (err != GNUTLS_E_SUCCESS) return err;
gnutls_datum_t datum = { .data = NULL, .size = 0 };
err = gnutls_pubkey_import_x509(key, crt, 0);
if (err != GNUTLS_E_SUCCESS) goto leave;
err = gnutls_pubkey_export2(key, GNUTLS_X509_FMT_DER, &datum);
if (err != GNUTLS_E_SUCCESS) goto leave;
char raw_pin[TLS_SHA256_RAW_LEN]; /* TMP buffer if raw == false */
err = gnutls_hash_fast(GNUTLS_DIG_SHA256, datum.data, datum.size,
(raw ? outchar : raw_pin));
if (err != GNUTLS_E_SUCCESS || raw/*success*/)
goto leave;
/* Convert to non-raw. */
err = kr_base64_encode((uint8_t *)raw_pin, sizeof(raw_pin),
(uint8_t *)outchar, outchar_len);
if (err >= 0 && err < outchar_len) {
outchar[err] = '\0'; /* kr_base64_encode() doesn't do it */
err = GNUTLS_E_SUCCESS;
} else if (kr_fails_assert(err < 0)) {
outchar[outchar_len - 1] = '\0';
err = kr_error(ENOSPC); /* base64 fits but '\0' doesn't */
}
leave:
gnutls_free(datum.data);
gnutls_pubkey_deinit(key);
return err;
}
/*! Log DNS-over-TLS OOB key-pin form of current credentials:
* https://tools.ietf.org/html/rfc7858#appendix-A */
void tls_credentials_log_pins(struct tls_credentials *tls_credentials)
{
for (int index = 0;; index++) {
gnutls_x509_crt_t *certs = NULL;
unsigned int cert_count = 0;
int err = gnutls_certificate_get_x509_crt(tls_credentials->credentials,
index, &certs, &cert_count);
if (err != GNUTLS_E_SUCCESS) {
if (err != GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE) {
kr_log_error(TLS, "could not get X.509 certificates (%d) %s\n",
err, gnutls_strerror_name(err));
}
return;
}
for (int i = 0; i < cert_count; i++) {
char pin[TLS_SHA256_BASE64_BUFLEN] = { 0 };
err = get_oob_key_pin(certs[i], pin, sizeof(pin), false);
if (err != GNUTLS_E_SUCCESS) {
kr_log_error(TLS, "could not calculate RFC 7858 OOB key-pin from cert %d (%d) %s\n",
i, err, gnutls_strerror_name(err));
} else {
kr_log_info(TLS, "RFC 7858 OOB key-pin (%d): pin-sha256=\"%s\"\n",
i, pin);
}
gnutls_x509_crt_deinit(certs[i]);
}
gnutls_free(certs);
}
}
static int str_replace(char **where_ptr, const char *with)
{
char *copy = with ? strdup(with) : NULL;
if (with && !copy) {
return kr_error(ENOMEM);
}
free(*where_ptr);
*where_ptr = copy;
return kr_ok();
}
static time_t get_end_entity_expiration(gnutls_certificate_credentials_t creds)
{
gnutls_datum_t data;
gnutls_x509_crt_t cert = NULL;
int err;
time_t ret = GNUTLS_X509_NO_WELL_DEFINED_EXPIRATION;
if ((err = gnutls_certificate_get_crt_raw(creds, 0, 0, &data)) != GNUTLS_E_SUCCESS) {
kr_log_error(TLS, "failed to get cert to check expiration: (%d) %s\n",
err, gnutls_strerror_name(err));
goto done;
}
if ((err = gnutls_x509_crt_init(&cert)) != GNUTLS_E_SUCCESS) {
kr_log_error(TLS, "failed to initialize cert: (%d) %s\n",
err, gnutls_strerror_name(err));
goto done;
}
if ((err = gnutls_x509_crt_import(cert, &data, GNUTLS_X509_FMT_DER)) != GNUTLS_E_SUCCESS) {
kr_log_error(TLS, "failed to construct cert while checking expiration: (%d) %s\n",
err, gnutls_strerror_name(err));
goto done;
}
ret = gnutls_x509_crt_get_expiration_time (cert);
done:
/* do not free data; g_c_get_crt_raw() says to treat it as
* constant. */
gnutls_x509_crt_deinit(cert);
return ret;
}
int tls_certificate_set(const char *tls_cert, const char *tls_key)
{
if (kr_fails_assert(the_network)) {
return kr_error(EINVAL);
}
struct tls_credentials *tls_credentials = calloc(1, sizeof(*tls_credentials));
if (tls_credentials == NULL) {
return kr_error(ENOMEM);
}
int err = 0;
if ((err = gnutls_certificate_allocate_credentials(&tls_credentials->credentials)) != GNUTLS_E_SUCCESS) {
kr_log_error(TLS, "gnutls_certificate_allocate_credentials() failed: (%d) %s\n",
err, gnutls_strerror_name(err));
tls_credentials_free(tls_credentials);
return kr_error(ENOMEM);
}
if ((err = gnutls_certificate_set_x509_system_trust(tls_credentials->credentials)) < 0) {
if (err != GNUTLS_E_UNIMPLEMENTED_FEATURE) {
kr_log_warning(TLS, "warning: gnutls_certificate_set_x509_system_trust() failed: (%d) %s\n",
err, gnutls_strerror_name(err));
tls_credentials_free(tls_credentials);
return err;
}
}
if ((str_replace(&tls_credentials->tls_cert, tls_cert) != 0) ||
(str_replace(&tls_credentials->tls_key, tls_key) != 0)) {
tls_credentials_free(tls_credentials);
return kr_error(ENOMEM);
}
if ((err = gnutls_certificate_set_x509_key_file(tls_credentials->credentials,
tls_cert, tls_key, GNUTLS_X509_FMT_PEM)) != GNUTLS_E_SUCCESS) {
tls_credentials_free(tls_credentials);
kr_log_error(TLS, "gnutls_certificate_set_x509_key_file(%s,%s) failed: %d (%s)\n",
tls_cert, tls_key, err, gnutls_strerror_name(err));
return kr_error(EINVAL);
}
/* record the expiration date: */
tls_credentials->valid_until = get_end_entity_expiration(tls_credentials->credentials);
/* Exchange the x509 credentials */
struct tls_credentials *old_credentials = the_network->tls_credentials;
/* Start using the new x509_credentials */
the_network->tls_credentials = tls_credentials;
tls_credentials_log_pins(the_network->tls_credentials);
if (old_credentials) {
err = tls_credentials_release(old_credentials);
if (err != kr_error(EBUSY)) {
return err;
}
}
return kr_ok();
}
/*! Borrow TLS credentials for context. */
static struct tls_credentials *tls_credentials_reserve(struct tls_credentials *tls_credentials)
{
if (!tls_credentials) {
return NULL;
}
tls_credentials->count++;
return tls_credentials;
}
/*! Release TLS credentials for context (decrements refcount or frees). */
int tls_credentials_release(struct tls_credentials *tls_credentials)
{
if (!tls_credentials) {
return kr_error(EINVAL);
}
if (--tls_credentials->count < 0) {
tls_credentials_free(tls_credentials);
} else {
return kr_error(EBUSY);
}
return kr_ok();
}
/*! Free TLS credentials, must not be called if it holds positive refcount. */
void tls_credentials_free(struct tls_credentials *tls_credentials)
{
if (!tls_credentials) {
return;
}
if (tls_credentials->credentials) {
gnutls_certificate_free_credentials(tls_credentials->credentials);
}
if (tls_credentials->tls_cert) {
free(tls_credentials->tls_cert);
}
if (tls_credentials->tls_key) {
free(tls_credentials->tls_key);
}
if (tls_credentials->ephemeral_servicename) {
free(tls_credentials->ephemeral_servicename);
}
free(tls_credentials);
}
void tls_client_param_unref(tls_client_param_t *entry)
{
if (!entry || kr_fails_assert(entry->refs)) return;
--(entry->refs);
if (entry->refs) return;
VERBOSE_MSG(true, "freeing TLS parameters %p\n", (void *)entry);
for (int i = 0; i < entry->ca_files.len; ++i) {
free_const(entry->ca_files.at[i]);
}
array_clear(entry->ca_files);
free_const(entry->hostname);
for (int i = 0; i < entry->pins.len; ++i) {
free_const(entry->pins.at[i]);
}
array_clear(entry->pins);
if (entry->credentials) {
gnutls_certificate_free_credentials(entry->credentials);
}
if (entry->session_data.data) {
gnutls_free(entry->session_data.data);
}
free(entry);
}
static int param_free(void **param, void *null)
{
if (kr_fails_assert(param && *param))
return -1;
tls_client_param_unref(*param);
return 0;
}
void tls_client_params_free(tls_client_params_t *params)
{
if (!params) return;
trie_apply(params, param_free, NULL);
trie_free(params);
}
tls_client_param_t * tls_client_param_new(void)
{
tls_client_param_t *e = calloc(1, sizeof(*e));
if (kr_fails_assert(e))
return NULL;
/* Note: those array_t don't need further initialization. */
e->refs = 1;
int ret = gnutls_certificate_allocate_credentials(&e->credentials);
if (ret != GNUTLS_E_SUCCESS) {
kr_log_error(TLSCLIENT, "error: gnutls_certificate_allocate_credentials() fails (%s)\n",
gnutls_strerror_name(ret));
free(e);
return NULL;
}
gnutls_certificate_set_verify_function(e->credentials, client_verify_certificate);
return e;
}
/**
* Convert an IP address and port number to binary key.
*
* \precond buffer \param key must have sufficient size
* \param addr[in]
* \param len[out] output length
* \param key[out] output buffer
*/
static bool construct_key(const union kr_sockaddr *addr, uint32_t *len, char *key)
{
switch (addr->ip.sa_family) {
case AF_INET:
memcpy(key, &addr->ip4.sin_port, sizeof(addr->ip4.sin_port));
memcpy(key + sizeof(addr->ip4.sin_port), &addr->ip4.sin_addr,
sizeof(addr->ip4.sin_addr));
*len = sizeof(addr->ip4.sin_port) + sizeof(addr->ip4.sin_addr);
return true;
case AF_INET6:
memcpy(key, &addr->ip6.sin6_port, sizeof(addr->ip6.sin6_port));
memcpy(key + sizeof(addr->ip6.sin6_port), &addr->ip6.sin6_addr,
sizeof(addr->ip6.sin6_addr));
*len = sizeof(addr->ip6.sin6_port) + sizeof(addr->ip6.sin6_addr);
return true;
default:
kr_assert(!EINVAL);
return false;
}
}
tls_client_param_t **tls_client_param_getptr(tls_client_params_t **params,
const struct sockaddr *addr, bool do_insert)
{
if (kr_fails_assert(params && addr))
return NULL;
/* We accept NULL for empty map; ensure the map exists if needed. */
if (!*params) {
if (!do_insert) return NULL;
*params = trie_create(NULL);
if (kr_fails_assert(*params))
return NULL;
}
/* Construct the key. */
const union kr_sockaddr *ia = (const union kr_sockaddr *)addr;
char key[sizeof(ia->ip6.sin6_port) + sizeof(ia->ip6.sin6_addr)];
uint32_t len;
if (!construct_key(ia, &len, key))
return NULL;
/* Get the entry. */
return (tls_client_param_t **)
(do_insert ? trie_get_ins : trie_get_try)(*params, key, len);
}
int tls_client_param_remove(tls_client_params_t *params, const struct sockaddr *addr)
{
const union kr_sockaddr *ia = (const union kr_sockaddr *)addr;
char key[sizeof(ia->ip6.sin6_port) + sizeof(ia->ip6.sin6_addr)];
uint32_t len;
if (!construct_key(ia, &len, key))
return kr_error(EINVAL);
trie_val_t param_ptr;
int ret = trie_del(params, key, len, &param_ptr);
if (ret != KNOT_EOK)
return kr_error(ret);
tls_client_param_unref(param_ptr);
return kr_ok();
}
static void log_all_pins(tls_client_param_t *params)
{
uint8_t buffer[TLS_SHA256_BASE64_BUFLEN + 1];
for (int i = 0; i < params->pins.len; i++) {
int len = kr_base64_encode(params->pins.at[i], TLS_SHA256_RAW_LEN,
buffer, TLS_SHA256_BASE64_BUFLEN);
if (!kr_fails_assert(len > 0)) {
buffer[len] = '\0';
kr_log_error(TLSCLIENT, "pin no. %d: %s\n", i, buffer);
}
}
}
static void log_all_certificates(const unsigned int cert_list_size,
const gnutls_datum_t *cert_list)
{
for (int i = 0; i < cert_list_size; i++) {
gnutls_x509_crt_t cert;
if (gnutls_x509_crt_init(&cert) != GNUTLS_E_SUCCESS) {
return;
}
if (gnutls_x509_crt_import(cert, &cert_list[i], GNUTLS_X509_FMT_DER) != GNUTLS_E_SUCCESS) {
gnutls_x509_crt_deinit(cert);
return;
}
char cert_pin[TLS_SHA256_BASE64_BUFLEN];
if (get_oob_key_pin(cert, cert_pin, sizeof(cert_pin), false) != GNUTLS_E_SUCCESS) {
gnutls_x509_crt_deinit(cert);
return;
}
kr_log_error(TLSCLIENT, "Certificate: %s\n", cert_pin);
gnutls_x509_crt_deinit(cert);
}
}
/**
* Verify that at least one certificate in the certificate chain matches
* at least one certificate pin in the non-empty params->pins array.
* \returns GNUTLS_E_SUCCESS if pin matches, any other value is an error
*/
static int client_verify_pin(const unsigned int cert_list_size,
const gnutls_datum_t *cert_list,
tls_client_param_t *params)
{
if (kr_fails_assert(params->pins.len > 0))
return GNUTLS_E_CERTIFICATE_ERROR;
for (int i = 0; i < cert_list_size; i++) {
gnutls_x509_crt_t cert;
int ret = gnutls_x509_crt_init(&cert);
if (ret != GNUTLS_E_SUCCESS) {
return ret;
}
ret = gnutls_x509_crt_import(cert, &cert_list[i], GNUTLS_X509_FMT_DER);
if (ret != GNUTLS_E_SUCCESS) {
gnutls_x509_crt_deinit(cert);
return ret;
}
char cert_pin[TLS_SHA256_RAW_LEN];
/* Get raw pin and compare. */
ret = get_oob_key_pin(cert, cert_pin, sizeof(cert_pin), true);
gnutls_x509_crt_deinit(cert);
if (ret != GNUTLS_E_SUCCESS) {
return ret;
}
for (size_t j = 0; j < params->pins.len; ++j) {
const uint8_t *pin = params->pins.at[j];
if (memcmp(cert_pin, pin, TLS_SHA256_RAW_LEN) != 0)
continue; /* mismatch */
VERBOSE_MSG(true, "matched a configured pin no. %zd\n", j);
return GNUTLS_E_SUCCESS;
}
VERBOSE_MSG(true, "none of %zd configured pin(s) matched\n",
params->pins.len);
}
kr_log_error(TLSCLIENT, "no pin matched: %zu pins * %d certificates\n",
params->pins.len, cert_list_size);
log_all_pins(params);
log_all_certificates(cert_list_size, cert_list);
return GNUTLS_E_CERTIFICATE_ERROR;
}
/**
* Verify that \param tls_session contains a valid X.509 certificate chain
* with given hostname.
*
* \returns GNUTLS_E_SUCCESS if certificate chain is valid, any other value is an error
*/
static int client_verify_certchain(struct pl_tls_sess_data *tls, const char *hostname)
{
if (kr_fails_assert(hostname)) {
kr_log_error(TLSCLIENT, "internal config inconsistency: no hostname set\n");
return GNUTLS_E_CERTIFICATE_ERROR;
}
unsigned int status;
int ret = gnutls_certificate_verify_peers3(tls->tls_session, hostname, &status);
if ((ret == GNUTLS_E_SUCCESS) && (status == 0)) {
return GNUTLS_E_SUCCESS;
}
const char *addr_str = kr_straddr(session2_get_peer(tls->h.session));
if (ret == GNUTLS_E_SUCCESS) {
gnutls_datum_t msg;
ret = gnutls_certificate_verification_status_print(
status, gnutls_certificate_type_get(tls->tls_session), &msg, 0);
if (ret == GNUTLS_E_SUCCESS) {
kr_log_error(TLSCLIENT, "failed to verify peer certificate of %s: "
"%s\n", addr_str, msg.data);
gnutls_free(msg.data);
} else {
kr_log_error(TLSCLIENT, "failed to verify peer certificate of %s: "
"unable to print reason: %s (%s)\n",
addr_str,
gnutls_strerror(ret), gnutls_strerror_name(ret));
} /* gnutls_certificate_verification_status_print end */
} else {
kr_log_error(TLSCLIENT, "failed to verify peer certificate of %s: "
"gnutls_certificate_verify_peers3 error: %s (%s)\n",
addr_str,
gnutls_strerror(ret), gnutls_strerror_name(ret));
} /* gnutls_certificate_verify_peers3 end */
return GNUTLS_E_CERTIFICATE_ERROR;
}
/**
* Verify that actual TLS security parameters of \param tls_session
* match requirements provided by user in tls_session->params.
* \returns GNUTLS_E_SUCCESS if requirements were met, any other value is an error
*/
static int client_verify_certificate(gnutls_session_t tls_session)
{
struct pl_tls_sess_data *tls = gnutls_session_get_ptr(tls_session);
if (kr_fails_assert(tls->client_params))
return GNUTLS_E_CERTIFICATE_ERROR;
if (tls->client_params->insecure) {
return GNUTLS_E_SUCCESS;
}
gnutls_certificate_type_t cert_type = gnutls_certificate_type_get(tls_session);
if (cert_type != GNUTLS_CRT_X509) {
kr_log_error(TLSCLIENT, "invalid certificate type %i has been received\n",
cert_type);
return GNUTLS_E_CERTIFICATE_ERROR;
}
unsigned int cert_list_size = 0;
const gnutls_datum_t *cert_list =
gnutls_certificate_get_peers(tls_session, &cert_list_size);
if (cert_list == NULL || cert_list_size == 0) {
kr_log_error(TLSCLIENT, "empty certificate list\n");
return GNUTLS_E_CERTIFICATE_ERROR;
}
if (tls->client_params->pins.len > 0)
/* check hash of the certificate but ignore everything else */
return client_verify_pin(cert_list_size, cert_list, tls->client_params);
else
return client_verify_certchain(tls, tls->client_params->hostname);
}
static int tls_pull_timeout_func(gnutls_transport_ptr_t h, unsigned int ms)
{
struct pl_tls_sess_data *tls = h;
if (kr_fails_assert(tls)) {
errno = EFAULT;
return -1;
}
size_t avail = protolayer_queue_count_payload(&tls->unwrap_queue);
VERBOSE_MSG(tls->client_side, "timeout check: available: %zu\n", avail);
if (!avail) {
errno = EAGAIN;
return -1;
}
return avail;
}
static int pl_tls_sess_data_deinit(struct pl_tls_sess_data *tls)
{
if (tls->tls_session) {
/* Don't terminate TLS connection, just tear it down */
gnutls_deinit(tls->tls_session);
tls->tls_session = NULL;
}
if (tls->client_side) {
tls_client_param_unref(tls->client_params);
} else {
tls_credentials_release(tls->server_credentials);
}
wire_buf_deinit(&tls->unwrap_buf);
while (queue_len(tls->unwrap_queue) > 0) {
struct protolayer_iter_ctx *ctx = queue_head(tls->unwrap_queue);
protolayer_break(ctx, kr_error(EIO));
queue_pop(tls->unwrap_queue);
}
queue_deinit(tls->unwrap_queue);
while (queue_len(tls->wrap_queue)) {
struct protolayer_iter_ctx *ctx = queue_head(tls->wrap_queue);
protolayer_break(ctx, kr_error(EIO));
queue_pop(tls->wrap_queue);
}
queue_deinit(tls->wrap_queue);
return kr_ok();
}
static int pl_tls_sess_server_init(struct session2 *session,
struct pl_tls_sess_data *tls)
{
if (kr_fails_assert(the_worker && the_engine))
return kr_error(EINVAL);
if (!the_network->tls_credentials) {
the_network->tls_credentials = tls_get_ephemeral_credentials();
if (!the_network->tls_credentials) {
kr_log_error(TLS, "X.509 credentials are missing, and ephemeral credentials failed; no TLS\n");
return kr_error(EINVAL);
}
kr_log_info(TLS, "Using ephemeral TLS credentials\n");
tls_credentials_log_pins(the_network->tls_credentials);
}
time_t now = time(NULL);
if (the_network->tls_credentials->valid_until != GNUTLS_X509_NO_WELL_DEFINED_EXPIRATION) {
if (the_network->tls_credentials->ephemeral_servicename) {
/* ephemeral cert: refresh if due to expire within a week */
if (now >= the_network->tls_credentials->valid_until - EPHEMERAL_CERT_EXPIRATION_SECONDS_RENEW_BEFORE) {
struct tls_credentials *newcreds = tls_get_ephemeral_credentials();
if (newcreds) {
tls_credentials_release(the_network->tls_credentials);
the_network->tls_credentials = newcreds;
kr_log_info(TLS, "Renewed expiring ephemeral X.509 cert\n");
} else {
kr_log_error(TLS, "Failed to renew expiring ephemeral X.509 cert, using existing one\n");
}
}
} else {
/* non-ephemeral cert: warn once when certificate expires */
if (now >= the_network->tls_credentials->valid_until) {
kr_log_error(TLS, "X.509 certificate has expired!\n");
the_network->tls_credentials->valid_until = GNUTLS_X509_NO_WELL_DEFINED_EXPIRATION;
}
}
}
int flags = GNUTLS_SERVER | GNUTLS_NONBLOCK;
#if GNUTLS_VERSION_NUMBER >= 0x030705
if (gnutls_check_version("3.7.5"))
flags |= GNUTLS_NO_TICKETS_TLS12;
#endif
int ret = gnutls_init(&tls->tls_session, flags);
if (ret != GNUTLS_E_SUCCESS) {
kr_log_error(TLS, "gnutls_init(): %s (%d)\n", gnutls_strerror_name(ret), ret);
pl_tls_sess_data_deinit(tls);
return ret;
}
tls->server_credentials = tls_credentials_reserve(the_network->tls_credentials);
ret = gnutls_credentials_set(tls->tls_session, GNUTLS_CRD_CERTIFICATE,
tls->server_credentials->credentials);
if (ret != GNUTLS_E_SUCCESS) {
kr_log_error(TLS, "gnutls_credentials_set(): %s (%d)\n", gnutls_strerror_name(ret), ret);
pl_tls_sess_data_deinit(tls);
return ret;
}
ret = kres_gnutls_set_priority(tls->tls_session);
if (ret != GNUTLS_E_SUCCESS) {
pl_tls_sess_data_deinit(tls);
return ret;
}
tls->client_side = false;
wire_buf_init(&tls->unwrap_buf, UNWRAP_BUF_SIZE);
gnutls_transport_set_pull_function(tls->tls_session, kres_gnutls_pull);
gnutls_transport_set_vec_push_function(tls->tls_session, kres_gnutls_vec_push);
gnutls_transport_set_ptr(tls->tls_session, tls);
if (the_network->tls_session_ticket_ctx) {
tls_session_ticket_enable(the_network->tls_session_ticket_ctx,
tls->tls_session);
}
const gnutls_datum_t *alpn = &tls_grp_alpn[session->proto];
if (alpn->size) { /* ALPN is a non-empty string */
flags = 0;
#if GNUTLS_VERSION_NUMBER >= 0x030500
/* Mandatory ALPN means the protocol must match if and
* only if ALPN extension is used by the client. */
flags |= GNUTLS_ALPN_MANDATORY;
#endif
ret = gnutls_alpn_set_protocols(tls->tls_session, alpn, 1, flags);
if (ret != GNUTLS_E_SUCCESS) {
kr_log_error(TLS, "gnutls_alpn_set_protocols(): %s (%d)\n", gnutls_strerror_name(ret), ret);
pl_tls_sess_data_deinit(tls);
return ret;
}
}
return kr_ok();
}
static int pl_tls_sess_client_init(struct session2 *session,
struct pl_tls_sess_data *tls,
tls_client_param_t *param)
{
unsigned int flags = GNUTLS_CLIENT | GNUTLS_NONBLOCK
#ifdef GNUTLS_ENABLE_FALSE_START
| GNUTLS_ENABLE_FALSE_START
#endif
;
#if GNUTLS_VERSION_NUMBER >= 0x030705
if (gnutls_check_version("3.7.5"))
flags |= GNUTLS_NO_TICKETS_TLS12;
#endif
int ret = gnutls_init(&tls->tls_session, flags);
if (ret != GNUTLS_E_SUCCESS) {
pl_tls_sess_data_deinit(tls);
return ret;
}
ret = kres_gnutls_set_priority(tls->tls_session);
if (ret != GNUTLS_E_SUCCESS) {
pl_tls_sess_data_deinit(tls);
return ret;
}
/* Must take a reference on parameters as the credentials are owned by it
* and must not be freed while the session is active. */
++(param->refs);
tls->client_params = param;
ret = gnutls_credentials_set(tls->tls_session, GNUTLS_CRD_CERTIFICATE,
param->credentials);
if (ret == GNUTLS_E_SUCCESS && param->hostname) {
ret = gnutls_server_name_set(tls->tls_session, GNUTLS_NAME_DNS,
param->hostname, strlen(param->hostname));
kr_log_debug(TLSCLIENT, "set hostname, ret = %d\n", ret);
} else if (!param->hostname) {
kr_log_debug(TLSCLIENT, "no hostname\n");
}
if (ret != GNUTLS_E_SUCCESS) {
pl_tls_sess_data_deinit(tls);
return ret;
}
tls->client_side = true;
wire_buf_init(&tls->unwrap_buf, UNWRAP_BUF_SIZE);
gnutls_transport_set_pull_function(tls->tls_session, kres_gnutls_pull);
gnutls_transport_set_vec_push_function(tls->tls_session, kres_gnutls_vec_push);
gnutls_transport_set_ptr(tls->tls_session, tls);
return kr_ok();
}
static int pl_tls_sess_init(struct session2 *session,
void *sess_data,
void *param)
{
struct pl_tls_sess_data *tls = sess_data;
session->secure = true;
queue_init(tls->unwrap_queue);
queue_init(tls->wrap_queue);
if (session->outgoing)
return pl_tls_sess_client_init(session, tls, param);
else
return pl_tls_sess_server_init(session, tls);
}
static int pl_tls_sess_deinit(struct session2 *session,
void *sess_data)
{
return pl_tls_sess_data_deinit(sess_data);
}
static enum protolayer_iter_cb_result pl_tls_unwrap(void *sess_data, void *iter_data,
struct protolayer_iter_ctx *ctx)
{
int brstatus = kr_ok();
struct pl_tls_sess_data *tls = sess_data;
struct session2 *s = ctx->session;
queue_push(tls->unwrap_queue, ctx);
/* Ensure TLS handshake is performed before receiving data.
* See https://www.gnutls.org/manual/html_node/TLS-handshake.html */
while (tls->handshake_state <= TLS_HS_IN_PROGRESS) {
int err = tls_handshake(tls, s);
if (err == kr_error(EAGAIN)) {
return protolayer_async(); /* Wait for more data */
} else if (err != kr_ok()) {
brstatus = err;
goto exit_break;
}
}
/* See https://gnutls.org/manual/html_node/Data-transfer-and-termination.html#Data-transfer-and-termination */
while (true) {
ssize_t count = gnutls_record_recv(tls->tls_session,
wire_buf_free_space(&tls->unwrap_buf),
wire_buf_free_space_length(&tls->unwrap_buf));
if (count == GNUTLS_E_AGAIN) {
if (!protolayer_queue_has_payload(&tls->unwrap_queue)) {
/* See https://www.gnutls.org/manual/html_node/Asynchronous-operation.html */
break;
}
continue;
} else if (count == GNUTLS_E_INTERRUPTED) {
continue;
} else if (count == GNUTLS_E_REHANDSHAKE) {
/* See https://www.gnutls.org/manual/html_node/Re_002dauthentication.html */
struct sockaddr *peer = session2_get_peer(s);
VERBOSE_MSG(tls->client_side, "TLS rehandshake with %s has started\n",
kr_straddr(peer));
tls->handshake_state = TLS_HS_IN_PROGRESS;
int err = kr_ok();
while (tls->handshake_state <= TLS_HS_IN_PROGRESS) {
err = tls_handshake(tls, s);
if (err == kr_error(EAGAIN)) {
break;
} else if (err != kr_ok()) {
brstatus = err;
goto exit_break;
}
}
if (err == kr_error(EAGAIN)) {
/* pull function is out of data */
break;
}
/* There are can be data available, check it. */
continue;
} else if (count < 0) {
VERBOSE_MSG(tls->client_side, "gnutls_record_recv failed: %s (%zd)\n",
gnutls_strerror_name(count), count);
brstatus = kr_error(EIO);
goto exit_break;
} else if (count == 0) {
break;
}
VERBOSE_MSG(tls->client_side, "received %zd data\n", count);
wire_buf_consume(&tls->unwrap_buf, count);
if (wire_buf_free_space_length(&tls->unwrap_buf) == 0 && protolayer_queue_has_payload(&tls->unwrap_queue) > 0) {
/* wire buffer is full but not all data was consumed */
brstatus = kr_error(ENOSPC);
goto exit_break;
}
if (kr_fails_assert(queue_len(tls->unwrap_queue) == 1)) {
brstatus = kr_error(EINVAL);
goto exit_break;
}
struct protolayer_iter_ctx *ctx_head = queue_head(tls->unwrap_queue);
if (kr_fails_assert(ctx == ctx_head)) {
protolayer_break(ctx, kr_error(EINVAL));
ctx = ctx_head;
}
}
/* Here all data must be consumed. */
if (protolayer_queue_has_payload(&tls->unwrap_queue)) {
/* Something went wrong, better return error.
* This is most probably due to gnutls_record_recv() did not
* consume all available network data by calling kres_gnutls_pull().
* TODO assess the need for buffering of data amount.
*/
brstatus = kr_error(ENOSPC);
goto exit_break;
}
struct protolayer_iter_ctx *ctx_head = queue_head(tls->unwrap_queue);
if (!kr_fails_assert(ctx == ctx_head))
queue_pop(tls->unwrap_queue);
ctx->payload = protolayer_payload_wire_buf(&tls->unwrap_buf, false);
return protolayer_continue(ctx);
exit_break:
ctx_head = queue_head(tls->unwrap_queue);
if (!kr_fails_assert(ctx == ctx_head))
queue_pop(tls->unwrap_queue);
return protolayer_break(ctx, brstatus);
}
static ssize_t pl_tls_submit(gnutls_session_t tls_session,
struct protolayer_payload payload)
{
if (payload.type == PROTOLAYER_PAYLOAD_WIRE_BUF)
payload = protolayer_payload_as_buffer(&payload);
// TODO: the handling of positive gnutls_record_send() is weird/confusing,
// but it seems caught later when checking gnutls_record_uncork()
if (payload.type == PROTOLAYER_PAYLOAD_BUFFER) {
ssize_t count = gnutls_record_send(tls_session,
payload.buffer.buf, payload.buffer.len);
if (count < 0)
return count;
return payload.buffer.len;
} else if (payload.type == PROTOLAYER_PAYLOAD_IOVEC) {
ssize_t total_submitted = 0;
for (int i = 0; i < payload.iovec.cnt; i++) {
struct iovec iov = payload.iovec.iov[i];
ssize_t count = gnutls_record_send(tls_session,
iov.iov_base, iov.iov_len);
if (count < 0)
return count;
total_submitted += iov.iov_len;
}
return total_submitted;
}
kr_assert(false && "Invalid payload");
return kr_error(EINVAL);
}
static enum protolayer_iter_cb_result pl_tls_wrap(
void *sess_data, void *iter_data,
struct protolayer_iter_ctx *ctx)
{
struct pl_tls_sess_data *tls = sess_data;
gnutls_session_t tls_session = tls->tls_session;
gnutls_record_cork(tls_session);
ssize_t submitted = pl_tls_submit(tls_session, ctx->payload);
if (submitted < 0) {
VERBOSE_MSG(tls->client_side, "pl_tls_submit failed: %s (%zd)\n",
gnutls_strerror_name(submitted), submitted);
return protolayer_break(ctx, submitted);
}
queue_push(tls->wrap_queue, ctx);
int ret = gnutls_record_uncork(tls_session, GNUTLS_RECORD_WAIT);
if (ret < 0) {
if (!gnutls_error_is_fatal(ret)) {
queue_pop(tls->wrap_queue);
return protolayer_break(ctx, kr_error(EAGAIN));
} else {
queue_pop(tls->wrap_queue);
VERBOSE_MSG(tls->client_side, "gnutls_record_uncork failed: %s (%d)\n",
gnutls_strerror_name(ret), ret);
return protolayer_break(ctx, kr_error(EIO));
}
}
if (ret != submitted) {
kr_log_error(TLS, "gnutls_record_uncork didn't send all data (%d of %zd)\n", ret, submitted);
return protolayer_break(ctx, kr_error(EIO));
}
return protolayer_async();
}
static enum protolayer_event_cb_result pl_tls_client_connect_start(
struct pl_tls_sess_data *tls, struct session2 *session)
{
if (tls->handshake_state != TLS_HS_NOT_STARTED)
return PROTOLAYER_EVENT_CONSUME;
if (kr_fails_assert(session->outgoing))
return PROTOLAYER_EVENT_CONSUME;
gnutls_session_set_ptr(tls->tls_session, tls);
gnutls_handshake_set_timeout(tls->tls_session, the_network->tcp.tls_handshake_timeout);
gnutls_transport_set_pull_timeout_function(tls->tls_session, tls_pull_timeout_func);
tls->handshake_state = TLS_HS_IN_PROGRESS;
tls_client_param_t *tls_params = tls->client_params;
if (tls_params->session_data.data != NULL) {
gnutls_session_set_data(tls->tls_session, tls_params->session_data.data,
tls_params->session_data.size);
}
/* See https://www.gnutls.org/manual/html_node/Asynchronous-operation.html */
while (tls->handshake_state <= TLS_HS_IN_PROGRESS) {
int ret = tls_handshake(tls, session);
if (ret != kr_ok()) {
if (ret == kr_error(EAGAIN)) {
session2_timer_stop(session);
session2_timer_start(session,
PROTOLAYER_EVENT_GENERAL_TIMEOUT,
MAX_TCP_INACTIVITY, MAX_TCP_INACTIVITY);
}
return PROTOLAYER_EVENT_CONSUME;
}
}
return PROTOLAYER_EVENT_CONSUME;
}
static enum protolayer_event_cb_result pl_tls_event_unwrap(
enum protolayer_event_type event, void **baton,
struct session2 *s, void *sess_data)
{
struct pl_tls_sess_data *tls = sess_data;
if (event == PROTOLAYER_EVENT_CLOSE) {
tls_close(tls, s, true); /* WITH gnutls_bye */
return PROTOLAYER_EVENT_PROPAGATE;
}
if (event == PROTOLAYER_EVENT_FORCE_CLOSE) {
tls_close(tls, s, false); /* WITHOUT gnutls_bye */
return PROTOLAYER_EVENT_PROPAGATE;
}
if (event == PROTOLAYER_EVENT_EOF) {
// TCP half-closed state not allowed
session2_force_close(s);
return PROTOLAYER_EVENT_CONSUME;
}
if (tls->client_side) {
if (event == PROTOLAYER_EVENT_CONNECT)
return pl_tls_client_connect_start(tls, s);
} else {
if (event == PROTOLAYER_EVENT_CONNECT) {
/* TLS sends its own _CONNECT event when the handshake
* is finished. */
return PROTOLAYER_EVENT_CONSUME;
}
}
return PROTOLAYER_EVENT_PROPAGATE;
}
static enum protolayer_event_cb_result pl_tls_event_wrap(
enum protolayer_event_type event, void **baton,
struct session2 *session, void *sess_data)
{
if (event == PROTOLAYER_EVENT_STATS_SEND_ERR) {
the_worker->stats.err_tls += 1;
return PROTOLAYER_EVENT_CONSUME;
} else if (event == PROTOLAYER_EVENT_STATS_QRY_OUT) {
the_worker->stats.tls += 1;
return PROTOLAYER_EVENT_CONSUME;
}
return PROTOLAYER_EVENT_PROPAGATE;
}
static void pl_tls_request_init(struct session2 *session,
struct kr_request *req,
void *sess_data)
{
req->qsource.comm_flags.tls = true;
}
__attribute__((constructor))
static void tls_protolayers_init(void)
{
protolayer_globals[PROTOLAYER_TYPE_TLS] = (struct protolayer_globals){
.sess_size = sizeof(struct pl_tls_sess_data),
.sess_deinit = pl_tls_sess_deinit,
.wire_buf_overhead = TLS_CHUNK_SIZE,
.sess_init = pl_tls_sess_init,
.unwrap = pl_tls_unwrap,
.wrap = pl_tls_wrap,
.event_unwrap = pl_tls_event_unwrap,
.event_wrap = pl_tls_event_wrap,
.request_init = pl_tls_request_init
};
}
#undef VERBOSE_MSG
/* Copyright (C) 2016 American Civil Liberties Union (ACLU)
* Copyright (C) CZ.NIC, z.s.p.o
* SPDX-License-Identifier: GPL-3.0-or-later
*/
#pragma once
#include <uv.h>
#include <gnutls/gnutls.h>
#include <libknot/packet/pkt.h>
#include "lib/defines.h"
#include "lib/generic/array.h"
#include "lib/generic/trie.h"
#include "lib/utils.h"
#define MAX_TLS_PADDING KR_EDNS_PAYLOAD
#define TLS_MAX_UNCORK_RETRIES 100
/* rfc 5476, 7.3 - handshake Protocol overview
* https://tools.ietf.org/html/rfc5246#page-33
* Message flow for a full handshake (only mandatory messages)
* ClientHello -------->
ServerHello
<-------- ServerHelloDone
ClientKeyExchange
Finished -------->
<-------- Finished
*
* See also https://blog.cloudflare.com/keyless-ssl-the-nitty-gritty-technical-details/
* So it takes 2 RTT.
* As we use session tickets, there are additional messages, add one RTT mode.
*/
#define TLS_MAX_HANDSHAKE_TIME (KR_CONN_RTT_MAX * (uint64_t)3)
/** Transport session (opaque). */
struct session2;
struct tls_ctx;
struct tls_client_ctx;
struct tls_credentials {
int count;
char *tls_cert;
char *tls_key;
gnutls_certificate_credentials_t credentials;
time_t valid_until;
char *ephemeral_servicename;
};
#define TLS_SHA256_RAW_LEN 32 /* gnutls_hash_get_len(GNUTLS_DIG_SHA256) */
/** Required buffer length for pin_sha256, including the zero terminator. */
#define TLS_SHA256_BASE64_BUFLEN (((TLS_SHA256_RAW_LEN * 8 + 4) / 6) + 3 + 1)
/** TLS authentication parameters for a single address-port pair. */
typedef struct {
uint32_t refs; /**< Reference count; consider TLS sessions in progress. */
bool insecure; /**< Use no authentication. */
const char *hostname; /**< Server name for SNI and certificate check, lowercased. */
array_t(const char *) ca_files; /**< Paths to certificate files; not really used. */
array_t(const uint8_t *) pins; /**< Certificate pins as raw unterminated strings.*/
gnutls_certificate_credentials_t credentials; /**< CA creds. in gnutls format. */
gnutls_datum_t session_data; /**< Session-resumption data gets stored here. */
} tls_client_param_t;
/** Holds configuration for TLS authentication for all potential servers.
* Special case: NULL pointer also means empty. */
typedef trie_t tls_client_params_t;
/** Get a pointer-to-pointer to TLS auth params.
* If it didn't exist, it returns NULL (if !do_insert) or pointer to NULL. */
tls_client_param_t **tls_client_param_getptr(tls_client_params_t **params,
const struct sockaddr *addr, bool do_insert);
/** Get a pointer to TLS auth params or NULL. */
static inline tls_client_param_t *
tls_client_param_get(tls_client_params_t *params, const struct sockaddr *addr)
{
tls_client_param_t **pe = tls_client_param_getptr(&params, addr, false);
return pe ? *pe : NULL;
}
/** Allocate and initialize the structure (with ->ref = 1). */
tls_client_param_t * tls_client_param_new(void);
/** Reference-counted free(); any inside data is freed alongside. */
void tls_client_param_unref(tls_client_param_t *entry);
int tls_client_param_remove(tls_client_params_t *params, const struct sockaddr *addr);
/** Free TLS authentication parameters. */
void tls_client_params_free(tls_client_params_t *params);
/*! Set TLS certificate and key from files. */
int tls_certificate_set(const char *tls_cert, const char *tls_key);
/*! Release TLS credentials for context (decrements refcount or frees). */
int tls_credentials_release(struct tls_credentials *tls_credentials);
/*! Generate new ephemeral TLS credentials. */
struct tls_credentials * tls_get_ephemeral_credentials(void);
/* Session tickets, server side. Implementation in ./tls_session_ticket-srv.c */
/*! Opaque struct used by tls_session_ticket_* functions. */
struct tls_session_ticket_ctx;
/*! Suggested maximum reasonable secret length. */
#define TLS_SESSION_TICKET_SECRET_MAX_LEN 1024
/*! Create a session ticket context and initialize it (secret gets copied inside).
*
* Passing zero-length secret implies using a random key, i.e. not synchronized
* between multiple instances.
*
* Beware that knowledge of the secret (if nonempty) breaks forward secrecy,
* so you should rotate the secret regularly and securely erase all past secrets.
* With TLS < 1.3 it's probably too risky to set nonempty secret.
*/
struct tls_session_ticket_ctx * tls_session_ticket_ctx_create(
uv_loop_t *loop, const char *secret, size_t secret_len);
/*! Try to enable session tickets for a server session. */
void tls_session_ticket_enable(struct tls_session_ticket_ctx *ctx, gnutls_session_t session);
/*! Free all resources of the session ticket context. NULL is accepted as well. */
void tls_session_ticket_ctx_destroy(struct tls_session_ticket_ctx *ctx);
/*! Free TLS credentials. */
void tls_credentials_free(struct tls_credentials *tls_credentials);
/*
* Copyright (C) 2016 American Civil Liberties Union (ACLU)
* Copyright (C) CZ.NIC, z.s.p.o.
*
* Initial Author: Daniel Kahn Gillmor <dkg@fifthhorseman.net>
* SPDX-License-Identifier: GPL-3.0-or-later
*/
#include <sys/file.h>
#include <unistd.h>
#include <gnutls/gnutls.h>
#include <gnutls/x509.h>
#include <gnutls/crypto.h>
#include "daemon/engine.h"
#include "daemon/tls.h"
#define EPHEMERAL_PRIVKEY_FILENAME "ephemeral_key.pem"
#define INVALID_HOSTNAME "dns-over-tls.invalid"
#define EPHEMERAL_CERT_EXPIRATION_SECONDS ((time_t)60*60*24*90)
/* This is an attempt to grab an exclusive, advisory, non-blocking
* lock based on a filename. At the moment it's POSIX-only, but it
* should be abstract enough of an interface to make an implementation
* for non-posix systems if anyone cares. */
typedef int lock_t;
static bool lock_is_invalid(lock_t lock)
{
return lock == -1;
}
/* a blocking lock on a given filename */
static lock_t lock_filename(const char *fname)
{
lock_t lockfd = open(fname, O_RDONLY|O_CREAT, 0400);
if (lockfd == -1)
return lockfd;
/* this should be a non-blocking lock */
if (flock(lockfd, LOCK_EX | LOCK_NB) != 0) {
close(lockfd);
return -1;
}
return lockfd; /* for cleanup later */
}
static void lock_unlock(lock_t *lock, const char *fname)
{
if (lock && !lock_is_invalid(*lock)) {
flock(*lock, LOCK_UN);
close(*lock);
*lock = -1;
unlink(fname); /* ignore errors */
}
}
static gnutls_x509_privkey_t get_ephemeral_privkey (void)
{
gnutls_x509_privkey_t privkey = NULL;
int err;
gnutls_datum_t data = { .data = NULL, .size = 0 };
lock_t lock;
int datafd = -1;
/* Take a lock to ensure that two daemons started concurrently
* with a shared cache don't both create the same privkey: */
lock = lock_filename(EPHEMERAL_PRIVKEY_FILENAME ".lock");
if (lock_is_invalid(lock)) {
kr_log_error(TLS, "unable to lock lockfile " EPHEMERAL_PRIVKEY_FILENAME ".lock\n");
goto done;
}
if ((err = gnutls_x509_privkey_init (&privkey)) < 0) {
kr_log_error(TLS, "gnutls_x509_privkey_init() failed: %d (%s)\n",
err, gnutls_strerror_name(err));
goto done;
}
/* read from cache file (we assume that we've chdir'ed
* already, so we're just looking for the file in the
* cachedir. */
datafd = open(EPHEMERAL_PRIVKEY_FILENAME, O_RDONLY);
if (datafd != -1) {
struct stat stat;
ssize_t bytes_read;
if (fstat(datafd, &stat)) {
kr_log_error(TLS, "unable to stat ephemeral private key " EPHEMERAL_PRIVKEY_FILENAME "\n");
goto bad_data;
}
data.data = gnutls_malloc(stat.st_size);
if (data.data == NULL) {
kr_log_error(TLS, "unable to allocate memory for reading ephemeral private key\n");
goto bad_data;
}
data.size = stat.st_size;
bytes_read = read(datafd, data.data, stat.st_size);
if (bytes_read < 0 || bytes_read != stat.st_size) {
kr_log_error(TLS, "unable to read ephemeral private key\n");
goto bad_data;
}
if ((err = gnutls_x509_privkey_import (privkey, &data, GNUTLS_X509_FMT_PEM)) < 0) {
kr_log_error(TLS, "gnutls_x509_privkey_import() failed: %d (%s)\n",
err, gnutls_strerror_name(err));
/* goto bad_data; */
bad_data:
close(datafd);
datafd = -1;
}
if (data.data != NULL) {
gnutls_free(data.data);
data.data = NULL;
}
}
if (datafd == -1) {
/* if loading failed, then generate ... */
#if GNUTLS_VERSION_NUMBER >= 0x030500
if ((err = gnutls_x509_privkey_generate(privkey, GNUTLS_PK_ECDSA, GNUTLS_CURVE_TO_BITS(GNUTLS_ECC_CURVE_SECP256R1), 0)) < 0) {
#else
if ((err = gnutls_x509_privkey_generate(privkey, GNUTLS_PK_RSA, gnutls_sec_param_to_pk_bits(GNUTLS_PK_RSA, GNUTLS_SEC_PARAM_MEDIUM), 0)) < 0) {
#endif
kr_log_error(TLS, "gnutls_x509_privkey_init() failed: %d (%s)\n",
err, gnutls_strerror_name(err));
gnutls_x509_privkey_deinit(privkey);
goto done;
}
/* ... and save */
kr_log_info(TLS, "Stashing ephemeral private key in " EPHEMERAL_PRIVKEY_FILENAME "\n");
if ((err = gnutls_x509_privkey_export2(privkey, GNUTLS_X509_FMT_PEM, &data)) < 0) {
kr_log_error(TLS, "gnutls_x509_privkey_export2() failed: %d (%s), not storing\n",
err, gnutls_strerror_name(err));
} else {
datafd = open(EPHEMERAL_PRIVKEY_FILENAME, O_WRONLY|O_CREAT, 0600);
if (datafd == -1) {
kr_log_error(TLS, "failed to open " EPHEMERAL_PRIVKEY_FILENAME " to store the ephemeral key\n");
} else {
ssize_t bytes_written;
bytes_written = write(datafd, data.data, data.size);
if (bytes_written != data.size)
kr_log_error(TLS, "failed to write %d octets to "
EPHEMERAL_PRIVKEY_FILENAME
" (%zd written)\n",
data.size, bytes_written);
}
}
}
done:
lock_unlock(&lock, EPHEMERAL_PRIVKEY_FILENAME ".lock");
if (datafd != -1) {
close(datafd);
}
if (data.data != NULL) {
gnutls_free(data.data);
}
return privkey;
}
static gnutls_x509_crt_t get_ephemeral_cert(gnutls_x509_privkey_t privkey, const char *servicename, time_t invalid_before, time_t valid_until)
{
gnutls_x509_crt_t cert = NULL;
int err;
/* need a random buffer of bytes */
uint8_t serial[16];
gnutls_rnd(GNUTLS_RND_NONCE, serial, sizeof(serial));
/* clear the left-most bit to avoid signedness confusion: */
serial[0] &= 0x7f;
size_t namelen = strlen(servicename);
#define gtx(fn, ...) \
if ((err = fn ( __VA_ARGS__ )) != GNUTLS_E_SUCCESS) { \
kr_log_error(TLS, #fn "() failed: %d (%s)\n", \
err, gnutls_strerror_name(err)); \
goto bad; }
gtx(gnutls_x509_crt_init, &cert);
gtx(gnutls_x509_crt_set_activation_time, cert, invalid_before);
gtx(gnutls_x509_crt_set_ca_status, cert, 0);
gtx(gnutls_x509_crt_set_expiration_time, cert, valid_until);
gtx(gnutls_x509_crt_set_key, cert, privkey);
gtx(gnutls_x509_crt_set_key_purpose_oid, cert, GNUTLS_KP_TLS_WWW_CLIENT, 0);
gtx(gnutls_x509_crt_set_key_purpose_oid, cert, GNUTLS_KP_TLS_WWW_SERVER, 0);
gtx(gnutls_x509_crt_set_key_usage, cert, GNUTLS_KEY_DIGITAL_SIGNATURE);
gtx(gnutls_x509_crt_set_serial, cert, serial, sizeof(serial));
gtx(gnutls_x509_crt_set_subject_alt_name, cert, GNUTLS_SAN_DNSNAME, servicename, namelen, GNUTLS_FSAN_SET);
gtx(gnutls_x509_crt_set_dn_by_oid,cert, GNUTLS_OID_X520_COMMON_NAME, 0, servicename, namelen);
gtx(gnutls_x509_crt_set_version, cert, 3);
gtx(gnutls_x509_crt_sign2,cert, cert, privkey, GNUTLS_DIG_SHA256, 0); /* self-sign, since it doesn't look like we can just stub-sign */
#undef gtx
return cert;
bad:
gnutls_x509_crt_deinit(cert);
return NULL;
}
/*! Generate new ephemeral TLS credentials. */
struct tls_credentials * tls_get_ephemeral_credentials(void)
{
struct tls_credentials *creds = NULL;
gnutls_x509_privkey_t privkey = NULL;
gnutls_x509_crt_t cert = NULL;
int err;
time_t now = time(NULL);
creds = calloc(1, sizeof(*creds));
if (!creds) {
kr_log_error(TLS, "failed to allocate memory for ephemeral credentials\n");
return NULL;
}
if ((err = gnutls_certificate_allocate_credentials(&(creds->credentials))) < 0) {
kr_log_error(TLS, "failed to allocate memory for ephemeral credentials\n");
goto failure;
}
creds->valid_until = now + EPHEMERAL_CERT_EXPIRATION_SECONDS;
creds->ephemeral_servicename = strdup(engine_get_hostname());
if (creds->ephemeral_servicename == NULL) {
kr_log_error(TLS, "could not get server's hostname, using '" INVALID_HOSTNAME "' instead\n");
if ((creds->ephemeral_servicename = strdup(INVALID_HOSTNAME)) == NULL) {
kr_log_error(TLS, "failed to allocate memory for ephemeral credentials\n");
goto failure;
}
}
if ((privkey = get_ephemeral_privkey()) == NULL) {
goto failure;
}
if ((cert = get_ephemeral_cert(privkey, creds->ephemeral_servicename, now - ((time_t)60 * 15), creds->valid_until)) == NULL) {
goto failure;
}
if ((err = gnutls_certificate_set_x509_key(creds->credentials, &cert, 1, privkey)) < 0) {
kr_log_error(TLS, "failed to set up ephemeral credentials\n");
goto failure;
}
gnutls_x509_privkey_deinit(privkey);
gnutls_x509_crt_deinit(cert);
return creds;
failure:
gnutls_x509_privkey_deinit(privkey);
gnutls_x509_crt_deinit(cert);
tls_credentials_free(creds);
return NULL;
}
/* Copyright (C) CZ.NIC, z.s.p.o. <knot-resolver@labs.nic.cz>
* SPDX-License-Identifier: GPL-3.0-or-later
*/
#include <inttypes.h>
#include <stdbool.h>
#include <stdlib.h>
#include <string.h>
#include <sys/time.h>
#include <gnutls/gnutls.h>
#include <gnutls/crypto.h>
#include <uv.h>
#include "lib/utils.h"
/* Style: "local/static" identifiers are usually named tst_* */
/** The number of seconds between synchronized rotation of TLS session ticket key. */
#define TST_KEY_LIFETIME 4096
/** Value from gnutls:lib/ext/session_ticket.c
* Beware: changing this needs to change the hashing implementation. */
#define SESSION_KEY_SIZE 64
/** Compile-time support for setting the secret. */
/* This is not secure with TLS <= 1.2 but TLS 1.3 and secure configuration
* is not available in GnuTLS yet. See https://gitlab.com/gnutls/gnutls/issues/477 */
#define TLS_SESSION_RESUMPTION_SYNC (GNUTLS_VERSION_NUMBER >= 0x030603)
#if TLS_SESSION_RESUMPTION_SYNC
#define TST_HASH GNUTLS_DIG_SHA3_512
#else
#define TST_HASH abort()
#endif
/** Fields are internal to tst_key_* functions. */
typedef struct tls_session_ticket_ctx {
uv_timer_t timer; /**< timer for rotation of the key */
unsigned char key[SESSION_KEY_SIZE]; /**< the key itself */
bool has_secret; /**< false -> key is random for each epoch */
uint16_t hash_len; /**< length of `hash_data` */
char hash_data[]; /**< data to hash to obtain `key`;
* it's `time_t epoch` and then the secret string */
} tst_ctx_t;
/** Check invariants, based on gnutls version. */
static bool tst_key_invariants(void)
{
static int result = 0; /*< cache for multiple invocations */
if (result) return result > 0;
bool ok = true;
#if TLS_SESSION_RESUMPTION_SYNC
/* SHA3-512 output size may never change, but let's check it anyway :-) */
ok = ok && gnutls_hash_get_len(TST_HASH) == SESSION_KEY_SIZE;
#endif
/* The ticket key size might change in a different gnutls version. */
gnutls_datum_t key = { 0, 0 };
ok = ok && gnutls_session_ticket_key_generate(&key) == 0
&& key.size == SESSION_KEY_SIZE;
free(key.data);
result = ok ? 1 : -1;
return ok;
}
/** Create the internal structures and copy the secret. Beware: secret must be kept secure. */
static tst_ctx_t * tst_key_create(const char *secret, size_t secret_len, uv_loop_t *loop)
{
const size_t hash_len = sizeof(time_t) + secret_len;
if (kr_fails_assert(!secret_len || (secret && hash_len >= secret_len && hash_len <= UINT16_MAX))) {
return NULL;
/* reasonable secret_len is best enforced in config API */
}
if (kr_fails_assert(tst_key_invariants()))
return NULL;
#if !TLS_SESSION_RESUMPTION_SYNC
if (secret_len) {
kr_log_error(TLS, "session ticket: secrets were not enabled at compile-time (your GnuTLS version is not supported)\n");
return NULL; /* ENOTSUP */
}
#endif
tst_ctx_t *ctx = malloc(sizeof(*ctx) + hash_len); /* can be slightly longer */
if (!ctx) return NULL;
ctx->has_secret = secret_len > 0;
ctx->hash_len = hash_len;
if (secret_len) {
memcpy(ctx->hash_data + sizeof(time_t), secret, secret_len);
}
if (uv_timer_init(loop, &ctx->timer) != 0) {
free(ctx);
return NULL;
}
ctx->timer.data = ctx;
return ctx;
}
/** Random variant of secret rotation: generate into key_tmp and copy. */
static int tst_key_get_random(tst_ctx_t *ctx)
{
gnutls_datum_t key_tmp = { NULL, 0 };
int err = gnutls_session_ticket_key_generate(&key_tmp);
if (err) return kr_error(err);
if (kr_fails_assert(key_tmp.size == SESSION_KEY_SIZE))
return kr_error(EFAULT);
memcpy(ctx->key, key_tmp.data, SESSION_KEY_SIZE);
gnutls_memset(key_tmp.data, 0, SESSION_KEY_SIZE);
free(key_tmp.data);
return kr_ok();
}
/** Recompute the session ticket key, if epoch has changed or forced. */
static int tst_key_update(tst_ctx_t *ctx, time_t epoch, bool force_update)
{
if (kr_fails_assert(ctx && ctx->hash_len >= sizeof(epoch)))
return kr_error(EINVAL);
/* documented limitation: time_t and endianness must match
* on instances sharing a secret */
if (!force_update && memcmp(ctx->hash_data, &epoch, sizeof(epoch)) == 0) {
return kr_ok(); /* we are up to date */
}
memcpy(ctx->hash_data, &epoch, sizeof(epoch));
if (!ctx->has_secret) {
return tst_key_get_random(ctx);
}
/* Otherwise, deterministic variant of secret rotation, if supported. */
#if !TLS_SESSION_RESUMPTION_SYNC
kr_assert(!ENOTSUP);
return kr_error(ENOTSUP);
#else
int err = gnutls_hash_fast(TST_HASH, ctx->hash_data,
ctx->hash_len, ctx->key);
return err == 0 ? kr_ok() : kr_error(err);
#endif
}
/** Free all resources of the key (securely). */
static void tst_key_destroy(uv_handle_t *timer)
{
if (kr_fails_assert(timer))
return;
tst_ctx_t *ctx = timer->data;
if (kr_fails_assert(ctx))
return;
gnutls_memset(ctx, 0, offsetof(tst_ctx_t, hash_data) + ctx->hash_len);
free(ctx);
}
static void tst_key_check(uv_timer_t *timer, bool force_update);
static void tst_timer_callback(uv_timer_t *timer)
{
tst_key_check(timer, false);
}
/** Update the ST key if needed and reschedule itself via the timer. */
static void tst_key_check(uv_timer_t *timer, bool force_update)
{
tst_ctx_t *stst = (tst_ctx_t *)timer->data;
/* Compute the current epoch. */
struct timeval now;
if (gettimeofday(&now, NULL)) {
kr_log_error(TLS, "session ticket: gettimeofday failed, %s\n",
strerror(errno));
return;
}
uv_update_time(timer->loop); /* to have sync. between real and mono time */
const time_t epoch = now.tv_sec / TST_KEY_LIFETIME;
/* Update the key; new sessions will fetch it from the location.
* Old ones hopefully can't get broken by that; documentation
* for gnutls_session_ticket_enable_server() doesn't say. */
int err = tst_key_update(stst, epoch, force_update);
if (err) {
kr_log_error(TLS, "session ticket: failed rotation, %s\n",
kr_strerror(err));
if (kr_fails_assert(err != kr_error(EINVAL)))
return;
}
/* Reschedule. */
const time_t tv_sec_next = (epoch + 1) * TST_KEY_LIFETIME;
const uint64_t ms_until_second = 1000 - (now.tv_usec + 501) / 1000;
const uint64_t remain_ms = (tv_sec_next - now.tv_sec - 1) * (uint64_t)1000
+ ms_until_second + 1;
/* ^ +1 because we don't want to wake up half a millisecond before the epoch! */
if (kr_fails_assert(remain_ms < ((uint64_t)TST_KEY_LIFETIME + 1 /*rounding tolerance*/) * 1000))
return;
kr_log_debug(TLS, "session ticket: epoch %"PRIu64
", scheduling rotation check in %"PRIu64" ms\n",
(uint64_t)epoch, remain_ms);
err = uv_timer_start(timer, &tst_timer_callback, remain_ms, 0);
if (kr_fails_assert(err == 0)) {
kr_log_error(TLS, "session ticket: failed to schedule, %s\n",
uv_strerror(err));
return;
}
}
/* Implementation for prototypes from ./tls.h */
void tls_session_ticket_enable(struct tls_session_ticket_ctx *ctx, gnutls_session_t session)
{
if (kr_fails_assert(ctx && session))
return;
const gnutls_datum_t gd = {
.size = SESSION_KEY_SIZE,
.data = ctx->key,
};
int err = gnutls_session_ticket_enable_server(session, &gd);
if (err) {
kr_log_error(TLS, "failed to enable session tickets: %s (%d)\n",
gnutls_strerror_name(err), err);
/* but continue without tickets */
}
}
tst_ctx_t * tls_session_ticket_ctx_create(uv_loop_t *loop, const char *secret,
size_t secret_len)
{
if (kr_fails_assert(loop && (!secret_len || secret)))
return NULL;
#if GNUTLS_VERSION_NUMBER < 0x030500
/* We would need different SESSION_KEY_SIZE; avoid an error. */
return NULL;
#endif
tst_ctx_t *ctx = tst_key_create(secret, secret_len, loop);
if (ctx) {
tst_key_check(&ctx->timer, true);
}
return ctx;
}
void tls_session_ticket_ctx_destroy(tst_ctx_t *ctx)
{
if (ctx == NULL) {
return;
}
uv_close((uv_handle_t *)&ctx->timer, &tst_key_destroy);
}
/* Copyright (C) CZ.NIC, z.s.p.o. <knot-resolver@labs.nic.cz>
* SPDX-License-Identifier: GPL-3.0-or-later
*/
#include "kresconfig.h"
#include "daemon/udp_queue.h"
#include "daemon/session2.h"
#include "lib/generic/array.h"
#include "lib/utils.h"
struct qr_task;
#include <sys/socket.h>
#if !ENABLE_SENDMMSG
int udp_queue_init_global(uv_loop_t *loop)
{
return 0;
}
/* Appease the linker in case this unused call isn't optimized out. */
void udp_queue_push(int fd, const struct sockaddr *sa, char *buf, size_t buf_len,
udp_queue_cb cb, void *baton)
{
abort();
}
void udp_queue_send_all(void)
{
}
#else
/* LATER: it might be useful to have this configurable during runtime,
* but the structures below would have to change a little (broken up). */
#define UDP_QUEUE_LEN 64
/** A queue of up to UDP_QUEUE_LEN messages, meant for the same socket. */
typedef struct {
int len; /**< The number of messages in the queue: 0..UDP_QUEUE_LEN */
struct mmsghdr msgvec[UDP_QUEUE_LEN]; /**< Parameter for sendmmsg() */
struct {
udp_queue_cb cb;
void *cb_baton;
struct iovec msg_iov[1]; /**< storage for .msgvec[i].msg_iov */
} items[UDP_QUEUE_LEN];
} udp_queue_t;
static udp_queue_t * udp_queue_create(void)
{
udp_queue_t *q = calloc(1, sizeof(*q));
kr_require(q != NULL);
for (int i = 0; i < UDP_QUEUE_LEN; ++i) {
struct msghdr *mhi = &q->msgvec[i].msg_hdr;
/* These shall remain always the same. */
mhi->msg_iov = q->items[i].msg_iov;
mhi->msg_iovlen = 1;
/* msg_name and msg_namelen will be per-call,
* and the rest is OK to remain zeroed all the time. */
}
return q;
}
/** Global state for udp_queue_*. Note: we never free the pointed-to memory. */
struct state {
/** Singleton map: fd -> udp_queue_t, as a simple array of pointers. */
udp_queue_t **udp_queues;
int udp_queues_len;
/** List of FD numbers that might have a non-empty queue. */
array_t(int) waiting_fds;
uv_check_t check_handle;
};
static struct state state = {0};
/** Empty the given queue. The queue is assumed to exist (but may be empty). */
static void udp_queue_send(int fd)
{
udp_queue_t *const q = state.udp_queues[fd];
if (!q->len) return;
int sent_len = sendmmsg(fd, q->msgvec, q->len, 0);
/* ATM we don't really do anything about failures. */
int err = sent_len < 0 ? errno : EAGAIN /* unknown error, really */;
for (int i = 0; i < q->len; ++i) {
if (q->items[i].cb)
q->items[i].cb(i < sent_len ? 0 : err, q->items[i].cb_baton);
}
q->len = 0;
}
/** Send all queued packets. */
void udp_queue_send_all(void)
{
for (int i = 0; i < state.waiting_fds.len; ++i) {
udp_queue_send(state.waiting_fds.at[i]);
}
state.waiting_fds.len = 0;
}
/** Periodical callback to send all queued packets. */
static void udp_queue_check(uv_check_t *handle)
{
udp_queue_send_all();
}
int udp_queue_init_global(uv_loop_t *loop)
{
int ret = uv_check_init(loop, &state.check_handle);
if (!ret) ret = uv_check_start(&state.check_handle, udp_queue_check);
return ret;
}
void udp_queue_push(int fd, const struct sockaddr *sa, char *buf, size_t buf_len,
udp_queue_cb cb, void *baton)
{
if (fd < 0) {
kr_log_error(SYSTEM, "ERROR: called udp_queue_push(fd = %d, ...)\n", fd);
abort();
}
/* Get a valid correct queue. */
if (fd >= state.udp_queues_len) {
const int new_len = fd + 1;
state.udp_queues = realloc(state.udp_queues, // NOLINT(bugprone-suspicious-realloc-usage): we just abort() below, so it's fine
sizeof(state.udp_queues[0]) * new_len); // NOLINT(bugprone-sizeof-expression): false-positive
if (!state.udp_queues) abort();
memset(state.udp_queues + state.udp_queues_len, 0,
sizeof(state.udp_queues[0]) * (new_len - state.udp_queues_len)); // NOLINT(bugprone-sizeof-expression): false-positive
state.udp_queues_len = new_len;
}
if (unlikely(state.udp_queues[fd] == NULL))
state.udp_queues[fd] = udp_queue_create();
udp_queue_t *const q = state.udp_queues[fd];
/* Append to the queue */
q->msgvec[q->len].msg_hdr.msg_name = (void *)sa;
q->msgvec[q->len].msg_hdr.msg_namelen = kr_sockaddr_len(sa);
q->items[q->len].cb = cb;
q->items[q->len].cb_baton = baton;
q->items[q->len].msg_iov[0] = (struct iovec){
.iov_base = buf,
.iov_len = buf_len,
};
if (q->len == 0)
array_push(state.waiting_fds, fd);
++(q->len);
if (q->len >= UDP_QUEUE_LEN) {
kr_assert(q->len == UDP_QUEUE_LEN);
udp_queue_send(fd);
/* We don't need to search state.waiting_fds;
* anyway, it's more efficient to let the hook do that. */
}
}
#endif
/* Copyright (C) CZ.NIC, z.s.p.o. <knot-resolver@labs.nic.cz>
* SPDX-License-Identifier: GPL-3.0-or-later
*/
#pragma once
#include <uv.h>
struct kr_request;
struct qr_task;
typedef void (*udp_queue_cb)(int status, void *baton);
/** Initialize the global state for udp_queue. */
int udp_queue_init_global(uv_loop_t *loop);
/** Send req->answer via UDP, possibly not immediately. */
void udp_queue_push(int fd, const struct sockaddr *sa, char *buf, size_t buf_len,
udp_queue_cb cb, void *baton);
/** Send all queued packets immediatelly. */
void udp_queue_send_all(void);
/* Copyright (C) 2014 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
/* Copyright (C) CZ.NIC, z.s.p.o. <knot-resolver@labs.nic.cz>
* SPDX-License-Identifier: GPL-3.0-or-later
*/
#include "kresconfig.h"
#include "daemon/worker.h"
#include <uv.h>
#include <lua.h>
#include <lauxlib.h>
#include <libknot/packet/pkt.h>
#include <libknot/internal/net.h>
#include <libknot/descriptor.h>
#include <contrib/cleanup.h>
#include <contrib/ucw/lib.h>
#include <contrib/ucw/mempool.h>
#if defined(__GLIBC__) && defined(_GNU_SOURCE)
#include <malloc.h>
#endif
#include <sys/types.h>
#include <unistd.h>
#include <gnutls/gnutls.h>
#include "daemon/worker.h"
#if ENABLE_XDP
#include <libknot/xdp/xdp.h>
#endif
#include "daemon/bindings/api.h"
#include "daemon/engine.h"
#include "daemon/io.h"
#include "daemon/proxyv2.h"
#include "daemon/session2.h"
#include "daemon/tls.h"
#include "lib/cache/util.h" /* packet_ttl */
#include "lib/layer.h"
#include "lib/layer/iterate.h" /* kr_response_classify */
#include "lib/utils.h"
#include "daemon/defer.h"
/* @internal IO request entry. */
struct ioreq
/* Magic defaults for the worker. */
#ifndef MAX_PIPELINED
#define MAX_PIPELINED 100
#endif
#define MAX_DGRAM_LEN UINT16_MAX
#define VERBOSE_MSG(qry, ...) kr_log_q(qry, WORKER, __VA_ARGS__)
/** Client request state. */
struct request_ctx
{
union {
uv_udp_t udp;
uv_tcp_t tcp;
uv_udp_send_t send;
uv_write_t write;
uv_connect_t connect;
} as;
struct kr_request req;
struct qr_task *task;
struct {
/** NULL if the request didn't come over network. */
struct session2 *session;
/** Requestor's address; separate because of UDP session "sharing". */
union kr_sockaddr addr;
/** Request communication address; if not from a proxy, same as addr. */
union kr_sockaddr comm_addr;
/** Local address. For AF_XDP we couldn't use session's,
* as the address might be different every time. */
union kr_sockaddr dst_addr;
/** Router's MAC address for XDP. */
ethaddr_t eth_from;
/** Our MAC address for XDP. */
ethaddr_t eth_to;
/** Whether XDP was used. */
bool xdp : 1;
} source;
};
static inline struct ioreq *ioreq_take(struct worker_ctx *worker)
/** Query resolution task. */
struct qr_task
{
struct ioreq *req = NULL;
if (worker->ioreqs.len > 0) {
req = array_tail(worker->ioreqs);
array_pop(worker->ioreqs);
} else {
req = malloc(sizeof(*req));
struct request_ctx *ctx;
knot_pkt_t *pktbuf;
qr_tasklist_t waiting;
struct session2 *pending[MAX_PENDING];
uint16_t pending_count;
uint16_t timeouts;
uint16_t iter_count;
uint32_t refs;
bool finished : 1;
bool leading : 1;
uint64_t creation_time;
uint64_t send_time;
uint64_t recv_time;
struct kr_transport *transport;
};
/* Convenience macros */
#define qr_task_ref(task) \
do { ++(task)->refs; } while(0)
#define qr_task_unref(task) \
do { \
if (task) \
kr_require((task)->refs > 0); \
if ((task) && --(task)->refs == 0) \
qr_task_free((task)); \
} while (0)
struct pl_dns_stream_sess_data {
struct protolayer_data h;
bool single : 1; /**< True: Stream only allows a single packet */
bool produced : 1; /**< True: At least one packet has been produced */
bool connected : 1; /**< True: The stream is connected */
bool half_closed : 1; /**< True: EOF was received, the stream is half-closed */
};
/* Forward decls */
static void qr_task_free(struct qr_task *task);
static int qr_task_step(struct qr_task *task,
const struct sockaddr *packet_source,
knot_pkt_t *packet);
static int qr_task_send(struct qr_task *task, struct session2 *session,
const struct comm_info *comm, knot_pkt_t *pkt);
static int qr_task_finalize(struct qr_task *task, int state);
static void qr_task_complete(struct qr_task *task);
static int worker_add_tcp_connected(const struct sockaddr* addr, struct session2 *session);
static int worker_del_tcp_connected(const struct sockaddr* addr);
static struct session2* worker_find_tcp_connected(const struct sockaddr* addr);
static int worker_add_tcp_waiting(const struct sockaddr* addr,
struct session2 *session);
static int worker_del_tcp_waiting(const struct sockaddr* addr);
static struct session2* worker_find_tcp_waiting(const struct sockaddr* addr);
static void subreq_finalize(struct qr_task *task, const struct sockaddr *packet_source, knot_pkt_t *pkt);
struct worker_ctx the_worker_value; /**< Static allocation is suitable for the singleton. */
struct worker_ctx *the_worker = NULL;
static inline void defer_sample_task(const struct qr_task *task)
{
if (task && task->ctx->source.session) {
defer_sample_addr(&task->ctx->source.addr, task->ctx->source.session->stream);
defer_sample_state.price_factor16 = task->ctx->req.qsource.price_factor16;
}
return req;
}
static inline void ioreq_release(struct worker_ctx *worker, struct ioreq *req)
/*! @internal Create a UDP/TCP handle for an outgoing AF_INET* connection.
* socktype is SOCK_* */
static struct session2 *ioreq_spawn(int socktype, sa_family_t family,
enum kr_proto grp,
struct protolayer_data_param *layer_param,
size_t layer_param_count)
{
if (!req || worker->ioreqs.len < 4 * MP_FREELIST_SIZE) {
array_push(worker->ioreqs, req);
bool precond = (socktype == SOCK_DGRAM || socktype == SOCK_STREAM)
&& (family == AF_INET || family == AF_INET6);
if (kr_fails_assert(precond)) {
kr_log_debug(WORKER, "ioreq_spawn: pre-condition failed\n");
return NULL;
}
/* Create connection for iterative query */
struct session2 *s;
int ret = io_create(the_worker->loop, &s, socktype, family, grp,
layer_param, layer_param_count, true);
if (ret) {
if (ret == UV_EMFILE) {
the_worker->too_many_open = true;
the_worker->rconcurrent_highwatermark = the_worker->stats.rconcurrent;
}
return NULL;
}
/* Bind to outgoing address, according to IP v4/v6. */
union kr_sockaddr *addr;
if (family == AF_INET) {
addr = (union kr_sockaddr *)&the_worker->out_addr4;
} else {
free(req);
addr = (union kr_sockaddr *)&the_worker->out_addr6;
}
if (addr->ip.sa_family != AF_UNSPEC) {
if (kr_fails_assert(addr->ip.sa_family == family)) {
session2_force_close(s);
return NULL;
}
if (socktype == SOCK_DGRAM) {
uv_udp_t *udp = (uv_udp_t *)session2_get_handle(s);
ret = uv_udp_bind(udp, &addr->ip, 0);
} else if (socktype == SOCK_STREAM){
uv_tcp_t *tcp = (uv_tcp_t *)session2_get_handle(s);
ret = uv_tcp_bind(tcp, &addr->ip, 0);
}
}
if (ret != 0) {
session2_force_close(s);
return NULL;
}
/* Connect or issue query datagram */
return s;
}
/** @internal Query resolution task. */
struct qr_task
static void ioreq_kill_pending(struct qr_task *task)
{
struct kr_request req;
struct worker_ctx *worker;
knot_pkt_t *pktbuf;
uv_req_t *ioreq;
uv_handle_t *iohandle;
uv_timer_t timeout;
struct {
union {
struct sockaddr_in ip4;
struct sockaddr_in6 ip6;
} addr;
uv_handle_t *handle;
} source;
uint16_t iter_count;
uint16_t flags;
};
for (uint16_t i = 0; i < task->pending_count; ++i) {
session2_kill_ioreq(task->pending[i], task);
}
task->pending_count = 0;
}
/* Forward decls */
static int qr_task_step(struct qr_task *task, knot_pkt_t *packet);
/** Get a mempool. */
static inline struct mempool *pool_borrow(void)
{
/* The implementation used to have extra caching layer,
* but it didn't work well. Now it's very simple. */
return mp_new((size_t)16 * 1024);
}
/** Return a mempool. */
static inline void pool_release(struct mempool *mp)
{
mp_delete(mp);
}
/** @internal Get singleton worker. */
static inline struct worker_ctx *get_worker(void)
/** Create a key for an outgoing subrequest: qname, qclass, qtype.
* @param key Destination buffer for key size, MUST be SUBREQ_KEY_LEN or larger.
* @return key length if successful or an error
*/
static const size_t SUBREQ_KEY_LEN = KR_RRKEY_LEN;
static int subreq_key(char *dst, knot_pkt_t *pkt)
{
return uv_default_loop()->data;
kr_require(pkt);
return kr_rrkey(dst, knot_pkt_qclass(pkt), knot_pkt_qname(pkt),
knot_pkt_qtype(pkt), knot_pkt_qtype(pkt));
}
static struct qr_task *qr_task_create(struct worker_ctx *worker, uv_handle_t *handle, knot_pkt_t *query, const struct sockaddr *addr)
#if ENABLE_XDP
static uint8_t *alloc_wire_cb(struct kr_request *req, uint16_t *maxlen)
{
/* How much can client handle? */
size_t answer_max = KNOT_WIRE_MIN_PKTSIZE;
size_t pktbuf_max = KR_EDNS_PAYLOAD;
if (!addr && handle) { /* TCP */
answer_max = KNOT_WIRE_MAX_PKTSIZE;
pktbuf_max = KNOT_WIRE_MAX_PKTSIZE;
} else if (knot_pkt_has_edns(query)) { /* EDNS */
answer_max = MAX(knot_edns_get_payload(query->opt_rr), KNOT_WIRE_MIN_PKTSIZE);
if (kr_fails_assert(maxlen))
return NULL;
struct request_ctx *ctx = (struct request_ctx *)req;
/* We know it's an AF_XDP socket; otherwise this CB isn't assigned. */
uv_handle_t *handle = session2_get_handle(ctx->source.session);
if (kr_fails_assert(handle->type == UV_POLL))
return NULL;
xdp_handle_data_t *xhd = handle->data;
knot_xdp_msg_t out;
bool ipv6 = ctx->source.comm_addr.ip.sa_family == AF_INET6;
int ret = knot_xdp_send_alloc(xhd->socket, ipv6 ? KNOT_XDP_MSG_IPV6 : 0, &out);
if (ret != KNOT_EOK) {
kr_assert(ret == KNOT_ENOMEM);
*maxlen = 0;
return NULL;
}
*maxlen = MIN(*maxlen, out.payload.iov_len);
return out.payload.iov_base;
}
static void free_wire(const struct request_ctx *ctx)
{
if (kr_fails_assert(ctx->req.alloc_wire_cb == alloc_wire_cb))
return;
knot_pkt_t *ans = ctx->req.answer;
if (unlikely(ans == NULL)) /* dropped */
return;
if (likely(ans->wire == NULL)) /* sent most likely */
return;
if (!ctx->source.session)
return;
/* We know it's an AF_XDP socket; otherwise alloc_wire_cb isn't assigned. */
uv_handle_t *handle = session2_get_handle(ctx->source.session);
if (!handle || kr_fails_assert(handle->type == UV_POLL))
return;
xdp_handle_data_t *xhd = handle->data;
/* Freeing is done by sending an empty packet (the API won't really send it). */
knot_xdp_msg_t out;
out.payload.iov_base = ans->wire;
out.payload.iov_len = 0;
uint32_t sent = 0;
int ret = 0;
knot_xdp_send_free(xhd->socket, &out, 1);
kr_assert(ret == KNOT_EOK && sent == 0);
kr_log_debug(XDP, "freed unsent buffer, ret = %d\n", ret);
}
#endif
/* Helper functions for transport selection */
static inline bool is_tls_capable(struct sockaddr *address) {
tls_client_param_t *tls_entry = tls_client_param_get(
the_network->tls_client_params, address);
return tls_entry;
}
static inline bool is_tcp_connected(struct sockaddr *address) {
return worker_find_tcp_connected(address);
}
static inline bool is_tcp_waiting(struct sockaddr *address) {
return worker_find_tcp_waiting(address);
}
/* Recycle available mempool if possible */
mm_ctx_t pool = {
.ctx = NULL,
.alloc = (mm_alloc_t) mp_alloc
/** Create and initialize a request_ctx (on a fresh mempool).
*
* session and addr point to the source of the request, and they are NULL
* in case the request didn't come from network.
*/
static struct request_ctx *request_create(struct session2 *session,
struct comm_info *comm,
uint32_t uid)
{
knot_mm_t pool = {
.ctx = pool_borrow(),
.alloc = (knot_mm_alloc_t) mp_alloc
};
if (worker->pools.len > 0) {
pool.ctx = array_tail(worker->pools);
array_pop(worker->pools);
} else { /* No mempool on the freelist, create new one */
pool.ctx = mp_new (4 * CPU_PAGE_SIZE);
/* Create request context */
struct request_ctx *ctx = mm_calloc(&pool, 1, sizeof(*ctx));
if (!ctx) {
pool_release(pool.ctx);
return NULL;
}
/* Create resolution task */
struct engine *engine = worker->engine;
struct qr_task *task = mm_alloc(&pool, sizeof(*task));
if (!task) {
mp_delete(pool.ctx);
/* TODO Relocate pool to struct request */
if (session && kr_fails_assert(session->outgoing == false)) {
pool_release(pool.ctx);
return NULL;
}
/* Create packet buffers for answer and subrequests */
task->req.pool = pool;
knot_pkt_t *pktbuf = knot_pkt_new(NULL, pktbuf_max, &task->req.pool);
knot_pkt_t *answer = knot_pkt_new(NULL, answer_max, &task->req.pool);
if (!pktbuf || !answer) {
mp_delete(pool.ctx);
ctx->source.session = session;
if (comm && comm->xdp) {
#if ENABLE_XDP
if (kr_fails_assert(session)) {
pool_release(pool.ctx);
return NULL;
}
memcpy(ctx->source.eth_to, comm->eth_to, sizeof(ctx->source.eth_to));
memcpy(ctx->source.eth_from, comm->eth_from, sizeof(ctx->source.eth_from));
ctx->req.alloc_wire_cb = alloc_wire_cb;
#else
kr_assert(!EINVAL);
pool_release(pool.ctx);
return NULL;
#endif
}
task->req.answer = answer;
task->pktbuf = pktbuf;
task->ioreq = NULL;
task->iohandle = NULL;
task->iter_count = 0;
task->flags = 0;
task->worker = worker;
task->source.handle = handle;
uv_timer_init(worker->loop, &task->timeout);
task->timeout.data = task;
if (addr) {
memcpy(&task->source.addr, addr, sockaddr_len(addr));
} else {
task->source.addr.ip4.sin_family = AF_UNSPEC;
struct kr_request *req = &ctx->req;
req->pool = pool;
req->vars_ref = LUA_NOREF;
req->uid = uid;
req->qsource.comm_flags.xdp = comm && comm->xdp;
req->qsource.price_factor16 = 1 << 16; // meaning *1.0
kr_request_set_extended_error(req, KNOT_EDNS_EDE_NONE, NULL);
array_init(req->qsource.headers);
if (session) {
kr_require(comm);
const struct sockaddr *src_addr = comm->src_addr;
const struct sockaddr *comm_addr = comm->comm_addr;
const struct sockaddr *dst_addr = comm->dst_addr;
const struct proxy_result *proxy = comm->proxy;
req->qsource.stream_id = -1;
session2_init_request(session, req);
req->qsource.flags = req->qsource.comm_flags;
if (proxy) {
req->qsource.flags.tcp = proxy->protocol == SOCK_STREAM;
req->qsource.flags.tls = proxy->has_tls;
}
/* We need to store a copy of peer address. */
memcpy(&ctx->source.addr.ip, src_addr, kr_sockaddr_len(src_addr));
req->qsource.addr = &ctx->source.addr.ip;
if (!comm_addr)
comm_addr = src_addr;
memcpy(&ctx->source.comm_addr.ip, comm_addr, kr_sockaddr_len(comm_addr));
req->qsource.comm_addr = &ctx->source.comm_addr.ip;
if (!dst_addr) /* We wouldn't have to copy in this case, but for consistency. */
dst_addr = session2_get_sockname(session);
memcpy(&ctx->source.dst_addr.ip, dst_addr, kr_sockaddr_len(dst_addr));
req->qsource.dst_addr = &ctx->source.dst_addr.ip;
}
/* Start resolution */
kr_resolve_begin(&task->req, &engine->resolver, answer);
worker->stats.concurrent += 1;
return task;
req->selection_context.is_tls_capable = is_tls_capable;
req->selection_context.is_tcp_connected = is_tcp_connected;
req->selection_context.is_tcp_waiting = is_tcp_waiting;
array_init(req->selection_context.forwarding_targets);
array_reserve_mm(req->selection_context.forwarding_targets, 1, kr_memreserve, &req->pool);
the_worker->stats.rconcurrent += 1;
return ctx;
}
static void qr_task_free(uv_handle_t *handle)
/** More initialization, related to the particular incoming query/packet. */
static int request_start(struct request_ctx *ctx, knot_pkt_t *query)
{
struct qr_task *task = handle->data;
/* Return handle to the event loop in case
* it was exclusively taken by this task. */
if (task->source.handle && !uv_has_ref(task->source.handle)) {
uv_ref(task->source.handle);
io_start_read(task->source.handle);
if (kr_fails_assert(query && ctx))
return kr_error(EINVAL);
struct kr_request *req = &ctx->req;
req->qsource.size = query->size;
if (knot_pkt_has_tsig(query)) {
req->qsource.size += query->tsig_wire.len;
}
/* Return mempool to ring or free it if it's full */
struct worker_ctx *worker = task->worker;
void *mp_context = task->req.pool.ctx;
if (worker->pools.len < MP_FREELIST_SIZE) {
mp_flush(mp_context);
array_push(worker->pools, mp_context);
} else {
mp_delete(mp_context);
knot_pkt_t *pkt = knot_pkt_new(NULL, req->qsource.size, &req->pool);
if (!pkt) {
return kr_error(ENOMEM);
}
/* Decommit memory every once in a while */
static int mp_delete_count = 0;
if (++mp_delete_count == 100000) {
lua_gc(worker->engine->L, LUA_GCCOLLECT, 0);
#if defined(__GLIBC__) && defined(_GNU_SOURCE)
malloc_trim(0);
#endif
mp_delete_count = 0;
int ret = knot_pkt_copy(pkt, query);
if (ret != KNOT_EOK && ret != KNOT_ETRAIL) {
return kr_error(ENOMEM);
}
req->qsource.packet = pkt;
/* Update stats */
worker->stats.concurrent -= 1;
/* Start resolution */
kr_resolve_begin(req, the_resolver);
the_worker->stats.queries += 1;
return kr_ok();
}
static void qr_task_timeout(uv_timer_t *req)
static void request_free(struct request_ctx *ctx)
{
struct qr_task *task = req->data;
if (!uv_is_closing((uv_handle_t *)req)) {
if (task->ioreq) { /* Invalidate pending IO request. */
task->ioreq->data = NULL;
}
qr_task_step(task, NULL);
/* Dereference any Lua vars table if exists */
if (ctx->req.vars_ref != LUA_NOREF) {
lua_State *L = the_engine->L;
/* Get worker variables table */
lua_rawgeti(L, LUA_REGISTRYINDEX, the_worker->vars_table_ref);
/* Get next free element (position 0) and store it under current reference (forming a list) */
lua_rawgeti(L, -1, 0);
lua_rawseti(L, -2, ctx->req.vars_ref);
/* Set current reference as the next free element */
lua_pushinteger(L, ctx->req.vars_ref);
lua_rawseti(L, -2, 0);
lua_pop(L, 1);
ctx->req.vars_ref = LUA_NOREF;
}
/* Free HTTP/2 headers for DoH requests. */
for(int i = 0; i < ctx->req.qsource.headers.len; i++) {
free(ctx->req.qsource.headers.at[i].name);
free(ctx->req.qsource.headers.at[i].value);
}
array_clear(ctx->req.qsource.headers);
/* Make sure to free XDP buffer in case it wasn't sent. */
if (ctx->req.alloc_wire_cb) {
#if ENABLE_XDP
free_wire(ctx);
#else
kr_assert(!EINVAL);
#endif
}
/* Return mempool to ring or free it if it's full */
pool_release(ctx->req.pool.ctx);
/* @note The 'task' is invalidated from now on. */
the_worker->stats.rconcurrent -= 1;
}
static int qr_task_on_send(struct qr_task *task, uv_handle_t *handle, int status)
static struct qr_task *qr_task_create(struct request_ctx *ctx)
{
if (task->req.state != KNOT_STATE_NOOP) {
if (status == 0 && handle) {
io_start_read(handle); /* Start reading answer */
}
} else { /* Finalize task */
uv_timer_stop(&task->timeout);
uv_close((uv_handle_t *)&task->timeout, qr_task_free);
/* Choose (initial) pktbuf size. As it is now, pktbuf can be used
* for UDP answers from upstream *and* from cache
* and for sending queries upstream */
uint16_t pktbuf_max = KR_EDNS_PAYLOAD;
const knot_rrset_t *opt_our = the_resolver->upstream_opt_rr;
if (opt_our) {
pktbuf_max = MAX(pktbuf_max, knot_edns_get_payload(opt_our));
}
return status;
/* Create resolution task */
struct qr_task *task = mm_calloc(&ctx->req.pool, 1, sizeof(*task));
if (!task) {
return NULL;
}
/* Create packet buffers for answer and subrequests */
knot_pkt_t *pktbuf = knot_pkt_new(NULL, pktbuf_max, &ctx->req.pool);
if (!pktbuf) {
mm_free(&ctx->req.pool, task);
return NULL;
}
pktbuf->size = 0;
task->ctx = ctx;
task->pktbuf = pktbuf;
array_init(task->waiting);
task->refs = 0;
kr_assert(ctx->task == NULL);
ctx->task = task;
/* Make the primary reference to task. */
qr_task_ref(task);
task->creation_time = kr_now();
the_worker->stats.concurrent += 1;
return task;
}
static void on_close(uv_handle_t *handle)
/* This is called when the task refcount is zero, free memory. */
static void qr_task_free(struct qr_task *task)
{
struct worker_ctx *worker = get_worker();
ioreq_release(worker, (struct ioreq *)handle);
struct request_ctx *ctx = task->ctx;
if (kr_fails_assert(ctx))
return;
kr_require(ctx->task == NULL);
request_free(ctx);
/* Update stats */
the_worker->stats.concurrent -= 1;
}
static void on_send(uv_udp_send_t *req, int status)
/*@ Register new qr_task within session. */
static int qr_task_register(struct qr_task *task, struct session2 *session)
{
struct worker_ctx *worker = get_worker();
struct qr_task *task = req->data;
if (task) {
qr_task_on_send(task, (uv_handle_t *)req->handle, status);
task->ioreq = NULL;
if (kr_fails_assert(!session->outgoing && session->stream))
return kr_error(EINVAL);
session2_tasklist_add(session, task);
struct request_ctx *ctx = task->ctx;
if (kr_fails_assert(ctx && (ctx->source.session == NULL || ctx->source.session == session)))
return kr_error(EINVAL);
ctx->source.session = session;
/* Soft-limit on parallel queries, there is no "slow down" RCODE
* that we could use to signalize to client, but we can stop reading,
* an in effect shrink TCP window size. To get more precise throttling,
* we would need to copy remainder of the unread buffer and reassemble
* when resuming reading. This is NYI. */
if (session2_tasklist_get_len(session) >= the_worker->tcp_pipeline_max &&
!session->throttled && !session->closing) {
session2_stop_read(session);
session->throttled = true;
}
ioreq_release(worker, (struct ioreq *)req);
return 0;
}
static void on_write(uv_write_t *req, int status)
static void qr_task_complete(struct qr_task *task)
{
struct worker_ctx *worker = get_worker();
struct qr_task *task = req->data;
if (task) {
qr_task_on_send(task, (uv_handle_t *)req->handle, status);
task->ioreq = NULL;
struct request_ctx *ctx = task->ctx;
/* Kill pending I/O requests */
ioreq_kill_pending(task);
kr_require(task->waiting.len == 0);
kr_require(task->leading == false);
struct session2 *s = ctx->source.session;
if (s) {
kr_require(!s->outgoing && session2_waitinglist_is_empty(s));
ctx->source.session = NULL;
session2_tasklist_del(s, task);
}
/* Release primary reference to task. */
if (ctx->task == task) {
ctx->task = NULL;
qr_task_unref(task);
}
ioreq_release(worker, (struct ioreq *)req);
}
static int qr_task_send(struct qr_task *task, uv_handle_t *handle, struct sockaddr *addr, knot_pkt_t *pkt)
/* This is called when we send subrequest / answer */
int qr_task_on_send(struct qr_task *task, struct session2 *s, int status)
{
int ret = 0;
if (!handle) {
return qr_task_on_send(task, handle, kr_error(EIO));
if (task->finished) {
kr_require(task->leading == false);
qr_task_complete(task);
}
struct ioreq *send_req = ioreq_take(task->worker);
if (!send_req) {
return qr_task_on_send(task, handle, kr_error(ENOMEM));
if (!s)
return status;
if (!s->stream && s->outgoing) {
// This should ensure that we are only dealing with our question to upstream
if (kr_fails_assert(!knot_wire_get_qr(task->pktbuf->wire)))
return status;
// start the timer
struct kr_query *qry = array_tail(task->ctx->req.rplan.pending);
if (kr_fails_assert(qry && task->transport))
return status;
size_t timeout = task->transport->timeout;
int ret = session2_timer_start(s, PROTOLAYER_EVENT_GENERAL_TIMEOUT,
timeout, 0);
/* Start next step with timeout, fatal if can't start a timer. */
if (ret != 0) {
subreq_finalize(task, &task->transport->address.ip, task->pktbuf);
qr_task_finalize(task, KR_STATE_FAIL);
}
}
/* Send using given protocol */
if (handle->type == UV_UDP) {
uv_buf_t buf = { (char *)pkt->wire, pkt->size };
send_req->as.send.data = task;
ret = uv_udp_send(&send_req->as.send, (uv_udp_t *)handle, &buf, 1, addr, &on_send);
} else {
uint16_t pkt_size = htons(pkt->size);
uv_buf_t buf[2] = {
{ (char *)&pkt_size, sizeof(pkt_size) },
{ (char *)pkt->wire, pkt->size }
};
send_req->as.write.data = task;
ret = uv_write(&send_req->as.write, (uv_stream_t *)handle, buf, 2, &on_write);
if (s->stream) {
if (status != 0) { // session probably not usable anymore; typically: ECONNRESET
const struct kr_request *req = &task->ctx->req;
if (kr_log_is_debug(WORKER, req)) {
const char *peer_str = NULL;
if (!s->outgoing) {
peer_str = "hidden"; // avoid logging downstream IPs
} else if (task->transport) {
peer_str = kr_straddr(&task->transport->address.ip);
}
if (!peer_str)
peer_str = "unknown"; // probably shouldn't happen
kr_log_req(req, 0, 0, WORKER,
"=> disconnected from '%s': %s\n",
peer_str, uv_strerror(status));
}
session2_force_close(s);
return status;
}
if (s->outgoing || s->closing)
return status;
if (s->throttled &&
session2_tasklist_get_len(s) < the_worker->tcp_pipeline_max/2) {
/* Start reading again if the session is throttled and
* the number of outgoing requests is below watermark. */
session2_start_read(s);
s->throttled = false;
}
}
if (ret == 0) {
task->ioreq = (uv_req_t *)send_req;
return status;
}
static void qr_task_wrap_finished(int status, struct session2 *session,
const struct comm_info *comm, void *baton)
{
struct qr_task *task = baton;
qr_task_on_send(task, session, status);
qr_task_unref(task);
}
static int qr_task_send(struct qr_task *task, struct session2 *session,
const struct comm_info *comm, knot_pkt_t *pkt)
{
if (!session)
return qr_task_on_send(task, NULL, kr_error(EIO));
int ret = 0;
if (pkt == NULL)
pkt = worker_task_get_pktbuf(task);
if (session->outgoing && session->stream) {
size_t try_limit = session2_tasklist_get_len(session) + 1;
uint16_t msg_id = knot_wire_get_id(pkt->wire);
size_t try_count = 0;
while (session2_tasklist_find_msgid(session, msg_id) &&
try_count <= try_limit) {
++msg_id;
++try_count;
}
if (try_count > try_limit)
return kr_error(ENOENT);
worker_task_pkt_set_msgid(task, msg_id);
}
/* Note time for upstream RTT */
task->send_time = kr_now();
task->recv_time = 0; // task structure is being reused so we have to zero this out here
/* Send using given protocol */
if (kr_fails_assert(!session->closing))
return qr_task_on_send(task, NULL, kr_error(EIO));
/* Pending '_finished' callback on current task */
qr_task_ref(task);
struct protolayer_payload payload = protolayer_payload_buffer(
(char *)pkt->wire, pkt->size, false);
payload.ttl = packet_ttl(pkt);
ret = session2_wrap(session, payload, comm, qr_task_wrap_finished, task);
if (ret >= 0) {
session2_touch(session);
if (session->outgoing) {
session2_tasklist_add(session, task);
}
if (the_worker->too_many_open &&
the_worker->stats.rconcurrent <
the_worker->rconcurrent_highwatermark - 10) {
the_worker->too_many_open = false;
}
ret = kr_ok();
} else {
ioreq_release(task->worker, send_req);
if (ret == UV_EMFILE) {
the_worker->too_many_open = true;
the_worker->rconcurrent_highwatermark = the_worker->stats.rconcurrent;
ret = kr_error(UV_EMFILE);
}
session2_event(session, PROTOLAYER_EVENT_STATS_SEND_ERR, NULL);
}
/* Update statistics */
if (handle != task->source.handle && addr) {
if (handle->type == UV_UDP)
task->worker->stats.udp += 1;
else
task->worker->stats.tcp += 1;
if (addr->sa_family == AF_INET6)
task->worker->stats.ipv6 += 1;
else
task->worker->stats.ipv4 += 1;
/* Update outgoing query statistics */
if (session->outgoing && comm) {
session2_event(session, PROTOLAYER_EVENT_STATS_QRY_OUT, NULL);
if (comm->comm_addr->sa_family == AF_INET6)
the_worker->stats.ipv6 += 1;
else if (comm->comm_addr->sa_family == AF_INET)
the_worker->stats.ipv4 += 1;
}
return ret;
}
static void on_connect(uv_connect_t *req, int status)
static struct kr_query *task_get_last_pending_query(struct qr_task *task)
{
struct worker_ctx *worker = get_worker();
struct qr_task *task = req->data;
if (task) {
task->ioreq = NULL;
if (status == 0) {
struct sockaddr_in6 addr;
int addrlen = sizeof(addr); /* Retrieve endpoint IP for statistics */
uv_stream_t *handle = req->handle;
uv_tcp_getpeername((uv_tcp_t *)handle, (struct sockaddr *)&addr, &addrlen);
qr_task_send(task, (uv_handle_t *)handle, (struct sockaddr *)&addr, task->pktbuf);
} else {
qr_task_step(task, NULL);
}
if (!task || task->ctx->req.rplan.pending.len == 0) {
return NULL;
}
ioreq_release(worker, (struct ioreq *)req);
return array_tail(task->ctx->req.rplan.pending);
}
static int qr_task_finalize(struct qr_task *task, int state)
static int send_waiting(struct session2 *session)
{
kr_resolve_finish(&task->req, state);
/* Send back answer */
(void) qr_task_send(task, task->source.handle, (struct sockaddr *)&task->source.addr, task->req.answer);
return state == KNOT_STATE_DONE ? 0 : kr_error(EIO);
if (session2_waitinglist_is_empty(session))
return 0;
defer_sample_state_t defer_prev_sample_state;
defer_sample_start(&defer_prev_sample_state);
int ret = 0;
do {
struct qr_task *t = session2_waitinglist_get(session);
defer_sample_task(t);
ret = qr_task_send(t, session, NULL, NULL);
defer_sample_restart();
if (ret != 0) {
struct sockaddr *peer = session2_get_peer(session);
session2_waitinglist_finalize(session, KR_STATE_FAIL);
session2_tasklist_finalize(session, KR_STATE_FAIL);
worker_del_tcp_connected(peer);
session2_close(session);
break;
}
session2_waitinglist_pop(session, true);
} while (!session2_waitinglist_is_empty(session));
defer_sample_stop(&defer_prev_sample_state, true);
return ret;
}
static int qr_task_step(struct qr_task *task, knot_pkt_t *packet)
static void on_connect(uv_connect_t *req, int status)
{
/* Close subrequest handle. */
uv_timer_stop(&task->timeout);
if (task->iohandle && !uv_is_closing(task->iohandle)) {
io_stop_read(task->iohandle);
uv_close(task->iohandle, on_close);
task->iohandle = NULL;
}
kr_require(the_worker);
uv_stream_t *handle = req->handle;
struct session2 *session = handle->data;
struct sockaddr *peer = session2_get_peer(session);
free(req);
/* Consume input and produce next query */
int sock_type = -1;
struct sockaddr *addr = NULL;
knot_pkt_t *pktbuf = task->pktbuf;
int state = kr_resolve_consume(&task->req, packet);
while (state == KNOT_STATE_PRODUCE) {
state = kr_resolve_produce(&task->req, &addr, &sock_type, pktbuf);
if (unlikely(++task->iter_count > KR_ITER_LIMIT)) {
return qr_task_finalize(task, KNOT_STATE_FAIL);
}
}
if (kr_fails_assert(session->outgoing))
return;
/* We're done, no more iterations needed */
if (state & (KNOT_STATE_DONE|KNOT_STATE_FAIL)) {
return qr_task_finalize(task, state);
} else if (!addr || sock_type < 0) {
return qr_task_step(task, NULL);
if (session->closing) {
worker_del_tcp_waiting(peer);
kr_assert(session2_is_empty(session));
return;
}
/* Create connection for iterative query */
uv_handle_t *subreq = (uv_handle_t *)ioreq_take(task->worker);
if (!subreq) {
return qr_task_finalize(task, KNOT_STATE_FAIL);
const bool log_debug = kr_log_is_debug(WORKER, NULL);
/* Check if the connection is in the waiting list.
* If no, most likely this is timed out connection
* which was removed from waiting list by
* on_tcp_connect_timeout() callback. */
struct session2 *found_session = worker_find_tcp_waiting(peer);
if (!found_session || found_session != session) {
/* session isn't on the waiting list.
* it's timed out session. */
if (log_debug) {
const char *peer_str = kr_straddr(peer);
kr_log_debug(WORKER, "=> connected to '%s', but session "
"is already timed out, close\n",
peer_str ? peer_str : "");
}
kr_assert(session2_tasklist_is_empty(session));
session2_waitinglist_retry(session, false);
session2_close(session);
return;
}
io_create(task->worker->loop, subreq, sock_type);
subreq->data = task;
/* Connect or issue query datagram */
task->iohandle = subreq;
if (sock_type == SOCK_DGRAM) {
if (qr_task_send(task, subreq, addr, pktbuf) != 0) {
return qr_task_step(task, NULL);
found_session = worker_find_tcp_connected(peer);
if (found_session) {
/* session already in the connected list.
* Something went wrong, it can be due to races when kresd has tried
* to reconnect to upstream after unsuccessful attempt. */
if (log_debug) {
const char *peer_str = kr_straddr(peer);
kr_log_debug(WORKER, "=> connected to '%s', but peer "
"is already connected, close\n",
peer_str ? peer_str : "");
}
} else {
struct ioreq *conn_req = ioreq_take(task->worker);
if (!conn_req) {
return qr_task_step(task, NULL);
kr_assert(session2_tasklist_is_empty(session));
session2_waitinglist_retry(session, false);
session2_close(session);
return;
}
if (status != 0) {
if (log_debug) {
const char *peer_str = kr_straddr(peer);
kr_log_debug(WORKER, "=> connection to '%s' failed (%s), flagged as 'bad'\n",
peer_str ? peer_str : "", uv_strerror(status));
}
conn_req->as.connect.data = task;
task->ioreq = (uv_req_t *)conn_req;
if (uv_tcp_connect(&conn_req->as.connect, (uv_tcp_t *)subreq, addr, on_connect) != 0) {
ioreq_release(task->worker, conn_req);
return qr_task_step(task, NULL);
worker_del_tcp_waiting(peer);
if (status != UV_ETIMEDOUT) {
/* In case of UV_ETIMEDOUT upstream has been
* already penalized in on_tcp_connect_timeout() */
session2_event(session, PROTOLAYER_EVENT_CONNECT_FAIL, NULL);
}
kr_assert(session2_tasklist_is_empty(session));
session2_close(session);
return;
}
/* Start next step with timeout */
uv_timer_start(&task->timeout, qr_task_timeout, KR_CONN_RTT_MAX, 0);
return kr_ok();
if (log_debug) {
const char *peer_str = kr_straddr(peer);
kr_log_debug(WORKER, "=> connected to '%s'\n", peer_str ? peer_str : "");
}
session2_event(session, PROTOLAYER_EVENT_CONNECT, NULL);
session2_start_read(session);
session2_timer_stop(session);
session2_timer_start(session, PROTOLAYER_EVENT_GENERAL_TIMEOUT,
MAX_TCP_INACTIVITY, MAX_TCP_INACTIVITY);
}
static int parse_query(knot_pkt_t *query)
static int transmit(struct qr_task *task)
{
/* Parse query packet. */
int ret = knot_pkt_parse(query, 0);
if (ret != KNOT_EOK) {
return kr_error(EPROTO); /* Ignore malformed query. */
}
if (!task)
return kr_error(EINVAL);
/* Check if at least header is parsed. */
if (query->parsed < query->size) {
return kr_error(EMSGSIZE);
}
struct kr_transport* transport = task->transport;
struct sockaddr_in6 *choice = (struct sockaddr_in6 *)&transport->address;
return kr_ok();
}
if (!choice)
return kr_error(EINVAL);
if (task->pending_count >= MAX_PENDING)
return kr_error(EBUSY);
/* Checkout answer before sending it */
struct request_ctx *ctx = task->ctx;
int ret = kr_resolve_checkout(&ctx->req, NULL, transport, task->pktbuf);
if (ret)
return ret;
int worker_exec(struct worker_ctx *worker, uv_handle_t *handle, knot_pkt_t *query, const struct sockaddr* addr)
{
if (!worker) {
struct session2 *session = ioreq_spawn(SOCK_DGRAM, choice->sin6_family,
KR_PROTO_UDP53, NULL, 0);
if (!session)
return kr_error(EINVAL);
}
/* Parse query */
int ret = parse_query(query);
struct sockaddr *addr = (struct sockaddr *)choice;
struct sockaddr *peer = session2_get_peer(session);
kr_assert(peer->sa_family == AF_UNSPEC && session->outgoing);
kr_require(addr->sa_family == AF_INET || addr->sa_family == AF_INET6);
memcpy(peer, addr, kr_sockaddr_len(addr));
/* Start new task on master sockets, or resume existing */
struct qr_task *task = handle->data;
bool is_master_socket = (!task);
if (is_master_socket) {
/* Ignore badly formed queries or responses. */
if (ret != 0 || knot_wire_get_qr(query->wire)) {
return kr_error(EINVAL); /* Ignore. */
}
task = qr_task_create(worker, handle, query, addr);
if (!task) {
return kr_error(ENOMEM);
struct comm_info out_comm = {
.comm_addr = (struct sockaddr *)choice
};
if (the_network->enable_connect_udp && session->outgoing && !session->stream) {
uv_udp_t *udp = (uv_udp_t *)session2_get_handle(session);
int connect_tries = 3;
do {
ret = uv_udp_connect(udp, out_comm.comm_addr);
} while (ret == UV_EADDRINUSE && --connect_tries > 0);
if (ret < 0) {
kr_log_info(IO, "Failed to establish udp connection to %s: %s\n",
kr_straddr(out_comm.comm_addr), uv_strerror(ret));
}
}
ret = qr_task_send(task, session, &out_comm, task->pktbuf);
if (ret) {
session2_close(session);
return ret;
}
/* Consume input and produce next query */
return qr_task_step(task, query);
task->pending[task->pending_count] = session;
task->pending_count += 1;
session2_start_read(session); /* Start reading answer */
return kr_ok();
}
int worker_resolve(struct worker_ctx *worker, knot_pkt_t *query, unsigned options)
static void subreq_finalize(struct qr_task *task, const struct sockaddr *packet_source, knot_pkt_t *pkt)
{
if (!worker) {
return kr_error(EINVAL);
if (!task || task->finished) {
return;
}
/* Close pending timer */
ioreq_kill_pending(task);
/* Clear from outgoing table. */
if (!task->leading)
return;
char key[SUBREQ_KEY_LEN];
const int klen = subreq_key(key, task->pktbuf);
if (klen > 0) {
void *val_deleted;
int ret = trie_del(the_worker->subreq_out, key, klen, &val_deleted);
kr_assert(ret == KNOT_EOK && val_deleted == task);
}
/* Notify waiting tasks. */
if (task->waiting.len > 0) {
struct kr_query *leader_qry = array_tail(task->ctx->req.rplan.pending);
defer_sample_state_t defer_prev_sample_state;
defer_sample_start(&defer_prev_sample_state);
for (size_t i = task->waiting.len; i > 0; i--) {
struct qr_task *follower = task->waiting.at[i - 1];
/* Reuse MSGID and 0x20 secret */
if (follower->ctx->req.rplan.pending.len > 0) {
struct kr_query *qry = array_tail(follower->ctx->req.rplan.pending);
qry->id = leader_qry->id;
qry->secret = leader_qry->secret;
/* Create task */
struct qr_task *task = qr_task_create(worker, NULL, query, NULL);
if (!task) {
return kr_error(ENOMEM);
// Note that this transport may not be present in `leader_qry`'s server selection
follower->transport = task->transport;
if(follower->transport) {
follower->transport->deduplicated = true;
}
leader_qry->secret = 0; /* Next will be already decoded */
}
qr_task_step(follower, packet_source, pkt);
qr_task_unref(follower);
defer_sample_restart();
}
defer_sample_stop(&defer_prev_sample_state, true);
task->waiting.len = 0;
}
task->req.options |= options;
return qr_task_step(task, query);
task->leading = false;
}
int worker_reserve(struct worker_ctx *worker, size_t ring_maxlen)
static void subreq_lead(struct qr_task *task)
{
array_init(worker->pools);
array_init(worker->ioreqs);
array_reserve(worker->pools, ring_maxlen);
array_reserve(worker->ioreqs, ring_maxlen);
memset(&worker->pkt_pool, 0, sizeof(worker->pkt_pool));
worker->pkt_pool.ctx = mp_new (4 * sizeof(knot_pkt_t));
worker->pkt_pool.alloc = (mm_alloc_t) mp_alloc;
return kr_ok();
if (kr_fails_assert(task))
return;
char key[SUBREQ_KEY_LEN];
const int klen = subreq_key(key, task->pktbuf);
if (klen < 0)
return;
struct qr_task **tvp = (struct qr_task **)
trie_get_ins(the_worker->subreq_out, key, klen);
if (unlikely(!tvp))
return; /*ENOMEM*/
if (kr_fails_assert(*tvp == NULL))
return;
*tvp = task;
task->leading = true;
}
#define reclaim_freelist(list, cb) \
for (unsigned i = 0; i < list.len; ++i) { \
cb(list.at[i]); \
} \
array_clear(list)
void worker_reclaim(struct worker_ctx *worker)
static bool subreq_enqueue(struct qr_task *task)
{
reclaim_freelist(worker->pools, mp_delete);
reclaim_freelist(worker->ioreqs, free);
mp_delete(worker->pkt_pool.ctx);
worker->pkt_pool.ctx = NULL;
if (kr_fails_assert(task))
return false;
char key[SUBREQ_KEY_LEN];
const int klen = subreq_key(key, task->pktbuf);
if (klen < 0)
return false;
struct qr_task **leader = (struct qr_task **)
trie_get_try(the_worker->subreq_out, key, klen);
if (!leader /*ENOMEM*/ || !*leader)
return false;
/* Enqueue itself to leader for this subrequest. */
int ret = array_push_mm((*leader)->waiting, task,
kr_memreserve, &(*leader)->ctx->req.pool);
if (unlikely(ret < 0)) /*ENOMEM*/
return false;
qr_task_ref(task);
return true;
}
static int qr_task_finalize(struct qr_task *task, int state)
{
kr_require(task && task->leading == false);
if (task->finished) {
return kr_ok();
}
defer_sample_task(task);
struct request_ctx *ctx = task->ctx;
struct session2 *source_session = ctx->source.session;
kr_resolve_finish(&ctx->req, state);
task->finished = true;
if (source_session == NULL) {
(void) qr_task_on_send(task, NULL, kr_error(EIO));
return state == KR_STATE_DONE ? kr_ok() : kr_error(EIO);
}
/* meant to be dropped */
if (unlikely(ctx->req.answer == NULL || ctx->req.options.NO_ANSWER)) {
/* For NO_ANSWER, a well-behaved layer should set the state to FAIL */
kr_assert(!ctx->req.options.NO_ANSWER || (ctx->req.state & KR_STATE_FAIL));
(void) qr_task_on_send(task, NULL, kr_ok());
return kr_ok();
}
if (source_session->closing ||
ctx->source.addr.ip.sa_family == AF_UNSPEC)
return kr_error(EINVAL);
/* Reference task as the callback handler can close it */
qr_task_ref(task);
/* Send back answer */
struct comm_info out_comm = {
.src_addr = &ctx->source.addr.ip,
.dst_addr = &ctx->source.dst_addr.ip,
.comm_addr = &ctx->source.comm_addr.ip,
.xdp = ctx->source.xdp
};
if (ctx->source.xdp) {
memcpy(out_comm.eth_from, ctx->source.eth_from, sizeof(out_comm.eth_from));
memcpy(out_comm.eth_to, ctx->source.eth_to, sizeof(out_comm.eth_to));
}
int ret = qr_task_send(task, source_session, &out_comm, ctx->req.answer);
if (ret != kr_ok()) {
(void) qr_task_on_send(task, NULL, kr_error(EIO));
/* Since source session is erroneous detach all tasks. */
while (!session2_tasklist_is_empty(source_session)) {
struct qr_task *t = session2_tasklist_del_first(source_session, false);
struct request_ctx *c = t->ctx;
kr_assert(c->source.session == source_session);
c->source.session = NULL;
/* Don't finalize them as there can be other tasks
* waiting for answer to this particular task.
* (ie. task->leading is true) */
worker_task_unref(t);
}
session2_close(source_session);
}
if (source_session->stream && !source_session->closing) {
struct pl_dns_stream_sess_data *stream =
protolayer_sess_data_get_proto(source_session, PROTOLAYER_TYPE_DNS_MULTI_STREAM);
if (!stream)
stream = protolayer_sess_data_get_proto(source_session, PROTOLAYER_TYPE_DNS_UNSIZED_STREAM);
if (!stream)
stream = protolayer_sess_data_get_proto(source_session, PROTOLAYER_TYPE_DNS_SINGLE_STREAM);
if (stream && stream->half_closed) {
session2_force_close(source_session);
}
}
qr_task_unref(task);
if (ret != kr_ok() || state != KR_STATE_DONE)
return kr_error(EIO);
return kr_ok();
}
static int udp_task_step(struct qr_task *task,
const struct sockaddr *packet_source, knot_pkt_t *packet)
{
/* If there is already outgoing query, enqueue to it. */
if (subreq_enqueue(task)) {
return kr_ok(); /* Will be notified when outgoing query finishes. */
}
/* Start transmitting */
int err = transmit(task);
if (err) {
subreq_finalize(task, packet_source, packet);
return qr_task_finalize(task, KR_STATE_FAIL);
}
/* Announce and start subrequest.
* @note Only UDP can lead I/O as it doesn't touch 'task->pktbuf' for reassembly.
*/
subreq_lead(task);
return kr_ok();
}
static int tcp_task_waiting_connection(struct session2 *session, struct qr_task *task)
{
if (kr_fails_assert(session->outgoing && !session->closing))
return kr_error(EINVAL);
/* Add task to the end of list of waiting tasks.
* It will be notified in on_connect() or qr_task_on_send(). */
int ret = session2_waitinglist_push(session, task);
if (ret < 0) {
return kr_error(EINVAL);
}
return kr_ok();
}
static int tcp_task_existing_connection(struct session2 *session, struct qr_task *task)
{
if (kr_fails_assert(session->outgoing && !session->closing))
return kr_error(EINVAL);
/* If there are any unsent queries, send it first. */
int ret = send_waiting(session);
if (ret != 0) {
return kr_error(EINVAL);
}
/* No unsent queries at that point. */
if (session2_tasklist_get_len(session) >= the_worker->tcp_pipeline_max) {
/* Too many outstanding queries, answer with SERVFAIL, */
return kr_error(EINVAL);
}
/* Send query to upstream. */
ret = qr_task_send(task, session, NULL, NULL);
if (ret != 0) {
/* Error, finalize task with SERVFAIL and
* close connection to upstream. */
session2_tasklist_finalize(session, KR_STATE_FAIL);
worker_del_tcp_connected(session2_get_peer(session));
session2_close(session);
return kr_error(EINVAL);
}
return kr_ok();
}
static int tcp_task_make_connection(struct qr_task *task, const struct sockaddr *addr)
{
/* Check if there must be TLS */
tls_client_param_t *tls_entry = tls_client_param_get(
the_network->tls_client_params, addr);
uv_connect_t *conn = malloc(sizeof(uv_connect_t));
if (!conn) {
return kr_error(EINVAL);
}
struct session2 *session;
bool has_tls = tls_entry;
if (has_tls) {
struct protolayer_data_param param = {
.protocol = PROTOLAYER_TYPE_TLS,
.param = tls_entry
};
session = ioreq_spawn(SOCK_STREAM, addr->sa_family,
KR_PROTO_DOT, &param, 1);
} else {
session = ioreq_spawn(SOCK_STREAM, addr->sa_family,
KR_PROTO_TCP53, NULL, 0);
}
if (!session) {
free(conn);
return kr_error(EINVAL);
}
if (kr_fails_assert(session->secure == has_tls)) {
free(conn);
return kr_error(EINVAL);
}
/* Add address to the waiting list.
* Now it "is waiting to be connected to." */
int ret = worker_add_tcp_waiting(addr, session);
if (ret < 0) {
free(conn);
session2_close(session);
return kr_error(EINVAL);
}
conn->data = session;
/* Store peer address for the session. */
struct sockaddr *peer = session2_get_peer(session);
memcpy(peer, addr, kr_sockaddr_len(addr));
/* Start watchdog to catch eventual connection timeout. */
ret = session2_timer_start(session, PROTOLAYER_EVENT_CONNECT_TIMEOUT,
KR_CONN_RTT_MAX, 0);
if (ret != 0) {
worker_del_tcp_waiting(addr);
free(conn);
session2_close(session);
return kr_error(EINVAL);
}
struct kr_query *qry = task_get_last_pending_query(task);
if (kr_log_is_debug_qry(WORKER, qry)) {
const char *peer_str = kr_straddr(peer);
VERBOSE_MSG(qry, "=> connecting to: '%s'\n", peer_str ? peer_str : "");
}
/* Start connection process to upstream. */
ret = uv_tcp_connect(conn, (uv_tcp_t *)session2_get_handle(session),
addr , on_connect);
if (ret != 0) {
session2_timer_stop(session);
worker_del_tcp_waiting(addr);
free(conn);
session2_close(session);
qry->server_selection.error(qry, task->transport, KR_SELECTION_TCP_CONNECT_FAILED);
return kr_error(EAGAIN);
}
/* Add task to the end of list of waiting tasks.
* Will be notified either in on_connect() or in qr_task_on_send(). */
ret = session2_waitinglist_push(session, task);
if (ret < 0) {
session2_timer_stop(session);
worker_del_tcp_waiting(addr);
free(conn);
session2_close(session);
return kr_error(EINVAL);
}
return kr_ok();
}
static int tcp_task_step(struct qr_task *task,
const struct sockaddr *packet_source, knot_pkt_t *packet)
{
if (kr_fails_assert(task->pending_count == 0)) {
subreq_finalize(task, packet_source, packet);
return qr_task_finalize(task, KR_STATE_FAIL);
}
/* target */
const struct sockaddr *addr = &task->transport->address.ip;
if (addr->sa_family == AF_UNSPEC) {
/* Target isn't defined. Finalize task with SERVFAIL.
* Although task->pending_count is zero, there are can be followers,
* so we need to call subreq_finalize() to handle them properly. */
subreq_finalize(task, packet_source, packet);
return qr_task_finalize(task, KR_STATE_FAIL);
}
/* Checkout task before connecting */
struct request_ctx *ctx = task->ctx;
if (kr_resolve_checkout(&ctx->req, NULL, task->transport, task->pktbuf) != 0) {
subreq_finalize(task, packet_source, packet);
return qr_task_finalize(task, KR_STATE_FAIL);
}
int ret;
struct session2* session = NULL;
if ((session = worker_find_tcp_waiting(addr)) != NULL) {
/* Connection is in the list of waiting connections.
* It means that connection establishing is coming right now. */
ret = tcp_task_waiting_connection(session, task);
} else if ((session = worker_find_tcp_connected(addr)) != NULL) {
/* Connection has been already established. */
ret = tcp_task_existing_connection(session, task);
} else {
/* Make connection. */
ret = tcp_task_make_connection(task, addr);
}
if (ret != kr_ok()) {
subreq_finalize(task, addr, packet);
if (ret == kr_error(EAGAIN)) {
ret = qr_task_step(task, addr, NULL);
} else {
ret = qr_task_finalize(task, KR_STATE_FAIL);
}
}
return ret;
}
static int qr_task_step(struct qr_task *task,
const struct sockaddr *packet_source, knot_pkt_t *packet)
{
defer_sample_task(task);
/* No more steps after we're finished. */
if (!task || task->finished) {
return kr_error(ESTALE);
}
/* Close pending I/O requests */
subreq_finalize(task, packet_source, packet);
if ((kr_now() - task->creation_time) >= KR_RESOLVE_TIME_LIMIT) {
struct kr_request *req = worker_task_request(task);
if (!kr_fails_assert(req))
kr_query_inform_timeout(req, req->current_query);
return qr_task_finalize(task, KR_STATE_FAIL);
}
/* Consume input and produce next query */
struct request_ctx *ctx = task->ctx;
if (kr_fails_assert(ctx))
return qr_task_finalize(task, KR_STATE_FAIL);
struct kr_request *req = &ctx->req;
if (the_worker->too_many_open) {
/* */
struct kr_rplan *rplan = &req->rplan;
if (the_worker->stats.rconcurrent <
the_worker->rconcurrent_highwatermark - 10) {
the_worker->too_many_open = false;
} else {
if (packet && kr_rplan_empty(rplan)) {
/* new query; TODO - make this detection more obvious */
kr_resolve_consume(req, &task->transport, packet);
}
return qr_task_finalize(task, KR_STATE_FAIL);
}
}
// Report network RTT back to server selection
if (packet && task->send_time && task->recv_time) {
struct kr_query *qry = array_tail(req->rplan.pending);
qry->server_selection.update_rtt(qry, task->transport, task->recv_time - task->send_time);
}
int state = kr_resolve_consume(req, &task->transport, packet);
task->transport = NULL;
while (state == KR_STATE_PRODUCE) {
state = kr_resolve_produce(req, &task->transport, task->pktbuf);
if (unlikely(++task->iter_count > KR_ITER_LIMIT ||
task->timeouts >= KR_TIMEOUT_LIMIT)) {
struct kr_rplan *rplan = &req->rplan;
struct kr_query *last = kr_rplan_last(rplan);
if (task->iter_count > KR_ITER_LIMIT) {
char *msg = "cancelling query due to exceeded iteration count limit";
VERBOSE_MSG(last, "%s of %d\n", msg, KR_ITER_LIMIT);
kr_request_set_extended_error(req, KNOT_EDNS_EDE_OTHER,
"OGHD: exceeded iteration count limit");
}
if (task->timeouts >= KR_TIMEOUT_LIMIT) {
char *msg = "cancelling query due to exceeded timeout retries limit";
VERBOSE_MSG(last, "%s of %d\n", msg, KR_TIMEOUT_LIMIT);
kr_request_set_extended_error(req, KNOT_EDNS_EDE_NREACH_AUTH, "QLPL");
}
return qr_task_finalize(task, KR_STATE_FAIL);
}
}
/* We're done, no more iterations needed */
if (state & (KR_STATE_DONE|KR_STATE_FAIL)) {
return qr_task_finalize(task, state);
} else if (!task->transport || !task->transport->protocol) {
return qr_task_step(task, NULL, NULL);
}
switch (task->transport->protocol)
{
case KR_TRANSPORT_UDP:
return udp_task_step(task, packet_source, packet);
case KR_TRANSPORT_TCP: // fall through
case KR_TRANSPORT_TLS:
return tcp_task_step(task, packet_source, packet);
default:
kr_assert(!EINVAL);
return kr_error(EINVAL);
}
}
static int worker_submit(struct session2 *session, struct comm_info *comm, knot_pkt_t *pkt)
{
if (!session || !pkt || session->closing)
return kr_error(EINVAL);
const bool is_query = pkt->size > KNOT_WIRE_OFFSET_FLAGS1
&& knot_wire_get_qr(pkt->wire) == 0;
const bool is_outgoing = session->outgoing;
int ret = 0;
if (is_query == is_outgoing)
ret = KNOT_ENOENT;
// For responses from upstream, try to find associated task and query.
// In case of errors, at least try to guess.
struct qr_task *task = NULL;
bool task_matched_id = false;
if (is_outgoing && pkt->size >= 2) {
const uint16_t id = knot_wire_get_id(pkt->wire);
task = session2_tasklist_del_msgid(session, id);
task_matched_id = task != NULL;
if (task_matched_id) // Note receive time for RTT calculation
task->recv_time = kr_now();
if (!task_matched_id) {
ret = KNOT_ENOENT;
VERBOSE_MSG(NULL, "=> DNS message with mismatching ID %d\n",
(int)id);
}
}
if (!task && is_outgoing && session->stream) {
// Source address of the reply got somewhat validated,
// so we try to at least guess which query, for error reporting.
task = session2_tasklist_get_first(session);
}
struct kr_query *qry = NULL;
if (task)
qry = array_tail(task->ctx->req.rplan.pending);
// Parse the packet, unless it's useless anyway.
if (ret == 0) {
ret = knot_pkt_parse(pkt, 0);
if (ret == KNOT_ETRAIL && is_outgoing
&& !kr_fails_assert(pkt->parsed < pkt->size)) {
// We deal with this later, so that RCODE takes priority.
ret = 0;
}
if (ret && kr_log_is_debug_qry(WORKER, qry)) {
VERBOSE_MSG(qry, "=> DNS message failed to parse, %s\n",
knot_strerror(ret));
}
}
/* Badly formed query when using DoH leads to a Bad Request */
if (session->custom_emalf_handling && !is_outgoing && ret) {
session2_event(session, PROTOLAYER_EVENT_MALFORMED, NULL);
return ret;
}
const struct sockaddr *addr = comm ? comm->src_addr : NULL;
/* Ignore badly formed queries. */
if (ret) {
if (is_outgoing && qry) // unusuable response from somewhat validated IP
qry->server_selection.error(qry, task->transport, KR_SELECTION_MALFORMED);
if (!is_outgoing)
the_worker->stats.dropped += 1;
if (task_matched_id) // notify task that answer won't be coming anymore
qr_task_step(task, addr, NULL);
return kr_error(EILSEQ);
}
/* Start new task on listening sockets,
* or resume if this is subrequest */
if (!is_outgoing) { /* request from a client */
struct request_ctx *ctx =
request_create(session, comm, knot_wire_get_id(pkt->wire));
if (!ctx)
return kr_error(ENOMEM);
ret = request_start(ctx, pkt);
if (ret != 0) {
request_free(ctx);
return kr_error(ENOMEM);
}
task = qr_task_create(ctx);
if (!task) {
request_free(ctx);
return kr_error(ENOMEM);
}
if (session->stream && qr_task_register(task, session)) {
return kr_error(ENOMEM);
}
} else { /* response from upstream */
if (task == NULL) {
return kr_error(ENOENT);
}
if (kr_fails_assert(!session->closing))
return kr_error(EINVAL);
}
if (kr_fails_assert(!session->closing))
return kr_error(EINVAL);
/* Packet was successfully parsed.
* Task was created (found). */
session2_touch(session);
/* Consume input and produce next message */
return qr_task_step(task, addr, pkt);
}
static int trie_add_tcp_session(trie_t *trie, const struct sockaddr *addr,
struct session2 *session)
{
if (kr_fails_assert(trie && addr))
return kr_error(EINVAL);
struct kr_sockaddr_key_storage key;
ssize_t keylen = kr_sockaddr_key(&key, addr);
if (keylen < 0)
return keylen;
trie_val_t *val = trie_get_ins(trie, key.bytes, keylen);
if (*val != NULL)
return kr_error(EEXIST);
*val = session;
return kr_ok();
}
static int trie_del_tcp_session(trie_t *trie, const struct sockaddr *addr)
{
if (kr_fails_assert(trie && addr))
return kr_error(EINVAL);
struct kr_sockaddr_key_storage key;
ssize_t keylen = kr_sockaddr_key(&key, addr);
if (keylen < 0)
return keylen;
int ret = trie_del(trie, key.bytes, keylen, NULL);
return ret ? kr_error(ENOENT) : kr_ok();
}
static struct session2 *trie_find_tcp_session(trie_t *trie,
const struct sockaddr *addr)
{
if (kr_fails_assert(trie && addr))
return NULL;
struct kr_sockaddr_key_storage key;
ssize_t keylen = kr_sockaddr_key(&key, addr);
if (keylen < 0)
return NULL;
trie_val_t *val = trie_get_try(trie, key.bytes, keylen);
return val ? *val : NULL;
}
static int worker_add_tcp_connected(const struct sockaddr* addr, struct session2 *session)
{
return trie_add_tcp_session(the_worker->tcp_connected, addr, session);
}
static int worker_del_tcp_connected(const struct sockaddr* addr)
{
return trie_del_tcp_session(the_worker->tcp_connected, addr);
}
static struct session2* worker_find_tcp_connected(const struct sockaddr* addr)
{
return trie_find_tcp_session(the_worker->tcp_connected, addr);
}
static int worker_add_tcp_waiting(const struct sockaddr* addr,
struct session2 *session)
{
return trie_add_tcp_session(the_worker->tcp_waiting, addr, session);
}
static int worker_del_tcp_waiting(const struct sockaddr* addr)
{
return trie_del_tcp_session(the_worker->tcp_waiting, addr);
}
static struct session2* worker_find_tcp_waiting(const struct sockaddr* addr)
{
return trie_find_tcp_session(the_worker->tcp_waiting, addr);
}
knot_pkt_t *worker_resolve_mk_pkt_dname(knot_dname_t *qname, uint16_t qtype, uint16_t qclass,
const struct kr_qflags *options)
{
knot_pkt_t *pkt = knot_pkt_new(NULL, KNOT_EDNS_MAX_UDP_PAYLOAD, NULL);
if (!pkt)
return NULL;
knot_pkt_put_question(pkt, qname, qclass, qtype);
knot_wire_set_rd(pkt->wire);
knot_wire_set_ad(pkt->wire);
/* Add OPT RR, including wire format so modules can see both representations.
* knot_pkt_put() copies the outside; we need to duplicate the inside manually. */
knot_rrset_t *opt = knot_rrset_copy(the_resolver->downstream_opt_rr, NULL);
if (!opt) {
knot_pkt_free(pkt);
return NULL;
}
if (options->DNSSEC_WANT) {
knot_edns_set_do(opt);
}
knot_pkt_begin(pkt, KNOT_ADDITIONAL);
int ret = knot_pkt_put(pkt, KNOT_COMPR_HINT_NONE, opt, KNOT_PF_FREE);
if (ret == KNOT_EOK) {
free(opt); /* inside is owned by pkt now */
} else {
knot_rrset_free(opt, NULL);
knot_pkt_free(pkt);
return NULL;
}
if (options->DNSSEC_CD) {
knot_wire_set_cd(pkt->wire);
}
return pkt;
}
knot_pkt_t *worker_resolve_mk_pkt(const char *qname_str, uint16_t qtype, uint16_t qclass,
const struct kr_qflags *options)
{
uint8_t qname[KNOT_DNAME_MAXLEN];
if (!knot_dname_from_str(qname, qname_str, sizeof(qname)))
return NULL;
return worker_resolve_mk_pkt_dname(qname, qtype, qclass, options);
}
struct qr_task *worker_resolve_start(knot_pkt_t *query, struct kr_qflags options)
{
if (kr_fails_assert(the_worker && query))
return NULL;
struct request_ctx *ctx = request_create(NULL, NULL, the_worker->next_request_uid);
if (!ctx)
return NULL;
/* Create task */
struct qr_task *task = qr_task_create(ctx);
if (!task) {
request_free(ctx);
return NULL;
}
/* Start task */
int ret = request_start(ctx, query);
if (ret != 0) {
/* task is attached to request context,
* so dereference (and deallocate) it first */
ctx->task = NULL;
qr_task_unref(task);
request_free(ctx);
return NULL;
}
the_worker->next_request_uid += 1;
if (the_worker->next_request_uid == 0)
the_worker->next_request_uid = UINT16_MAX + 1;
/* Set options late, as qr_task_start() -> kr_resolve_begin() rewrite it. */
kr_qflags_set(&task->ctx->req.options, options);
return task;
}
int worker_resolve_exec(struct qr_task *task, knot_pkt_t *query)
{
if (!task)
return kr_error(EINVAL);
return qr_task_step(task, NULL, query);
}
int worker_task_numrefs(const struct qr_task *task)
{
return task->refs;
}
struct kr_request *worker_task_request(struct qr_task *task)
{
if (!task || !task->ctx)
return NULL;
return &task->ctx->req;
}
int worker_task_finalize(struct qr_task *task, int state)
{
return qr_task_finalize(task, state);
}
int worker_task_step(struct qr_task *task, const struct sockaddr *packet_source,
knot_pkt_t *packet)
{
return qr_task_step(task, packet_source, packet);
}
void worker_task_ref(struct qr_task *task)
{
qr_task_ref(task);
}
void worker_task_unref(struct qr_task *task)
{
qr_task_unref(task);
}
void worker_task_timeout_inc(struct qr_task *task)
{
task->timeouts += 1;
}
knot_pkt_t *worker_task_get_pktbuf(const struct qr_task *task)
{
return task->pktbuf;
}
struct kr_transport *worker_task_get_transport(struct qr_task *task)
{
return task->transport;
}
struct session2 *worker_request_get_source_session(const struct kr_request *req)
{
static_assert(offsetof(struct request_ctx, req) == 0,
"Bad struct request_ctx definition.");
return ((struct request_ctx *)req)->source.session;
}
uint16_t worker_task_pkt_get_msgid(struct qr_task *task)
{
knot_pkt_t *pktbuf = worker_task_get_pktbuf(task);
uint16_t msg_id = knot_wire_get_id(pktbuf->wire);
return msg_id;
}
void worker_task_pkt_set_msgid(struct qr_task *task, uint16_t msgid)
{
knot_pkt_t *pktbuf = worker_task_get_pktbuf(task);
knot_wire_set_id(pktbuf->wire, msgid);
struct kr_query *q = task_get_last_pending_query(task);
if (q)
q->id = msgid;
}
uint64_t worker_task_creation_time(struct qr_task *task)
{
return task->creation_time;
}
void worker_task_subreq_finalize(struct qr_task *task)
{
subreq_finalize(task, NULL, NULL);
}
bool worker_task_finished(struct qr_task *task)
{
return task->finished;
}
/** Reserve worker buffers. We assume worker's been zeroed. */
static int worker_reserve(void)
{
the_worker->tcp_connected = trie_create(NULL);
the_worker->tcp_waiting = trie_create(NULL);
the_worker->subreq_out = trie_create(NULL);
mm_ctx_mempool(&the_worker->pkt_pool, 4 * sizeof(knot_pkt_t));
return kr_ok();
}
void worker_deinit(void)
{
if (kr_fails_assert(the_worker))
return;
trie_free(the_worker->tcp_connected);
trie_free(the_worker->tcp_waiting);
trie_free(the_worker->subreq_out);
the_worker->subreq_out = NULL;
for (int i = 0; i < the_worker->doh_qry_headers.len; i++)
free((void *)the_worker->doh_qry_headers.at[i]);
array_clear(the_worker->doh_qry_headers);
mp_delete(the_worker->pkt_pool.ctx);
the_worker->pkt_pool.ctx = NULL;
the_worker = NULL;
}
static inline knot_pkt_t *produce_packet(uint8_t *buf, size_t buf_len)
{
return knot_pkt_new(buf, buf_len, &the_worker->pkt_pool);
}
static enum protolayer_event_cb_result pl_dns_dgram_event_unwrap(
enum protolayer_event_type event, void **baton,
struct session2 *session, void *sess_data)
{
if (event != PROTOLAYER_EVENT_GENERAL_TIMEOUT)
return PROTOLAYER_EVENT_PROPAGATE;
if (session2_tasklist_get_len(session) != 1 ||
!session2_waitinglist_is_empty(session))
return PROTOLAYER_EVENT_PROPAGATE;
session2_timer_stop(session);
struct qr_task *task = session2_tasklist_get_first(session);
if (!task)
return PROTOLAYER_EVENT_PROPAGATE;
if (task->leading && task->pending_count > 0) {
struct kr_query *qry = array_tail(task->ctx->req.rplan.pending);
qry->server_selection.error(qry, task->transport, KR_SELECTION_QUERY_TIMEOUT);
}
task->timeouts += 1;
the_worker->stats.timeout += 1;
qr_task_step(task, NULL, NULL);
return PROTOLAYER_EVENT_PROPAGATE;
}
static size_t pl_dns_dgram_wire_buf_overhead(bool outgoing)
{
if (outgoing) {
if (the_resolver->upstream_opt_rr)
return knot_edns_get_payload(the_resolver->upstream_opt_rr);
} else {
if (the_resolver->downstream_opt_rr)
return knot_edns_get_payload(the_resolver->downstream_opt_rr);
}
return KNOT_WIRE_MAX_PKTSIZE;
}
static enum protolayer_iter_cb_result pl_dns_dgram_unwrap(
void *sess_data, void *iter_data, struct protolayer_iter_ctx *ctx)
{
struct session2 *session = ctx->session;
if (ctx->payload.type == PROTOLAYER_PAYLOAD_IOVEC) {
int ret = kr_ok();
for (int i = 0; i < ctx->payload.iovec.cnt; i++) {
const struct iovec *iov = &ctx->payload.iovec.iov[i];
if (iov->iov_len > MAX_DGRAM_LEN) {
session2_penalize(session);
ret = kr_error(EFBIG);
break;
}
knot_pkt_t *pkt = produce_packet(
iov->iov_base, iov->iov_len);
if (!pkt) {
ret = KNOT_EMALF;
break;
}
ret = worker_submit(session, ctx->comm, pkt);
if (ret)
break;
}
mp_flush(the_worker->pkt_pool.ctx);
return protolayer_break(ctx, ret);
} else if (ctx->payload.type == PROTOLAYER_PAYLOAD_BUFFER) {
if (ctx->payload.buffer.len > MAX_DGRAM_LEN) {
session2_penalize(session);
return protolayer_break(ctx, kr_error(EFBIG));
}
knot_pkt_t *pkt = produce_packet(
ctx->payload.buffer.buf,
ctx->payload.buffer.len);
if (!pkt)
return protolayer_break(ctx, KNOT_EMALF);
int ret = worker_submit(session, ctx->comm, pkt);
mp_flush(the_worker->pkt_pool.ctx);
return protolayer_break(ctx, ret);
} else if (ctx->payload.type == PROTOLAYER_PAYLOAD_WIRE_BUF) {
const size_t msg_len = wire_buf_data_length(ctx->payload.wire_buf);
if (msg_len > MAX_DGRAM_LEN) {
session2_penalize(session);
return protolayer_break(ctx, kr_error(EFBIG));
}
knot_pkt_t *pkt = produce_packet(
wire_buf_data(ctx->payload.wire_buf),
msg_len);
if (!pkt)
return protolayer_break(ctx, KNOT_EMALF);
int ret = worker_submit(session, ctx->comm, pkt);
wire_buf_reset(ctx->payload.wire_buf);
mp_flush(the_worker->pkt_pool.ctx);
return protolayer_break(ctx, ret);
} else {
kr_assert(false && "Invalid payload");
return protolayer_break(ctx, kr_error(EINVAL));
}
}
static int pl_dns_stream_sess_init(struct session2 *session,
void *sess_data, void *param)
{
/* _UNSIZED_STREAM and _MULTI_STREAM - don't forget to split if needed
* at some point */
session->stream = true;
return kr_ok();
}
static int pl_dns_single_stream_sess_init(struct session2 *session,
void *sess_data, void *param)
{
session->stream = true;
struct pl_dns_stream_sess_data *stream = sess_data;
stream->single = true;
return kr_ok();
}
static enum protolayer_event_cb_result pl_dns_stream_resolution_timeout(
struct session2 *s)
{
if (kr_fails_assert(!s->closing))
return PROTOLAYER_EVENT_PROPAGATE;
if (!session2_tasklist_is_empty(s)) {
int finalized = session2_tasklist_finalize_expired(s);
the_worker->stats.timeout += finalized;
/* session2_tasklist_finalize_expired() may call worker_task_finalize().
* If session is a source session and there were IO errors,
* worker_task_finalize() can finalize all tasks and close session. */
if (s->closing)
return PROTOLAYER_EVENT_PROPAGATE;
}
if (!session2_tasklist_is_empty(s)) {
session2_timer_stop(s);
session2_timer_start(s,
PROTOLAYER_EVENT_GENERAL_TIMEOUT,
KR_RESOLVE_TIME_LIMIT / 2,
KR_RESOLVE_TIME_LIMIT / 2);
} else {
/* Normally it should not happen,
* but better to check if there anything in this list. */
if (!session2_waitinglist_is_empty(s)) {
defer_sample_state_t defer_prev_sample_state;
defer_sample_start(&defer_prev_sample_state);
do {
struct qr_task *t = session2_waitinglist_pop(s, false);
worker_task_finalize(t, KR_STATE_FAIL);
worker_task_unref(t);
the_worker->stats.timeout += 1;
if (s->closing)
return PROTOLAYER_EVENT_PROPAGATE;
defer_sample_restart();
} while (!session2_waitinglist_is_empty(s));
defer_sample_stop(&defer_prev_sample_state, true);
}
uint64_t idle_in_timeout = the_network->tcp.in_idle_timeout;
uint64_t idle_time = kr_now() - s->last_activity;
if (idle_time < idle_in_timeout) {
idle_in_timeout -= idle_time;
session2_timer_stop(s);
session2_timer_start(s, PROTOLAYER_EVENT_GENERAL_TIMEOUT,
idle_in_timeout, idle_in_timeout);
} else {
struct sockaddr *peer = session2_get_peer(s);
char *peer_str = kr_straddr(peer);
kr_log_debug(IO, "=> closing connection to '%s'\n",
peer_str ? peer_str : "");
worker_del_tcp_waiting(peer);
worker_del_tcp_connected(peer);
session2_close(s);
}
}
return PROTOLAYER_EVENT_PROPAGATE;
}
static enum protolayer_event_cb_result pl_dns_stream_connected(
struct session2 *session, struct pl_dns_stream_sess_data *stream)
{
if (kr_fails_assert(!stream->connected))
return PROTOLAYER_EVENT_PROPAGATE;
stream->connected = true;
struct sockaddr *peer = session2_get_peer(session);
if (session->outgoing && worker_del_tcp_waiting(peer) != 0) {
/* session isn't in list of waiting queries,
* something gone wrong */
goto fail;
}
int err = worker_add_tcp_connected(peer, session);
if (err) {
/* Could not add session to the list of connected, something
* went wrong. */
goto fail;
}
send_waiting(session);
return PROTOLAYER_EVENT_PROPAGATE;
fail:
session2_waitinglist_finalize(session, KR_STATE_FAIL);
kr_assert(session2_tasklist_is_empty(session));
session2_close(session);
return PROTOLAYER_EVENT_CONSUME;
}
static enum protolayer_event_cb_result pl_dns_stream_connection_fail(
struct session2 *session, enum kr_selection_error sel_err)
{
session2_timer_stop(session);
kr_assert(session2_tasklist_is_empty(session));
struct sockaddr *peer = session2_get_peer(session);
worker_del_tcp_waiting(peer);
struct qr_task *task = session2_waitinglist_get(session);
if (!task) {
/* Normally shouldn't happen. */
const char *peer_str = kr_straddr(peer);
VERBOSE_MSG(NULL, "=> connection to '%s' failed, empty waitinglist\n",
peer_str ? peer_str : "");
return PROTOLAYER_EVENT_PROPAGATE;
}
struct kr_query *qry = task_get_last_pending_query(task);
if (kr_log_is_debug_qry(WORKER, qry)) {
const char *peer_str = kr_straddr(peer);
bool timeout = sel_err == KR_SELECTION_TCP_CONNECT_TIMEOUT;
VERBOSE_MSG(qry, "=> connection to '%s' failed (%s)\n",
peer_str ? peer_str : "",
timeout ? "timeout" : "error");
}
if (qry)
qry->server_selection.error(qry, task->transport, sel_err);
the_worker->stats.timeout += session2_waitinglist_get_len(session);
session2_waitinglist_retry(session, true);
kr_assert(session2_tasklist_is_empty(session));
/* uv_cancel() doesn't support uv_connect_t request,
* so that we can't cancel it.
* There still exists possibility of successful connection
* for this request.
* So connection callback (on_connect()) must check
* if connection is in the list of waiting connection.
* If no, most likely this is timed out connection even if
* it was successful. */
return PROTOLAYER_EVENT_PROPAGATE;
}
static enum protolayer_event_cb_result pl_dns_stream_disconnected(
struct session2 *session, struct pl_dns_stream_sess_data *stream)
{
struct sockaddr *peer = session2_get_peer(session);
worker_del_tcp_waiting(peer);
worker_del_tcp_connected(peer);
if (!stream->connected)
return PROTOLAYER_EVENT_PROPAGATE;
stream->connected = false;
if (session2_is_empty(session))
return PROTOLAYER_EVENT_PROPAGATE;
defer_sample_state_t defer_prev_sample_state;
if (session->outgoing)
defer_sample_start(&defer_prev_sample_state);
while (!session2_waitinglist_is_empty(session)) {
struct qr_task *task = session2_waitinglist_pop(session, false);
kr_assert(task->refs > 1);
session2_tasklist_del(session, task);
if (session->outgoing) {
if (task->ctx->req.options.FORWARD) {
/* We are in TCP_FORWARD mode.
* To prevent failing at kr_resolve_consume()
* qry.flags.TCP must be cleared.
* TODO - refactoring is needed. */
struct kr_request *req = &task->ctx->req;
struct kr_rplan *rplan = &req->rplan;
struct kr_query *qry = array_tail(rplan->pending);
qry->flags.TCP = false;
}
qr_task_step(task, NULL, NULL);
defer_sample_restart();
} else {
kr_assert(task->ctx->source.session == session);
task->ctx->source.session = NULL;
}
worker_task_unref(task);
}
while (!session2_tasklist_is_empty(session)) {
struct qr_task *task = session2_tasklist_del_first(session, false);
if (session->outgoing) {
if (task->ctx->req.options.FORWARD) {
struct kr_request *req = &task->ctx->req;
struct kr_rplan *rplan = &req->rplan;
struct kr_query *qry = array_tail(rplan->pending);
qry->flags.TCP = false;
}
qr_task_step(task, NULL, NULL);
defer_sample_restart();
} else {
kr_assert(task->ctx->source.session == session);
task->ctx->source.session = NULL;
}
worker_task_unref(task);
}
if (session->outgoing)
defer_sample_stop(&defer_prev_sample_state, true);
return PROTOLAYER_EVENT_PROPAGATE;
}
static enum protolayer_event_cb_result pl_dns_stream_eof(
struct session2 *session, struct pl_dns_stream_sess_data *stream)
{
if (!session2_is_empty(session)) {
stream->half_closed = true;
return PROTOLAYER_EVENT_CONSUME;
}
return PROTOLAYER_EVENT_PROPAGATE;
}
static enum protolayer_event_cb_result pl_dns_stream_event_unwrap(
enum protolayer_event_type event, void **baton,
struct session2 *session, void *sess_data)
{
if (session->closing)
return PROTOLAYER_EVENT_PROPAGATE;
struct pl_dns_stream_sess_data *stream = sess_data;
switch (event) {
case PROTOLAYER_EVENT_GENERAL_TIMEOUT:
return pl_dns_stream_resolution_timeout(session);
case PROTOLAYER_EVENT_CONNECT_TIMEOUT:
return pl_dns_stream_connection_fail(session,
KR_SELECTION_TCP_CONNECT_TIMEOUT);
case PROTOLAYER_EVENT_CONNECT:
return pl_dns_stream_connected(session, stream);
case PROTOLAYER_EVENT_CONNECT_FAIL:;
enum kr_selection_error err = (*baton)
? *(enum kr_selection_error *)baton
: KR_SELECTION_TCP_CONNECT_FAILED;
return pl_dns_stream_connection_fail(session, err);
case PROTOLAYER_EVENT_DISCONNECT:
case PROTOLAYER_EVENT_CLOSE:
case PROTOLAYER_EVENT_FORCE_CLOSE:
return pl_dns_stream_disconnected(session, stream);
case PROTOLAYER_EVENT_EOF:
return pl_dns_stream_eof(session, stream);
default:
return PROTOLAYER_EVENT_PROPAGATE;
}
}
static knot_pkt_t *stream_produce_packet(struct session2 *session,
struct wire_buf *wb,
bool *out_err)
{
*out_err = false;
if (wire_buf_data_length(wb) == 0) {
wire_buf_reset(wb);
return NULL;
}
if (wire_buf_data_length(wb) < sizeof(uint16_t)) {
return NULL;
}
uint16_t msg_len = knot_wire_read_u16(wire_buf_data(wb));
if (msg_len == 0) {
*out_err = true;
session2_penalize(session);
return NULL;
}
if (msg_len >= wb->size) {
*out_err = true;
session2_penalize(session);
return NULL;
}
if (wire_buf_data_length(wb) < msg_len + sizeof(uint16_t)) {
return NULL;
}
uint8_t *wire = (uint8_t *)wire_buf_data(wb) + sizeof(uint16_t);
session->was_useful = true;
knot_pkt_t *pkt = produce_packet(wire, msg_len);
*out_err = (pkt == NULL);
return pkt;
}
static int stream_discard_packet(struct session2 *session,
struct wire_buf *wb,
const knot_pkt_t *pkt,
bool *out_err)
{
*out_err = true;
if (kr_fails_assert(wire_buf_data_length(wb) >= sizeof(uint16_t))) {
wire_buf_reset(wb);
return kr_error(EINVAL);
}
size_t msg_size = knot_wire_read_u16(wire_buf_data(wb));
uint8_t *wire = (uint8_t *)wire_buf_data(wb) + sizeof(uint16_t);
if (kr_fails_assert(msg_size + sizeof(uint16_t) <= wire_buf_data_length(wb))) {
/* TCP message length field is greater then
* number of bytes in buffer, must not happen. */
wire_buf_reset(wb);
return kr_error(EINVAL);
}
if (kr_fails_assert(wire == pkt->wire)) {
/* packet wirebuf must be located at the beginning
* of the session wirebuf, must not happen. */
wire_buf_reset(wb);
return kr_error(EINVAL);
}
if (kr_fails_assert(msg_size >= pkt->size)) {
wire_buf_reset(wb);
return kr_error(EINVAL);
}
wire_buf_trim(wb, msg_size + sizeof(uint16_t));
*out_err = false;
if (wire_buf_data_length(wb) == 0) {
wire_buf_reset(wb);
} else if (wire_buf_data_length(wb) < KNOT_WIRE_HEADER_SIZE) {
wire_buf_movestart(wb);
}
return kr_ok();
}
static enum protolayer_iter_cb_result pl_dns_stream_unwrap(
void *sess_data, void *iter_data, struct protolayer_iter_ctx *ctx)
{
if (kr_fails_assert(ctx->payload.type == PROTOLAYER_PAYLOAD_WIRE_BUF)) {
/* DNS stream only works with a wire buffer */
return protolayer_break(ctx, kr_error(EINVAL));
}
int status = kr_ok();
struct session2 *session = ctx->session;
struct pl_dns_stream_sess_data *stream_sess = sess_data;
struct wire_buf *wb = ctx->payload.wire_buf;
if (wire_buf_data_length(wb) == 0)
return protolayer_break(ctx, status);
const uint32_t max_iters = (wire_buf_data_length(wb) /
(KNOT_WIRE_HEADER_SIZE + KNOT_WIRE_QUESTION_MIN_SIZE)) + 1;
int iters = 0;
bool pkt_error = false;
knot_pkt_t *pkt = NULL;
while ((pkt = stream_produce_packet(session, wb, &pkt_error)) && iters < max_iters) {
if (kr_fails_assert(!pkt_error)) {
status = kr_error(EINVAL);
goto exit;
}
if (stream_sess->single && stream_sess->produced) {
if (kr_log_is_debug(WORKER, NULL)) {
kr_log_debug(WORKER, "Unexpected extra data from %s\n",
kr_straddr(ctx->comm->src_addr));
}
status = KNOT_EMALF;
goto exit;
}
stream_sess->produced = true;
int ret = worker_submit(session, ctx->comm, pkt);
/* Errors from worker_submit() are intentionally *not* handled
* in order to ensure the entire wire buffer is processed. */
if (ret == kr_ok()) {
iters += 1;
}
if (stream_discard_packet(session, wb, pkt, &pkt_error) < 0) {
/* Packet data isn't stored in memory as expected.
* something went wrong, normally should not happen. */
break;
}
}
/* worker_submit() may cause the session to close (e.g. due to IO
* write error when the packet triggers an immediate answer). This is
* an error state, as well as any wirebuf error. */
if (session->closing || pkt_error)
status = kr_error(EIO);
exit:
wire_buf_movestart(wb);
mp_flush(the_worker->pkt_pool.ctx);
if (status < 0)
session2_force_close(session);
return protolayer_break(ctx, status);
}
struct sized_iovs {
uint8_t nlen[2];
struct iovec iovs[];
};
static enum protolayer_iter_cb_result pl_dns_stream_wrap(
void *sess_data, void *iter_data, struct protolayer_iter_ctx *ctx)
{
if (ctx->payload.type == PROTOLAYER_PAYLOAD_BUFFER) {
if (kr_fails_assert(ctx->payload.buffer.len <= UINT16_MAX))
return protolayer_break(ctx, kr_error(EMSGSIZE));
const int iovcnt = 2;
struct sized_iovs *siov = mm_alloc(&ctx->pool,
sizeof(*siov) + iovcnt * sizeof(struct iovec));
kr_require(siov);
knot_wire_write_u16(siov->nlen, ctx->payload.buffer.len);
siov->iovs[0] = (struct iovec){
.iov_base = &siov->nlen,
.iov_len = sizeof(siov->nlen)
};
siov->iovs[1] = (struct iovec){
.iov_base = ctx->payload.buffer.buf,
.iov_len = ctx->payload.buffer.len
};
ctx->payload = protolayer_payload_iovec(siov->iovs, iovcnt, false);
return protolayer_continue(ctx);
} else if (ctx->payload.type == PROTOLAYER_PAYLOAD_IOVEC) {
const int iovcnt = 1 + ctx->payload.iovec.cnt;
struct sized_iovs *siov = mm_alloc(&ctx->pool,
sizeof(*siov) + iovcnt * sizeof(struct iovec));
kr_require(siov);
size_t total_len = 0;
for (int i = 0; i < ctx->payload.iovec.cnt; i++) {
const struct iovec *iov = &ctx->payload.iovec.iov[i];
total_len += iov->iov_len;
siov->iovs[i + 1] = *iov;
}
if (kr_fails_assert(total_len <= UINT16_MAX))
return protolayer_break(ctx, kr_error(EMSGSIZE));
knot_wire_write_u16(siov->nlen, total_len);
siov->iovs[0] = (struct iovec){
.iov_base = &siov->nlen,
.iov_len = sizeof(siov->nlen)
};
ctx->payload = protolayer_payload_iovec(siov->iovs, iovcnt, false);
return protolayer_continue(ctx);
} else {
kr_assert(false && "Invalid payload");
return protolayer_break(ctx, kr_error(EINVAL));
}
}
static void pl_dns_stream_request_init(struct session2 *session,
struct kr_request *req,
void *sess_data)
{
req->qsource.comm_flags.tcp = true;
}
__attribute__((constructor))
static void worker_protolayers_init(void)
{
protolayer_globals[PROTOLAYER_TYPE_DNS_DGRAM] = (struct protolayer_globals){
.wire_buf_overhead_cb = pl_dns_dgram_wire_buf_overhead,
.wire_buf_max_overhead = KNOT_WIRE_MAX_PKTSIZE,
.unwrap = pl_dns_dgram_unwrap,
.event_unwrap = pl_dns_dgram_event_unwrap
};
protolayer_globals[PROTOLAYER_TYPE_DNS_UNSIZED_STREAM] = (struct protolayer_globals){
.sess_size = sizeof(struct pl_dns_stream_sess_data),
.wire_buf_overhead = KNOT_WIRE_MAX_PKTSIZE,
.sess_init = pl_dns_stream_sess_init,
.unwrap = pl_dns_dgram_unwrap,
.event_unwrap = pl_dns_stream_event_unwrap,
.request_init = pl_dns_stream_request_init
};
const struct protolayer_globals stream_common = {
.sess_size = sizeof(struct pl_dns_stream_sess_data),
.wire_buf_overhead = KNOT_WIRE_MAX_PKTSIZE,
.sess_init = NULL, /* replaced in specific layers below */
.unwrap = pl_dns_stream_unwrap,
.wrap = pl_dns_stream_wrap,
.event_unwrap = pl_dns_stream_event_unwrap,
.request_init = pl_dns_stream_request_init
};
protolayer_globals[PROTOLAYER_TYPE_DNS_MULTI_STREAM] = stream_common;
protolayer_globals[PROTOLAYER_TYPE_DNS_MULTI_STREAM].sess_init = pl_dns_stream_sess_init;
protolayer_globals[PROTOLAYER_TYPE_DNS_SINGLE_STREAM] = stream_common;
protolayer_globals[PROTOLAYER_TYPE_DNS_SINGLE_STREAM].sess_init = pl_dns_single_stream_sess_init;
}
int worker_init(void)
{
if (kr_fails_assert(the_worker == NULL))
return kr_error(EINVAL);
kr_bindings_register(the_engine->L); // TODO move
/* Create main worker. */
the_worker = &the_worker_value;
memset(the_worker, 0, sizeof(*the_worker));
uv_loop_t *loop = uv_default_loop();
the_worker->loop = loop;
/* Register table for worker per-request variables */
struct lua_State *L = the_engine->L;
lua_newtable(L);
lua_setfield(L, -2, "vars");
lua_getfield(L, -1, "vars");
the_worker->vars_table_ref = luaL_ref(L, LUA_REGISTRYINDEX);
lua_pop(L, 1);
the_worker->tcp_pipeline_max = MAX_PIPELINED;
the_worker->out_addr4.sin_family = AF_UNSPEC;
the_worker->out_addr6.sin6_family = AF_UNSPEC;
array_init(the_worker->doh_qry_headers);
int ret = worker_reserve();
if (ret) return ret;
the_worker->next_request_uid = UINT16_MAX + 1;
/* Set some worker.* fields in Lua */
lua_getglobal(L, "worker");
pid_t pid = getpid();
auto_free char *pid_str = NULL;
const char *inst_name = getenv("SYSTEMD_INSTANCE");
if (inst_name) {
lua_pushstring(L, inst_name);
} else {
ret = asprintf(&pid_str, "%ld", (long)pid);
kr_assert(ret > 0);
lua_pushstring(L, pid_str);
}
lua_setfield(L, -2, "id");
lua_pushnumber(L, pid);
lua_setfield(L, -2, "pid");
char cwd[PATH_MAX];
get_workdir(cwd, sizeof(cwd));
lua_pushstring(L, cwd);
lua_setfield(L, -2, "cwd");
loop->data = the_worker;
/* ^^^^ Now this shouldn't be used anymore, but it's hard to be 100% sure. */
return kr_ok();
}
#undef VERBOSE_MSG
/* Copyright (C) 2014 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
/* Copyright (C) CZ.NIC, z.s.p.o. <knot-resolver@labs.nic.cz>
* SPDX-License-Identifier: GPL-3.0-or-later
*/
#pragma once
#include <libknot/internal/mempattern.h>
#include "daemon/engine.h"
#include "lib/generic/array.h"
#include "lib/generic/trie.h"
/** Query resolution task (opaque). */
struct qr_task;
/** Worker state. */
struct worker_ctx;
/** Transport session (opaque). */
struct session2;
/** Data about the communication (defined in io.h). */
struct comm_info;
/* @cond internal Freelist of available mempools. */
typedef array_t(void *) mp_freelist_t;
/* @endcond */
/** Pointer to the singleton worker. NULL if not initialized. */
KR_EXPORT extern struct worker_ctx *the_worker;
/** Create and initialize the worker.
* \return error code (ENOMEM) */
int worker_init(void);
/** Destroy the worker (free memory). */
void worker_deinit(void);
KR_EXPORT knot_pkt_t *worker_resolve_mk_pkt_dname(knot_dname_t *qname, uint16_t qtype, uint16_t qclass,
const struct kr_qflags *options);
/**
* Query resolution worker.
* Create a packet suitable for worker_resolve_start(). All in malloc() memory.
*/
struct worker_ctx {
struct engine *engine;
uv_loop_t *loop;
#if __linux__
uint8_t wire_buf[RECVMMSG_BATCH * KNOT_WIRE_MAX_PKTSIZE];
#else
uint8_t wire_buf[KNOT_WIRE_MAX_PKTSIZE];
#endif
struct {
size_t concurrent;
size_t udp;
size_t tcp;
size_t ipv4;
size_t ipv6;
} stats;
mp_freelist_t pools;
mp_freelist_t ioreqs;
mm_ctx_t pkt_pool;
};
KR_EXPORT knot_pkt_t *
worker_resolve_mk_pkt(const char *qname_str, uint16_t qtype, uint16_t qclass,
const struct kr_qflags *options);
/**
* Process incoming packet (query or answer to subrequest).
* @return 0 or an error code
* Start query resolution with given query.
*
* @return task or NULL
*/
int worker_exec(struct worker_ctx *worker, uv_handle_t *handle, knot_pkt_t *query, const struct sockaddr* addr);
KR_EXPORT struct qr_task *
worker_resolve_start(knot_pkt_t *query, struct kr_qflags options);
/**
* Schedule query for resolution.
* Execute a request with given query.
* It expects task to be created with \fn worker_resolve_start.
*
* @return 0 or an error code
*/
int worker_resolve(struct worker_ctx *worker, knot_pkt_t *query, unsigned options);
KR_EXPORT int worker_resolve_exec(struct qr_task *task, knot_pkt_t *query);
/** @return struct kr_request associated with opaque task */
struct kr_request *worker_task_request(struct qr_task *task);
int worker_task_step(struct qr_task *task, const struct sockaddr *packet_source,
knot_pkt_t *packet);
int worker_task_numrefs(const struct qr_task *task);
/** Finalize given task */
int worker_task_finalize(struct qr_task *task, int state);
void worker_task_complete(struct qr_task *task);
void worker_task_ref(struct qr_task *task);
void worker_task_unref(struct qr_task *task);
void worker_task_timeout_inc(struct qr_task *task);
knot_pkt_t *worker_task_get_pktbuf(const struct qr_task *task);
struct kr_transport *worker_task_get_transport(struct qr_task *task);
/** Note: source session is NULL in case the request hasn't come over network. */
KR_EXPORT struct session2 *worker_request_get_source_session(const struct kr_request *req);
uint16_t worker_task_pkt_get_msgid(struct qr_task *task);
void worker_task_pkt_set_msgid(struct qr_task *task, uint16_t msgid);
uint64_t worker_task_creation_time(struct qr_task *task);
void worker_task_subreq_finalize(struct qr_task *task);
bool worker_task_finished(struct qr_task *task);
/** To be called after sending a DNS message. It mainly deals with cleanups. */
int qr_task_on_send(struct qr_task *task, struct session2 *s, int status);
/** Various worker statistics. Sync with wrk_stats() */
struct worker_stats {
size_t queries; /**< Total number of requests (from clients and internal ones). */
size_t concurrent; /**< The number of requests currently in processing. */
size_t rconcurrent; /*< TODO: remove? I see no meaningful difference from .concurrent. */
size_t dropped; /**< The number of requests dropped due to being badly formed. See #471. */
size_t timeout; /**< Number of outbound queries that timed out. */
size_t udp; /**< Number of outbound queries over UDP. */
size_t tcp; /**< Number of outbound queries over TCP (excluding TLS). */
size_t tls; /**< Number of outbound queries over TLS. */
size_t ipv4; /**< Number of outbound queries over IPv4.*/
size_t ipv6; /**< Number of outbound queries over IPv6. */
size_t err_udp; /**< Total number of write errors for UDP transport. */
size_t err_tcp; /**< Total number of write errors for TCP transport. */
size_t err_tls; /**< Total number of write errors for TLS transport. */
size_t err_http; /**< Total number of write errors for HTTP(S) transport. */
};
/** @cond internal */
/** Number of request within timeout window. */
#define MAX_PENDING 4
/** Maximum response time from TCP upstream, milliseconds */
#define MAX_TCP_INACTIVITY (KR_RESOLVE_TIME_LIMIT + KR_CONN_RTT_MAX)
#ifndef RECVMMSG_BATCH /* see check_bufsize() */
#define RECVMMSG_BATCH 1
#endif
/** List of query resolution tasks. */
typedef array_t(struct qr_task *) qr_tasklist_t;
/** List of HTTP header names. */
typedef array_t(const char *) doh_headerlist_t;
/** \details Worker state is meant to persist during the whole life of daemon. */
struct worker_ctx {
uv_loop_t *loop;
int count; /** unreliable, does not count systemd instance, do not use */
int vars_table_ref;
unsigned tcp_pipeline_max;
/** Addresses to bind for outgoing connections or AF_UNSPEC. */
struct sockaddr_in out_addr4;
struct sockaddr_in6 out_addr6;
struct worker_stats stats;
bool too_many_open;
size_t rconcurrent_highwatermark;
/** List of active outbound TCP sessions */
trie_t *tcp_connected;
/** List of outbound TCP sessions waiting to be accepted */
trie_t *tcp_waiting;
/** Subrequest leaders (struct qr_task*), indexed by qname+qtype+qclass. */
trie_t *subreq_out;
knot_mm_t pkt_pool;
unsigned int next_request_uid;
/* HTTP Headers for DoH. */
doh_headerlist_t doh_qry_headers;
};
/** Reserve worker buffers */
int worker_reserve(struct worker_ctx *worker, size_t ring_maxlen);
/** @endcond */
/** Collect worker mempools */
void worker_reclaim(struct worker_ctx *worker);
/* Copyright (C) CZ.NIC, z.s.p.o. <knot-resolver@labs.nic.cz>
* SPDX-License-Identifier: GPL-3.0-or-later
*/
/* Module is intended to import resource records from file into resolver's cache.
* File supposed to be a standard DNS zone file
* which contains text representations of resource records.
* For now only root zone import is supported.
*
* Import process consists of two stages.
* 1) Zone file parsing and (optionally) ZONEMD verification.
* 2) DNSSEC validation and storage in cache.
*
* These stages are implemented as two separate functions
* (zi_zone_import and zi_zone_process) which run sequentially with a
* pause between them. This is done because resolver is a single-threaded
* application, so it can't process user's requests during the whole import
* process. Separation into two stages allows to reduce the
* continuous time interval when resolver can't serve user requests.
* Since root zone isn't large, it is imported as single chunk.
*/
#include "daemon/zimport.h"
#include <inttypes.h> /* PRIu64 */
#include <limits.h>
#include <math.h>
#include <stdlib.h>
#include <sys/stat.h>
#include <uv.h>
#include <libknot/rrset.h>
#include <libzscanner/scanner.h>
#include <libdnssec/digest.h>
#include "daemon/worker.h"
#include "lib/dnssec/ta.h"
#include "lib/dnssec.h"
#include "lib/generic/trie.h"
#include "lib/utils.h"
/* Pause between parse and import stages, milliseconds. */
#define ZONE_IMPORT_PAUSE 100
// NAN normally comes from <math.h> but it's not guaranteed.
#ifndef NAN
#define NAN nan("")
#endif
struct zone_import_ctx {
knot_mm_t *pool; /// memory pool for all allocations (including struct itself)
knot_dname_t *origin;
uv_timer_t timer;
// from zi_config_t
zi_callback cb;
void *cb_param;
trie_t *rrsets; /// map: key_get() -> knot_rrset_t*, in ZONEMD order
uint32_t timestamp_rr; /// stamp of when RR data arrived (seconds since epoch)
struct kr_svldr_ctx *svldr; /// DNSSEC validator; NULL iff we don't validate
const knot_dname_t *last_cut; /// internal to zi_rrset_import()
uint8_t *digest_buf; /// temporary buffer for digest computation (on pool)
#define DIGEST_BUF_SIZE (64*1024 - 1)
#define DIGEST_ALG_COUNT 2
struct {
bool active; /// whether we want it computed
dnssec_digest_ctx_t *ctx;
const uint8_t *expected; /// expected digest (inside zonemd on pool)
} digests[DIGEST_ALG_COUNT]; /// we use indices 0 and 1 for SHA 384 and 512
};
typedef struct zone_import_ctx zone_import_ctx_t;
#define KEY_LEN (KNOT_DNAME_MAXLEN + 1 + 2 + 2)
/** Construct key for name, type and signed type (if type == RRSIG).
*
* Return negative error code in asserted cases.
*/
static int key_get(char buf[KEY_LEN], const knot_dname_t *name,
uint16_t type, uint16_t type_maysig, char **key_p)
{
char *lf = (char *)knot_dname_lf(name, (uint8_t *)buf);
if (kr_fails_assert(lf && key_p))
return kr_error(EINVAL);
int len = (unsigned char)lf[0];
lf++; // point to start of data
*key_p = lf;
// Check that LF is right-aligned to KNOT_DNAME_MAXLEN in buf.
if (kr_fails_assert(lf + len == buf + KNOT_DNAME_MAXLEN))
return kr_error(EINVAL);
buf[KNOT_DNAME_MAXLEN] = 0; // this ensures correct ZONEMD order
memcpy(buf + KNOT_DNAME_MAXLEN + 1, &type, sizeof(type));
len += 1 + sizeof(type);
if (type == KNOT_RRTYPE_RRSIG) {
memcpy(buf + KNOT_DNAME_MAXLEN + 1 + sizeof(type),
&type_maysig, sizeof(type_maysig));
len += sizeof(type_maysig);
}
return len;
}
/** Simple helper to retreive from zone_import_ctx_t::rrsets */
static knot_rrset_t * rrset_get(trie_t *rrsets, const knot_dname_t *name,
uint16_t type, uint16_t type_maysig)
{
char key_buf[KEY_LEN], *key;
const int len = key_get(key_buf, name, type, type_maysig, &key);
if (len < 0)
return NULL;
const trie_val_t *rrsig_p = trie_get_try(rrsets, key, len);
if (!rrsig_p)
return NULL;
kr_assert(*rrsig_p);
return *rrsig_p;
}
static int digest_rrset(trie_val_t *rr_p, void *z_import_v)
{
zone_import_ctx_t *z_import = z_import_v;
const knot_rrset_t *rr = *rr_p;
// ignore apex ZONEMD or its RRSIG, and also out of bailiwick records
const int origin_bailiwick = knot_dname_in_bailiwick(rr->owner, z_import->origin);
const bool is_apex = origin_bailiwick == 0;
if (is_apex && kr_rrset_type_maysig(rr) == KNOT_RRTYPE_ZONEMD)
return KNOT_EOK;
if (unlikely(origin_bailiwick < 0))
return KNOT_EOK;
const int len = knot_rrset_to_wire_extra(rr, z_import->digest_buf, DIGEST_BUF_SIZE,
0, NULL, KNOT_PF_ORIGTTL);
if (len < 0)
return kr_error(len);
// digest serialized RRSet
for (int i = 0; i < DIGEST_ALG_COUNT; ++i) {
if (!z_import->digests[i].active)
continue;
dnssec_binary_t bufbin = { len, z_import->digest_buf };
int ret = dnssec_digest(z_import->digests[i].ctx, &bufbin);
if (ret != KNOT_EOK)
return kr_error(ret);
}
return KNOT_EOK;
}
/** Verify ZONEMD in the stored zone, and return error code.
*
* ZONEMD signature is verified iff z_import->svldr != NULL
https://www.rfc-editor.org/rfc/rfc8976.html#name-verifying-zone-digest
*/
static int zonemd_verify(zone_import_ctx_t *z_import)
{
bool zonemd_is_valid = false;
// Find ZONEMD RR + RRSIG
knot_rrset_t * const rr_zonemd
= rrset_get(z_import->rrsets, z_import->origin, KNOT_RRTYPE_ZONEMD, 0);
if (!rr_zonemd) {
// no zonemd; let's compute the shorter digest and print info later
z_import->digests[KNOT_ZONEMD_ALGORITHM_SHA384 - 1].active = true;
goto do_digest;
}
// Validate ZONEMD RRSIG, if desired
if (z_import->svldr) {
const knot_rrset_t *rrsig_zonemd
= rrset_get(z_import->rrsets, z_import->origin,
KNOT_RRTYPE_RRSIG, KNOT_RRTYPE_ZONEMD);
int ret = rrsig_zonemd
? kr_svldr_rrset(rr_zonemd, &rrsig_zonemd->rrs, z_import->svldr)
: kr_error(ENOENT);
zonemd_is_valid = (ret == kr_ok());
if (!rrsig_zonemd) {
kr_log_error(PREFILL, "ZONEMD signature missing\n");
} else if (!zonemd_is_valid) {
kr_log_error(PREFILL, "ZONEMD signature failed to validate\n");
}
}
// Get SOA serial
const knot_rrset_t *soa = rrset_get(z_import->rrsets, z_import->origin,
KNOT_RRTYPE_SOA, 0);
if (!soa) {
kr_log_error(PREFILL, "SOA record not found\n");
return kr_error(ENOENT);
}
if (soa->rrs.count != 1) {
kr_log_error(PREFILL, "the SOA RR set is weird\n");
return kr_error(EINVAL);
} // length is checked by parser already
const uint32_t soa_serial = knot_soa_serial(soa->rrs.rdata);
// Figure out SOA+ZONEMD RR contents.
bool some_active = false;
knot_rdata_t *rd = rr_zonemd->rrs.rdata;
for (int i = 0; i < rr_zonemd->rrs.count; ++i, rd = knot_rdataset_next(rd)) {
if (rd->len < 6 || knot_zonemd_scheme(rd) != KNOT_ZONEMD_SCHEME_SIMPLE
|| knot_zonemd_soa_serial(rd) != soa_serial)
continue;
const int algo = knot_zonemd_algorithm(rd);
if (algo != KNOT_ZONEMD_ALGORITHM_SHA384 && algo != KNOT_ZONEMD_ALGORITHM_SHA512)
continue;
if (rd->len != 6 + knot_zonemd_digest_size(rd)) {
kr_log_error(PREFILL, "ZONEMD record has incorrect digest length\n");
return kr_error(EINVAL);
}
if (z_import->digests[algo - 1].active) {
kr_log_error(PREFILL, "multiple clashing ZONEMD records found\n");
return kr_error(EINVAL);
}
some_active = true;
z_import->digests[algo - 1].active = true;
z_import->digests[algo - 1].expected = knot_zonemd_digest(rd);
}
if (!some_active) {
kr_log_error(PREFILL, "ZONEMD record(s) found but none were usable\n");
return kr_error(ENOENT);
}
do_digest:
// Init memory, etc.
if (!z_import->digest_buf) {
z_import->digest_buf = mm_alloc(z_import->pool, DIGEST_BUF_SIZE);
if (!z_import->digest_buf)
return kr_error(ENOMEM);
}
for (int i = 0; i < DIGEST_ALG_COUNT; ++i) {
const int algo = i + 1;
if (!z_import->digests[i].active)
continue;
int ret = dnssec_digest_init(algo, &z_import->digests[i].ctx);
if (ret != KNOT_EOK) {
// free previous successful _ctx, if applicable
dnssec_binary_t digest = { 0 };
while (--i >= 0) {
if (z_import->digests[i].active)
dnssec_digest_finish(z_import->digests[i].ctx,
&digest);
}
return kr_error(ENOMEM);
}
}
// Actually compute the digest(s).
int ret = trie_apply(z_import->rrsets, digest_rrset, z_import);
dnssec_binary_t digs[DIGEST_ALG_COUNT] = { { 0 } };
for (int i = 0; i < DIGEST_ALG_COUNT; ++i) {
if (!z_import->digests[i].active)
continue;
int ret2 = dnssec_digest_finish(z_import->digests[i].ctx, &digs[i]);
if (ret == DNSSEC_EOK)
ret = ret2;
// we need to keep going to free all digests[*].ctx
}
if (ret != DNSSEC_EOK) {
for (int i = 0; i < DIGEST_ALG_COUNT; ++i)
free(digs[i].data);
kr_log_error(PREFILL, "error when computing digest: %s\n",
kr_strerror(ret));
return kr_error(ret);
}
// Now only check that one of the hashes match.
bool has_match = false;
for (int i = 0; i < DIGEST_ALG_COUNT; ++i) {
if (!z_import->digests[i].active)
continue;
// hexdump the hash for logging
char hash_str[digs[i].size * 2 + 1];
for (ssize_t j = 0; j < digs[i].size; ++j)
(void)sprintf(hash_str + 2*j, "%02x", digs[i].data[j]);
if (!z_import->digests[i].expected) {
kr_log_error(PREFILL, "no ZONEMD found; computed hash: %s\n",
hash_str);
} else if (memcmp(z_import->digests[i].expected, digs[i].data,
digs[i].size) != 0) {
kr_log_error(PREFILL, "ZONEMD hash mismatch; computed hash: %s\n",
hash_str);
} else {
kr_log_debug(PREFILL, "ZONEMD hash matches\n");
has_match = true;
continue;
}
}
for (int i = 0; i < DIGEST_ALG_COUNT; ++i)
free(digs[i].data);
bool ok = has_match && (zonemd_is_valid || !z_import->svldr);
return ok ? kr_ok() : kr_error(ENOENT);
}
/**
* @internal Import given rrset to cache.
*
* @return error code; we could've chosen to keep importing even if some RRset fails,
* but it would be harder to ensure that we don't generate too many logs
* and that we pass an error to the finishing callback.
*/
static int zi_rrset_import(trie_val_t *rr_p, void *z_import_v)
{
zone_import_ctx_t *z_import = z_import_v;
knot_rrset_t *rr = *rr_p;
if (rr->type == KNOT_RRTYPE_RRSIG)
return 0; // we do RRSIGs at once with their types
const int origin_bailiwick = knot_dname_in_bailiwick(rr->owner, z_import->origin);
if (unlikely(origin_bailiwick < 0)) {
KR_DNAME_GET_STR(owner_str, rr->owner);
kr_log_warning(PREFILL, "ignoring out of bailiwick record(s) on %s\n",
owner_str);
return 0; // well, let's continue without error
}
// Determine if this RRset is authoritative.
// We utilize that iteration happens in canonical order.
bool is_auth;
const int kdib = knot_dname_in_bailiwick(rr->owner, z_import->last_cut);
if (kdib == 0 && (rr->type == KNOT_RRTYPE_DS || rr->type == KNOT_RRTYPE_NSEC
|| rr->type == KNOT_RRTYPE_NSEC3)) {
// parent side of the zone cut (well, presumably in case of NSEC*)
is_auth = true;
} else if (kdib >= 0) {
// inside non-auth subtree
is_auth = false;
} else if (rr->type == KNOT_RRTYPE_NS && origin_bailiwick > 0) {
// entering non-auth subtree
z_import->last_cut = rr->owner;
is_auth = false;
} else {
// outside non-auth subtree
is_auth = true;
z_import->last_cut = NULL; // so that the next _in_bailiwick() is faster
}
// Rare case: `A` exactly on zone cut would be misdetected and fail validation;
// it's the only type ordered before NS.
if (unlikely(is_auth && rr->type < KNOT_RRTYPE_NS)) {
if (rrset_get(z_import->rrsets, rr->owner, KNOT_RRTYPE_NS, 0))
is_auth = false;
}
// Get and validate the corresponding RRSIGs, if authoritative.
const knot_rrset_t *rrsig = NULL;
if (is_auth) {
rrsig = rrset_get(z_import->rrsets, rr->owner, KNOT_RRTYPE_RRSIG, rr->type);
if (unlikely(!rrsig && z_import->svldr)) {
KR_DNAME_GET_STR(owner_str, rr->owner);
KR_RRTYPE_GET_STR(type_str, rr->type);
kr_log_error(PREFILL, "no records found for %s RRSIG %s\n",
owner_str, type_str);
return kr_error(ENOENT);
}
}
if (is_auth && z_import->svldr) {
int ret = kr_svldr_rrset(rr, &rrsig->rrs, z_import->svldr);
if (unlikely(ret)) {
KR_DNAME_GET_STR(owner_str, rr->owner);
KR_RRTYPE_GET_STR(type_str, rr->type);
kr_log_error(PREFILL, "validation failed for %s %s: %s\n",
owner_str, type_str, kr_strerror(ret));
return kr_error(ret);
}
}
uint8_t rank;
if (!is_auth) {
rank = KR_RANK_OMIT;
} else if (z_import->svldr) {
rank = KR_RANK_AUTH|KR_RANK_SECURE;
} else {
rank = KR_RANK_AUTH|KR_RANK_INSECURE;
}
int ret = kr_cache_insert_rr(&the_resolver->cache, rr, rrsig,
rank, z_import->timestamp_rr,
// Optim.: only stash NSEC* params at the apex.
origin_bailiwick == 0);
if (ret) {
kr_log_error(PREFILL, "caching an RRset failed: %s\n",
kr_strerror(ret));
return kr_error(ret);
}
return 0; // success
}
static void ctx_delete(zone_import_ctx_t *z_import)
{
if (kr_fails_assert(z_import)) return;
kr_svldr_free_ctx(z_import->svldr);
/* Free `z_import`'s pool, including `z_import` itself, because it is
* allocated inside said pool. */
mm_ctx_delete(z_import->pool);
}
static void timer_close(uv_handle_t *handle)
{
ctx_delete(handle->data);
}
/** @internal Iterate over parsed rrsets and try to import each of them. */
static void zi_zone_process(uv_timer_t *timer)
{
zone_import_ctx_t *z_import = timer->data;
kr_timer_t stopwatch;
kr_timer_start(&stopwatch);
int ret = trie_apply(z_import->rrsets, zi_rrset_import, z_import);
(void)kr_cache_commit(&the_resolver->cache); // RW transaction open
if (ret == 0) {
kr_log_info(PREFILL, "performance: validating and caching took %.3lf s\n",
kr_timer_elapsed(&stopwatch));
}
if (z_import->cb)
z_import->cb(kr_error(ret), z_import->cb_param);
uv_close((uv_handle_t *)timer, timer_close);
}
/** @internal Store rrset that has been imported to zone import context memory pool.
* @return -1 if failed; 0 if success. */
static int zi_record_store(zs_scanner_t *s)
{
if (s->r_data_length > UINT16_MAX) {
/* Due to knot_rrset_add_rdata(..., const uint16_t size, ...); */
kr_log_error(PREFILL, "line %"PRIu64": rdata is too long\n",
s->line_counter);
return -1;
}
if (knot_dname_size(s->r_owner) != strlen((const char *)(s->r_owner)) + 1) {
kr_log_error(PREFILL, "line %"PRIu64
": owner name contains zero byte, skip\n",
s->line_counter);
return 0;
}
zone_import_ctx_t *z_import = (zone_import_ctx_t *)s->process.data;
knot_rrset_t *new_rr = knot_rrset_new(s->r_owner, s->r_type, s->r_class,
s->r_ttl, z_import->pool);
if (!new_rr) {
kr_log_error(PREFILL, "line %"PRIu64": error creating rrset\n",
s->line_counter);
return -1;
}
int res = knot_rrset_add_rdata(new_rr, s->r_data, s->r_data_length,
z_import->pool);
if (res != KNOT_EOK) {
kr_log_error(PREFILL, "line %"PRIu64": error adding rdata to rrset\n",
s->line_counter);
return -1;
}
/* zscanner itself does not canonize - neither owner nor insides */
res = knot_rrset_rr_to_canonical(new_rr);
if (res != KNOT_EOK) {
kr_log_error(PREFILL, "line %"PRIu64": error when canonizing: %s\n",
s->line_counter, knot_strerror(res));
return -1;
}
/* Records in zone file may not be grouped by name and RR type.
* Use map to create search key and
* avoid ineffective searches across all the imported records. */
char key_buf[KEY_LEN], *key;
const int len = key_get(key_buf, new_rr->owner, new_rr->type,
kr_rrset_type_maysig(new_rr), &key);
if (len < 0) {
kr_log_error(PREFILL, "line %"PRIu64": error constructing rrkey\n",
s->line_counter);
return -1;
}
trie_val_t *rr_p = trie_get_ins(z_import->rrsets, key, len);
if (!rr_p)
return -1; // ENOMEM
if (*rr_p) {
knot_rrset_t *rr = *rr_p;
res = knot_rdataset_merge(&rr->rrs, &new_rr->rrs, z_import->pool);
} else {
*rr_p = new_rr;
}
if (res != 0) {
kr_log_error(PREFILL, "line %"PRIu64": error saving parsed rrset\n",
s->line_counter);
return -1;
}
return 0;
}
static int zi_state_parsing(zs_scanner_t *s)
{
bool empty = true;
while (zs_parse_record(s) == 0) {
switch (s->state) {
case ZS_STATE_DATA:
if (zi_record_store(s) != 0) {
return -1;
}
zone_import_ctx_t *z_import = (zone_import_ctx_t *) s->process.data;
empty = false;
if (s->r_type == KNOT_RRTYPE_SOA) {
z_import->origin = knot_dname_copy(s->r_owner,
z_import->pool);
}
break;
case ZS_STATE_ERROR:
kr_log_error(PREFILL, "line: %"PRIu64
": parse error; code: %i ('%s')\n",
s->line_counter, s->error.code,
zs_strerror(s->error.code));
return -1;
case ZS_STATE_INCLUDE:
kr_log_error(PREFILL, "line: %"PRIu64
": INCLUDE is not supported\n",
s->line_counter);
return -1;
case ZS_STATE_EOF:
case ZS_STATE_STOP:
if (empty) {
kr_log_error(PREFILL, "empty zone file\n");
return -1;
}
if (!((zone_import_ctx_t *) s->process.data)->origin) {
kr_log_error(PREFILL, "zone file doesn't contain SOA record\n");
return -1;
}
return (s->error.counter == 0) ? 0 : -1;
default:
kr_log_error(PREFILL, "line: %"PRIu64
": unexpected parse state: %i\n",
s->line_counter, s->state);
return -1;
}
}
return -1;
}
int zi_zone_import(const zi_config_t config)
{
const zi_config_t *c = &config;
if (kr_fails_assert(c && c->zone_file))
return kr_error(EINVAL);
knot_mm_t *pool = mm_ctx_mempool2((size_t)1024 * 1024);
zone_import_ctx_t *z_import = mm_calloc(pool, 1, sizeof(*z_import));
if (!z_import) return kr_error(ENOMEM);
z_import->pool = pool;
z_import->cb = c->cb;
z_import->cb_param = c->cb_param;
z_import->rrsets = trie_create(z_import->pool);
kr_timer_t stopwatch;
kr_timer_start(&stopwatch);
//// Parse the whole zone file into z_import->rrsets.
zs_scanner_t s_storage, *s = &s_storage;
/* zs_init(), zs_set_input_file(), zs_set_processing() returns -1 in case of error,
* so don't print error code as it meaningless. */
int ret = zs_init(s, c->origin, KNOT_CLASS_IN, c->ttl);
if (ret != 0) {
kr_log_error(PREFILL, "error initializing zone scanner instance, error: %i (%s)\n",
s->error.code, zs_strerror(s->error.code));
goto fail;
}
ret = zs_set_input_file(s, c->zone_file);
if (ret != 0) {
kr_log_error(PREFILL, "error opening zone file `%s`, error: %i (%s)\n",
c->zone_file, s->error.code, zs_strerror(s->error.code));
zs_deinit(s);
goto fail;
}
/* Don't set processing and error callbacks as we don't use automatic parsing.
* Parsing as well error processing will be performed in zi_state_parsing().
* Store pointer to zone import context for further use. */
ret = zs_set_processing(s, NULL, NULL, (void *)z_import);
if (ret != 0) {
kr_log_error(PREFILL, "zs_set_processing() failed for zone file `%s`, "
"error: %i (%s)\n",
c->zone_file, s->error.code, zs_strerror(s->error.code));
zs_deinit(s);
goto fail;
}
ret = zi_state_parsing(s);
zs_deinit(s);
const double time_parse = kr_timer_elapsed(&stopwatch);
if (ret != 0) {
kr_log_error(PREFILL, "error parsing zone file `%s`\n", c->zone_file);
goto fail;
}
kr_log_debug(PREFILL, "import started for zone file `%s`\n", c->zone_file);
KR_DNAME_GET_STR(zone_name_str, z_import->origin);
//// Choose timestamp_rr, according to config.
struct timespec now;
if (clock_gettime(CLOCK_REALTIME, &now)) {
ret = kr_error(errno);
kr_log_error(PREFILL, "failed to get current time: %s\n", kr_strerror(ret));
goto fail;
}
if (config.time_src == ZI_STAMP_NOW) {
z_import->timestamp_rr = now.tv_sec;
} else if (config.time_src == ZI_STAMP_MTIM) {
struct stat st;
if (stat(c->zone_file, &st) != 0) {
kr_log_debug(PREFILL, "failed to stat file `%s`: %s\n",
c->zone_file, strerror(errno));
goto fail;
}
z_import->timestamp_rr = st.st_mtime;
} else {
ret = kr_error(EINVAL);
goto fail;
}
//// Some sanity checks
const knot_rrset_t *soa = rrset_get(z_import->rrsets, z_import->origin,
KNOT_RRTYPE_SOA, 0);
if (z_import->timestamp_rr > now.tv_sec) {
kr_log_warning(PREFILL, "zone file `%s` comes from future\n", c->zone_file);
} else if (!soa) {
kr_log_warning(PREFILL, "missing %s SOA\n", zone_name_str);
} else if ((int64_t)z_import->timestamp_rr + soa->ttl < now.tv_sec) {
kr_log_warning(PREFILL, "%s SOA already expired\n", zone_name_str);
}
//// Initialize validator context with the DNSKEY.
if (c->downgrade)
goto zonemd;
const knot_rrset_t * const ds = c->ds ? c->ds :
kr_ta_get(the_resolver->trust_anchors, z_import->origin);
if (!ds) {
if (!kr_ta_closest(the_resolver, z_import->origin, KNOT_RRTYPE_DNSKEY))
goto zonemd; // our TAs say we're insecure
kr_log_error(PREFILL, "no DS found for `%s`, fail\n", zone_name_str);
ret = kr_error(ENOENT);
goto fail;
}
if (!knot_dname_is_equal(ds->owner, z_import->origin)) {
kr_log_error(PREFILL, "mismatching DS owner, fail\n");
ret = kr_error(EINVAL);
goto fail;
}
knot_rrset_t * const dnskey = rrset_get(z_import->rrsets, z_import->origin,
KNOT_RRTYPE_DNSKEY, 0);
if (!dnskey) {
kr_log_error(PREFILL, "no DNSKEY found for `%s`, fail\n", zone_name_str);
ret = kr_error(ENOENT);
goto fail;
}
knot_rrset_t * const dnskey_sigs = rrset_get(z_import->rrsets, z_import->origin,
KNOT_RRTYPE_RRSIG, KNOT_RRTYPE_DNSKEY);
if (!dnskey_sigs) {
kr_log_error(PREFILL, "no RRSIGs for DNSKEY found for `%s`, fail\n",
zone_name_str);
ret = kr_error(ENOENT);
goto fail;
}
kr_rrset_validation_ctx_t err_ctx;
z_import->svldr = kr_svldr_new_ctx(ds, dnskey, &dnskey_sigs->rrs,
z_import->timestamp_rr, &err_ctx);
if (!z_import->svldr) {
// log RRSIG stats; very similar to log_bogus_rrsig()
kr_log_error(PREFILL, "failed to validate DNSKEY for `%s` "
"(%u matching RRSIGs, %u expired, %u not yet valid, "
"%u invalid signer, %u invalid label count, %u invalid key, "
"%u invalid crypto, %u invalid NSEC)\n",
zone_name_str,
err_ctx.rrs_counters.matching_name_type,
err_ctx.rrs_counters.expired, err_ctx.rrs_counters.notyet,
err_ctx.rrs_counters.signer_invalid,
err_ctx.rrs_counters.labels_invalid,
err_ctx.rrs_counters.key_invalid,
err_ctx.rrs_counters.crypto_invalid,
err_ctx.rrs_counters.nsec_invalid);
ret = kr_error(ENOENT);
goto fail;
}
//// Do all ZONEMD processing, if desired.
zonemd: (void)0; // C can't have a variable definition following a label
double time_zonemd = NAN;
if (c->zonemd) {
kr_timer_start(&stopwatch);
ret = zonemd_verify(z_import);
time_zonemd = kr_timer_elapsed(&stopwatch);
} else {
ret = kr_ok();
}
kr_log_info(PREFILL, "performance: parsing took %.3lf s, hashing took %.3lf s\n",
time_parse, time_zonemd);
if (ret) goto fail;
//// Phase two, after a pause. Validate and import all the remaining records.
ret = uv_timer_init(the_worker->loop, &z_import->timer);
if (ret) goto fail;
z_import->timer.data = z_import;
ret = uv_timer_start(&z_import->timer, zi_zone_process, ZONE_IMPORT_PAUSE, 0);
if (ret) goto fail;
return kr_ok();
fail:
if (z_import->cb)
z_import->cb(kr_error(ret), z_import->cb_param);
if (kr_fails_assert(ret))
ret = ENOENT;
ctx_delete(z_import);
return kr_error(ret);
}