Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • knot/knot-resolver
  • dkg/resolver
  • sbalazik/resolver
  • anb/knot-resolver
  • tkrizek/knot-resolver
  • jono/knot-resolver
  • analogic/knot-resolver
  • flokli/knot-resolver
  • hectorm/knot-resolver
  • aisha/knot-resolver
10 results
Show changes
Showing
with 7468 additions and 578 deletions
/* Copyright (C) 2015 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
/* Copyright (C) CZ.NIC, z.s.p.o. <knot-resolver@labs.nic.cz>
* SPDX-License-Identifier: GPL-3.0-or-later
*/
#include <unistd.h>
#include <assert.h>
#include "daemon/network.h"
#include "daemon/worker.h"
#include "contrib/cleanup.h"
#include "daemon/bindings/impl.h"
#include "daemon/io.h"
#include "daemon/tls.h"
#include "daemon/worker.h"
#include "lib/utils.h"
/* libuv 1.7.0+ is able to support SO_REUSEPORT for loadbalancing */
#if defined(UV_VERSION_HEX)
#if (__linux__ && SO_REUSEPORT)
#define handle_init(type, loop, handle, family) do { \
uv_ ## type ## _init_ex((loop), (handle), (family)); \
uv_os_fd_t fd = 0; \
if (uv_fileno((uv_handle_t *)(handle), &fd) == 0) { \
int on = 1; \
int ret = setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &on, sizeof(on)); \
if (ret) { \
return ret; \
} \
} \
} while (0)
/* libuv 1.7.0+ is able to assign fd immediately */
#else
#define handle_init(type, loop, handle, family) do { \
uv_ ## type ## _init_ex((loop), (handle), (family)); \
} while (0)
#endif
#else
#define handle_init(type, loop, handle, family) \
uv_ ## type ## _init((loop), (handle))
#if ENABLE_XDP
#include <libknot/xdp/eth.h>
#endif
void network_init(struct network *net, uv_loop_t *loop)
#include <libgen.h>
#include <net/if.h>
#include <sys/un.h>
#include <unistd.h>
/** Determines the type of `struct endpoint_key`. */
enum endpoint_key_type
{
ENDPOINT_KEY_SOCKADDR = 1,
ENDPOINT_KEY_IFNAME = 2,
};
/** Used as a key in the `struct network::endpoints` trie. */
struct endpoint_key {
enum endpoint_key_type type;
char data[];
};
struct __attribute__((packed)) endpoint_key_sockaddr {
enum endpoint_key_type type;
struct kr_sockaddr_key_storage sa_key;
};
struct __attribute__((packed)) endpoint_key_ifname {
enum endpoint_key_type type;
char ifname[128];
};
/** Used for reserving enough storage for `endpoint_key`. */
struct endpoint_key_storage {
union {
enum endpoint_key_type type;
struct endpoint_key_sockaddr sa;
struct endpoint_key_ifname ifname;
char bytes[1]; /* for easier casting */
};
};
static_assert(_Alignof(struct endpoint_key) <= 4, "endpoint_key must be aligned to <=4");
static_assert(_Alignof(struct endpoint_key_sockaddr) <= 4, "endpoint_key must be aligned to <=4");
static_assert(_Alignof(struct endpoint_key_ifname) <= 4, "endpoint_key must be aligned to <=4");
static struct network the_network_value = {0};
struct network *the_network = NULL;
void network_init(uv_loop_t *loop, int tcp_backlog)
{
the_network = &the_network_value;
the_network->loop = loop;
the_network->endpoints = trie_create(NULL);
the_network->endpoint_kinds = trie_create(NULL);
the_network->proxy_all4 = false;
the_network->proxy_all6 = false;
the_network->proxy_addrs4 = trie_create(NULL);
the_network->proxy_addrs6 = trie_create(NULL);
the_network->tls_client_params = NULL;
the_network->tls_session_ticket_ctx = /* unsync. random, by default */
tls_session_ticket_ctx_create(loop, NULL, 0);
the_network->tcp.in_idle_timeout = 10000;
the_network->tcp.tls_handshake_timeout = TLS_MAX_HANDSHAKE_TIME;
the_network->tcp.user_timeout = 1000; // 1s should be more than enough
the_network->tcp_backlog = tcp_backlog;
the_network->enable_connect_udp = true;
// On Linux, unset means some auto-tuning mechanism also depending on RAM,
// which might be OK default (together with the user_timeout above)
//the_network->listen_{tcp,udp}_buflens.{snd,rcv}
}
/** Notify the registered function about endpoint getting open.
* If log_port < 1, don't log it. */
static int endpoint_open_lua_cb(struct endpoint *ep,
const char *log_addr)
{
const bool ok = ep->flags.kind && !ep->handle && !ep->engaged && ep->fd != -1;
if (kr_fails_assert(ok))
return kr_error(EINVAL);
/* First find callback in the endpoint registry. */
lua_State *L = the_engine->L;
void **pp = trie_get_try(the_network->endpoint_kinds, ep->flags.kind,
strlen(ep->flags.kind));
if (!pp && the_network->missing_kind_is_error) {
kr_log_error(NETWORK, "error: network socket kind '%s' not handled when opening '%s",
ep->flags.kind, log_addr);
if (ep->family != AF_UNIX)
kr_log_error(NETWORK, "#%d", ep->port);
kr_log_error(NETWORK, "'\n");
return kr_error(ENOENT);
}
if (!pp) return kr_ok();
/* Now execute the callback. */
const int fun_id = (intptr_t)*pp;
lua_rawgeti(L, LUA_REGISTRYINDEX, fun_id);
lua_pushboolean(L, true /* open */);
lua_pushpointer(L, ep);
if (ep->family == AF_UNIX) {
lua_pushstring(L, log_addr);
} else {
lua_pushfstring(L, "%s#%d", log_addr, ep->port);
}
if (lua_pcall(L, 3, 0, 0)) {
kr_log_error(NETWORK, "error opening %s: %s\n", log_addr, lua_tostring(L, -1));
return kr_error(ENOSYS); /* TODO: better value? */
}
ep->engaged = true;
return kr_ok();
}
static int engage_endpoint_array(const char *b_key, uint32_t key_len, trie_val_t *val, void *net)
{
if (net != NULL) {
net->loop = loop;
net->endpoints = map_make();
const char *log_addr = network_endpoint_key_str((struct endpoint_key *) b_key);
if (!log_addr)
log_addr = "[unknown]";
endpoint_array_t *eps = *val;
for (int i = 0; i < eps->len; ++i) {
struct endpoint *ep = &eps->at[i];
const bool match = !ep->engaged && ep->flags.kind;
if (!match) continue;
int ret = endpoint_open_lua_cb(ep, log_addr);
if (ret) return ret;
}
return 0;
}
static void close_handle(uv_handle_t *handle, bool force)
int network_engage_endpoints(void)
{
if (force) { /* Force close if event loop isn't running. */
uv_os_fd_t fd = 0;
if (uv_fileno(handle, &fd) == 0) {
close(fd);
}
handle->loop = NULL;
io_free(handle);
} else { /* Asynchronous close */
uv_close(handle, io_free);
if (the_network->missing_kind_is_error)
return kr_ok(); /* maybe weird, but let's make it idempotent */
the_network->missing_kind_is_error = true;
int ret = trie_apply_with_key(the_network->endpoints, engage_endpoint_array, the_network);
if (ret) {
the_network->missing_kind_is_error = false; /* avoid the same errors when closing */
return ret;
}
return kr_ok();
}
static int close_endpoint(struct endpoint *ep, bool force)
const char *network_endpoint_key_str(const struct endpoint_key *key)
{
if (ep->udp) {
close_handle((uv_handle_t *)ep->udp, force);
switch (key->type)
{
case ENDPOINT_KEY_SOCKADDR:;
const struct endpoint_key_sockaddr *sa_key =
(struct endpoint_key_sockaddr *) key;
struct sockaddr_storage sa_storage;
struct sockaddr *sa = kr_sockaddr_from_key(&sa_storage, (const char *) &sa_key->sa_key);
return kr_straddr(sa);
case ENDPOINT_KEY_IFNAME:;
const struct endpoint_key_ifname *if_key =
(struct endpoint_key_ifname *) key;
return if_key->ifname;
default:
kr_assert(false);
return NULL;
}
if (ep->tcp) {
close_handle((uv_handle_t *)ep->tcp, force);
}
/** Notify the registered function about endpoint about to be closed. */
static void endpoint_close_lua_cb(struct endpoint *ep)
{
lua_State *L = the_engine->L;
void **pp = trie_get_try(the_network->endpoint_kinds, ep->flags.kind,
strlen(ep->flags.kind));
if (!pp && the_network->missing_kind_is_error) {
kr_log_error(NETWORK, "internal error: missing kind '%s' in endpoint registry\n",
ep->flags.kind);
return;
}
if (!pp) return;
free(ep);
return kr_ok();
const int fun_id = (intptr_t)*pp;
lua_rawgeti(L, LUA_REGISTRYINDEX, fun_id);
lua_pushboolean(L, false /* close */);
lua_pushpointer(L, ep);
lua_pushstring(L, "FIXME:endpoint-identifier");
if (lua_pcall(L, 3, 0, 0)) {
kr_log_error(NETWORK, "failed to close FIXME:endpoint-identifier: %s\n",
lua_tostring(L, -1));
}
}
static void endpoint_close(struct endpoint *ep, bool force)
{
const bool is_control = ep->flags.kind && strcmp(ep->flags.kind, "control") == 0;
const bool is_xdp = ep->family == AF_XDP;
if (ep->family == AF_UNIX) { /* The FS name would be left behind. */
/* Extract local address for this socket. */
struct sockaddr_un sa;
sa.sun_path[0] = '\0'; /*< probably only for lint:scan-build */
socklen_t addr_len = sizeof(sa);
if (getsockname(ep->fd, (struct sockaddr *)&sa, &addr_len)
|| unlink(sa.sun_path)) {
kr_log_error(NETWORK, "error (ignored) when closing unix socket (fd = %d): %s\n",
ep->fd, strerror(errno));
return;
}
}
if (ep->flags.kind && !is_control && !is_xdp) {
kr_assert(!ep->handle);
/* Special lua-handled endpoint. */
if (ep->engaged) {
endpoint_close_lua_cb(ep);
}
if (ep->fd > 0) {
close(ep->fd); /* nothing to do with errors */
}
free_const(ep->flags.kind);
return;
}
free_const(ep->flags.kind); /* needed if (is_control) */
kr_require(ep->handle);
if (force) { /* Force close if event loop isn't running. */
if (ep->fd >= 0) {
close(ep->fd);
}
if (ep->handle) {
ep->handle->loop = NULL;
struct session2 *s = ep->handle->data;
if (s)
session2_close(s);
}
} else { /* Asynchronous close */
struct session2 *s = ep->handle->data;
session2_close(s);
}
}
/** Endpoint visitor (see @file map.h) */
static int close_key(const char *key, void *val, void *ext)
/** Endpoint visitor (see @file trie.h) */
static int close_key(trie_val_t *val, void* net)
{
endpoint_array_t *ep_array = val;
for (size_t i = ep_array->len; i--;) {
close_endpoint(ep_array->at[i], true);
endpoint_array_t *ep_array = *val;
for (int i = 0; i < ep_array->len; ++i) {
endpoint_close(&ep_array->at[i], true);
}
return 0;
}
static int free_key(const char *key, void *val, void *ext)
static int free_key(trie_val_t *val, void* ext)
{
endpoint_array_t *ep_array = val;
endpoint_array_t *ep_array = *val;
array_clear(*ep_array);
free(ep_array);
return kr_ok();
}
void network_deinit(struct network *net)
int kind_unregister(trie_val_t *tv, void *L)
{
int fun_id = (intptr_t)*tv;
luaL_unref(L, LUA_REGISTRYINDEX, fun_id);
return 0;
}
void network_close_force(void)
{
if (the_network != NULL) {
trie_apply(the_network->endpoints, close_key, the_network);
trie_apply(the_network->endpoints, free_key, NULL);
trie_clear(the_network->endpoints);
}
}
/** Frees all the `struct net_proxy_data` in the specified trie. */
void network_proxy_free_addr_data(trie_t* trie)
{
if (net != NULL) {
map_walk(&net->endpoints, close_key, 0);
map_walk(&net->endpoints, free_key, 0);
map_clear(&net->endpoints);
tls_credentials_free(net->tls_credentials);
net->tls_credentials = NULL;
trie_it_t *it;
for (it = trie_it_begin(trie); !trie_it_finished(it); trie_it_next(it)) {
struct net_proxy_data *data = *trie_it_val(it);
free(data);
}
trie_it_free(it);
}
void network_unregister(void)
{
network_close_force();
trie_apply(the_network->endpoint_kinds, kind_unregister, the_engine->L);
}
/** Fetch or create endpoint array and insert endpoint. */
static int insert_endpoint(struct network *net, const char *addr, struct endpoint *ep)
void network_deinit(void)
{
trie_free(the_network->endpoint_kinds);
trie_free(the_network->endpoints);
network_proxy_free_addr_data(the_network->proxy_addrs4);
trie_free(the_network->proxy_addrs4);
network_proxy_free_addr_data(the_network->proxy_addrs6);
trie_free(the_network->proxy_addrs6);
tls_credentials_free(the_network->tls_credentials);
tls_client_params_free(the_network->tls_client_params);
tls_session_ticket_ctx_destroy(the_network->tls_session_ticket_ctx);
#ifndef NDEBUG
memset(the_network, 0, sizeof(*the_network));
#endif
the_network = NULL;
}
/** Creates an endpoint key for use with a `trie_t` and stores it into `dst`.
* Returns the actual length of the generated key. */
static ssize_t endpoint_key_create(struct endpoint_key_storage *dst,
const char *addr_str,
const struct sockaddr *sa)
{
memset(dst, 0, sizeof(*dst));
if (sa) {
struct endpoint_key_sockaddr *key = &dst->sa;
key->type = ENDPOINT_KEY_SOCKADDR;
ssize_t keylen = kr_sockaddr_key(&key->sa_key, sa);
if (keylen < 0)
return keylen;
return sizeof(struct endpoint_key) + keylen;
} else {
struct endpoint_key_ifname *key = &dst->ifname;
key->type = ENDPOINT_KEY_IFNAME;
/* The subtractions and additions of 1 are here to account for
* null-terminators. */
strncpy(key->ifname, addr_str, sizeof(key->ifname) - 1);
return sizeof(struct endpoint_key) + strlen(key->ifname) + 1;
}
}
/** Fetch or create endpoint array and insert endpoint (shallow memcpy). */
static int insert_endpoint(const char *addr_str,
const struct sockaddr *addr, struct endpoint *ep)
{
/* Fetch or insert address into map */
endpoint_array_t *ep_array = map_get(&net->endpoints, addr);
if (ep_array == NULL) {
struct endpoint_key_storage key;
ssize_t keylen = endpoint_key_create(&key, addr_str, addr);
if (keylen < 0)
return keylen;
trie_val_t *val = trie_get_ins(the_network->endpoints, key.bytes, keylen);
endpoint_array_t *ep_array;
if (*val) {
ep_array = *val;
} else {
ep_array = malloc(sizeof(*ep_array));
if (ep_array == NULL) {
return kr_error(ENOMEM);
}
if (map_set(&net->endpoints, addr, ep_array) != 0) {
free(ep_array);
return kr_error(ENOMEM);
}
kr_require(ep_array);
array_init(*ep_array);
*val = ep_array;
}
if (array_push(*ep_array, ep) < 0) {
if (array_reserve(*ep_array, ep_array->len + 1)) {
return kr_error(ENOMEM);
}
memcpy(&ep_array->at[ep_array->len++], ep, sizeof(*ep));
return kr_ok();
}
/** Open endpoint protocols. */
static int open_endpoint(struct network *net, struct endpoint *ep, struct sockaddr *sa, uint32_t flags)
/** Open endpoint protocols. ep->flags were pre-set.
* \p addr_str is only used for logging or for XDP "address". */
static int open_endpoint(const char *addr_str,
struct endpoint *ep, const struct sockaddr *sa)
{
int ret = 0;
if (flags & NET_UDP) {
ep->udp = malloc(sizeof(*ep->udp));
if (!ep->udp) {
return kr_error(ENOMEM);
}
memset(ep->udp, 0, sizeof(*ep->udp));
handle_init(udp, net->loop, ep->udp, sa->sa_family); /* can return! */
ret = udp_bind(ep->udp, sa);
if (ret != 0) {
return ret;
}
ep->flags |= NET_UDP;
const bool is_control = ep->flags.kind && strcmp(ep->flags.kind, "control") == 0;
const bool is_xdp = ep->family == AF_XDP;
bool ok = (!is_xdp)
|| (sa == NULL && ep->fd == -1 && ep->nic_queue >= 0
&& ep->flags.sock_type == SOCK_DGRAM && !ep->flags.tls);
if (kr_fails_assert(ok))
return kr_error(EINVAL);
if (ep->handle) {
return kr_error(EEXIST);
}
if (flags & NET_TCP) {
ep->tcp = malloc(sizeof(*ep->tcp));
if (!ep->tcp) {
return kr_error(ENOMEM);
}
memset(ep->tcp, 0, sizeof(*ep->tcp));
handle_init(tcp, net->loop, ep->tcp, sa->sa_family); /* can return! */
if (flags & NET_TLS) {
ret = tcp_bind_tls(ep->tcp, sa);
ep->flags |= NET_TLS;
} else {
ret = tcp_bind(ep->tcp, sa);
}
if (ret != 0) {
return ret;
if (sa && ep->fd == -1) {
if (sa->sa_family == AF_UNIX) {
struct sockaddr_un *sun = (struct sockaddr_un*)sa;
char *dirc = strdup(sun->sun_path);
char *dname = dirname(dirc);
(void)unlink(sun->sun_path); /** Attempt to unlink if socket path exists. */
(void)mkdir(dname, S_IRWXU|S_IRWXG); /** Attempt to create dir. */
free(dirc);
}
ep->flags |= NET_TCP;
ep->fd = io_bind(sa, ep->flags.sock_type, &ep->flags);
if (ep->fd < 0) return ep->fd;
}
if (ep->flags.kind && !is_control && !is_xdp) {
/* This EP isn't to be managed internally after binding. */
return endpoint_open_lua_cb(ep, addr_str);
} else {
ep->engaged = true;
/* .engaged seems not really meaningful in this case, but... */
}
return ret;
}
/** Open fd as endpoint. */
static int open_endpoint_fd(struct network *net, struct endpoint *ep, int fd, int sock_type, bool use_tls)
{
int ret = kr_ok();
if (sock_type == SOCK_DGRAM) {
if (use_tls) {
/* we do not support TLS over UDP */
return kr_error(EBADF);
}
if (ep->udp) {
return kr_error(EEXIST);
}
ep->udp = malloc(sizeof(*ep->udp));
if (!ep->udp) {
return kr_error(ENOMEM);
}
uv_udp_init(net->loop, ep->udp);
ret = udp_bindfd(ep->udp, fd);
if (ret != 0) {
close_handle((uv_handle_t *)ep->udp, false);
return ret;
}
ep->flags |= NET_UDP;
return kr_ok();
int ret;
if (is_control) {
uv_pipe_t *ep_handle = malloc(sizeof(uv_pipe_t));
ep->handle = (uv_handle_t *)ep_handle;
ret = !ep->handle ? ENOMEM
: io_listen_pipe(the_network->loop, ep_handle, ep->fd);
goto finish_ret;
}
if (sock_type == SOCK_STREAM) {
if (ep->tcp) {
return kr_error(EEXIST);
}
ep->tcp = malloc(sizeof(*ep->tcp));
if (!ep->tcp) {
return kr_error(ENOMEM);
}
uv_tcp_init(net->loop, ep->tcp);
if (use_tls) {
ret = tcp_bindfd_tls(ep->tcp, fd);
ep->flags |= NET_TLS;
} else {
ret = tcp_bindfd(ep->tcp, fd);
}
if (ret != 0) {
close_handle((uv_handle_t *)ep->tcp, false);
return ret;
}
ep->flags |= NET_TCP;
return kr_ok();
if (ep->family == AF_UNIX) {
/* Some parts of connection handling would need more work,
* so let's support AF_UNIX only with .kind != NULL for now. */
kr_log_error(NETWORK, "AF_UNIX only supported with set { kind = '...' }\n");
ret = EAFNOSUPPORT;
goto finish_ret;
/*
uv_pipe_t *ep_handle = malloc(sizeof(uv_pipe_t));
*/
}
if (is_xdp) {
#if ENABLE_XDP
uv_poll_t *ep_handle = malloc(sizeof(uv_poll_t));
ep->handle = (uv_handle_t *)ep_handle;
ret = !ep->handle ? ENOMEM
: io_listen_xdp(the_network->loop, ep, addr_str);
#else
ret = ESOCKTNOSUPPORT;
#endif
goto finish_ret;
} /* else */
if (ep->flags.sock_type == SOCK_DGRAM) {
if (kr_fails_assert(!ep->flags.tls))
return kr_error(EINVAL);
uv_udp_t *ep_handle = malloc(sizeof(uv_udp_t));
ep->handle = (uv_handle_t *)ep_handle;
ret = !ep->handle ? ENOMEM
: io_listen_udp(the_network->loop, ep_handle, ep->fd);
goto finish_ret;
} /* else */
if (ep->flags.sock_type == SOCK_STREAM) {
uv_tcp_t *ep_handle = malloc(sizeof(uv_tcp_t));
ep->handle = (uv_handle_t *)ep_handle;
ret = !ep->handle ? ENOMEM
: io_listen_tcp(the_network->loop, ep_handle, ep->fd,
the_network->tcp_backlog, ep->flags.tls, ep->flags.http);
goto finish_ret;
} /* else */
kr_assert(false);
return kr_error(EINVAL);
finish_ret:
if (!ret) return ret;
free(ep->handle);
ep->handle = NULL;
return kr_error(ret);
}
/** @internal Fetch endpoint array and offset of the address/port query. */
static endpoint_array_t *network_get(struct network *net, const char *addr, uint16_t port, size_t *index)
/** @internal Fetch a pointer to endpoint of given parameters (or NULL).
* Beware that there might be multiple matches, though that's not common.
* The matching isn't really precise in the sense that it might not find
* and endpoint that would *collide* the passed one. */
static struct endpoint * endpoint_get(const char *addr_str,
const struct sockaddr *sa,
endpoint_flags_t flags)
{
endpoint_array_t *ep_array = map_get(&net->endpoints, addr);
if (ep_array) {
for (size_t i = ep_array->len; i--;) {
struct endpoint *ep = ep_array->at[i];
if (ep->port == port) {
*index = i;
return ep_array;
}
struct endpoint_key_storage key;
ssize_t keylen = endpoint_key_create(&key, addr_str, sa);
if (keylen < 0)
return NULL;
trie_val_t *val = trie_get_try(the_network->endpoints, key.bytes, keylen);
if (!val)
return NULL;
endpoint_array_t *ep_array = *val;
uint16_t port = kr_inaddr_port(sa);
for (int i = 0; i < ep_array->len; ++i) {
struct endpoint *ep = &ep_array->at[i];
if ((flags.xdp || ep->port == port) && endpoint_flags_eq(ep->flags, flags)) {
return ep;
}
}
return NULL;
}
int network_listen_fd(struct network *net, int fd, bool use_tls)
/** \note pass (either sa != NULL xor ep.fd != -1) or XDP case (neither sa nor ep.fd)
* \note in XDP case addr_str is interface name
* \note ownership of ep.flags.* is taken on success. */
static int create_endpoint(const char *addr_str,
struct endpoint *ep, const struct sockaddr *sa)
{
/* Extract local address and socket type. */
int sock_type = SOCK_DGRAM;
socklen_t len = sizeof(sock_type);
int ret = getsockopt(fd, SOL_SOCKET, SO_TYPE, &sock_type, &len);
if (ret != 0) {
return kr_error(EBADF);
int ret = open_endpoint(addr_str, ep, sa);
if (ret == 0) {
ret = insert_endpoint(addr_str, sa, ep);
}
if (ret != 0 && ep->handle) {
endpoint_close(ep, false);
}
return ret;
}
int network_listen_fd(int fd, endpoint_flags_t flags)
{
if (kr_fails_assert(!flags.xdp))
return kr_error(EINVAL);
/* Extract fd's socket type. */
socklen_t len = sizeof(flags.sock_type);
int ret = getsockopt(fd, SOL_SOCKET, SO_TYPE, &flags.sock_type, &len);
if (ret != 0)
return kr_error(errno);
const bool is_dtls = flags.sock_type == SOCK_DGRAM && !flags.kind && flags.tls;
if (kr_fails_assert(!is_dtls))
return kr_error(EINVAL); /* Perhaps DTLS some day. */
if (flags.sock_type != SOCK_DGRAM && flags.sock_type != SOCK_STREAM)
return kr_error(EBADF);
/* Extract local address for this socket. */
struct sockaddr_storage ss;
struct sockaddr_storage ss = { .ss_family = AF_UNSPEC };
socklen_t addr_len = sizeof(ss);
ret = getsockname(fd, (struct sockaddr *)&ss, &addr_len);
if (ret != 0) {
return kr_error(EBADF);
if (ret != 0)
return kr_error(errno);
struct endpoint ep = {
.flags = flags,
.family = ss.ss_family,
.fd = fd,
};
/* Extract address string and port. */
char addr_buf[INET6_ADDRSTRLEN]; /* https://tools.ietf.org/html/rfc4291 */
const char *addr_str;
switch (ep.family) {
case AF_INET:
ret = uv_ip4_name((const struct sockaddr_in*)&ss, addr_buf, sizeof(addr_buf));
addr_str = addr_buf;
ep.port = ntohs(((struct sockaddr_in *)&ss)->sin_port);
break;
case AF_INET6:
ret = uv_ip6_name((const struct sockaddr_in6*)&ss, addr_buf, sizeof(addr_buf));
addr_str = addr_buf;
ep.port = ntohs(((struct sockaddr_in6 *)&ss)->sin6_port);
break;
case AF_UNIX:
/* No SOCK_DGRAM with AF_UNIX support, at least for now. */
ret = flags.sock_type == SOCK_STREAM ? kr_ok() : kr_error(EAFNOSUPPORT);
addr_str = ((struct sockaddr_un *)&ss)->sun_path;
break;
default:
ret = kr_error(EAFNOSUPPORT);
}
int port = 0;
char addr_str[INET6_ADDRSTRLEN]; /* https://tools.ietf.org/html/rfc4291 */
if (ss.ss_family == AF_INET) {
uv_ip4_name((const struct sockaddr_in*)&ss, addr_str, sizeof(addr_str));
port = ntohs(((struct sockaddr_in *)&ss)->sin_port);
} else if (ss.ss_family == AF_INET6) {
uv_ip6_name((const struct sockaddr_in6*)&ss, addr_str, sizeof(addr_str));
port = ntohs(((struct sockaddr_in6 *)&ss)->sin6_port);
} else {
return kr_error(EAFNOSUPPORT);
}
/* Fetch or create endpoint for this fd */
size_t index = 0;
endpoint_array_t *ep_array = network_get(net, addr_str, port, &index);
if (!ep_array) {
struct endpoint *ep = malloc(sizeof(*ep));
memset(ep, 0, sizeof(*ep));
ep->flags = NET_DOWN;
ep->port = port;
ret = insert_endpoint(net, addr_str, ep);
if (ret != 0) {
return ret;
if (ret) return ret;
/* always create endpoint for supervisor supplied fd
* even if addr+port is not unique */
return create_endpoint(addr_str, &ep, (struct sockaddr *) &ss);
}
/** Try selecting XDP queue automatically. */
static int16_t nic_queue_auto(void)
{
const char *inst_str = getenv("SYSTEMD_INSTANCE");
if (!inst_str)
return 0; // should work OK for simple (single-kresd) deployments
char *endp;
errno = 0; // strtol() is special in this respect
long inst = strtol(inst_str, &endp, 10);
if (!errno && *endp == '\0' && inst > 0 && inst < UINT16_MAX)
return inst - 1; // 1-based vs. 0-based indexing conventions
return -1;
}
int network_listen(const char *addr, uint16_t port,
int16_t nic_queue, endpoint_flags_t flags)
{
if (kr_fails_assert(the_network != NULL && addr != 0 && nic_queue >= -1))
return kr_error(EINVAL);
if (flags.xdp && nic_queue < 0) {
nic_queue = nic_queue_auto();
if (nic_queue < 0) {
return kr_error(EINVAL);
}
}
// Try parsing the address.
const struct sockaddr *sa = kr_straddr_socket(addr, port, NULL);
if (!sa && !flags.xdp) { // unusable address spec
return kr_error(EINVAL);
}
char ifname_buf[64] UNUSED;
if (sa && flags.xdp) { // auto-detection: address -> interface
#if ENABLE_XDP
int ret = knot_eth_name_from_addr((const struct sockaddr_storage *)sa,
ifname_buf, sizeof(ifname_buf));
// even on success we don't want to pass `sa` on
free_const(sa);
sa = NULL;
if (ret) {
return kr_error(ret);
}
ep_array = network_get(net, addr_str, port, &index);
addr = ifname_buf;
#else
return kr_error(ESOCKTNOSUPPORT);
#endif
}
// XDP: if addr failed to parse as address, we assume it's an interface name.
if (endpoint_get(addr, sa, flags)) {
return kr_error(EADDRINUSE); // Already listening
}
/* Open fd in found/created endpoint. */
struct endpoint *ep = ep_array->at[index];
assert(ep != NULL);
/* Create a libuv struct for this socket. */
return open_endpoint_fd(net, ep, fd, sock_type, use_tls);
struct endpoint ep = { 0 };
ep.flags = flags;
ep.fd = -1;
ep.port = port;
ep.family = flags.xdp ? AF_XDP : sa->sa_family;
ep.nic_queue = nic_queue;
int ret = create_endpoint(addr, &ep, sa);
// Error reporting: more precision.
if (ret == KNOT_EINVAL && !sa && flags.xdp && ENABLE_XDP) {
if (!if_nametoindex(addr) && errno == ENODEV) {
ret = kr_error(ENODEV);
}
}
free_const(sa);
return ret;
}
int network_listen(struct network *net, const char *addr, uint16_t port, uint32_t flags)
int network_proxy_allow(const char* addr)
{
if (net == NULL || addr == 0 || port == 0) {
if (kr_fails_assert(the_network != NULL && addr != NULL))
return kr_error(EINVAL);
int family = kr_straddr_family(addr);
if (family < 0) {
kr_log_error(NETWORK, "Wrong address format for proxy_allowed: %s\n",
addr);
return kr_error(EINVAL);
} else if (family == AF_UNIX) {
kr_log_error(NETWORK, "Unix sockets not supported for proxy_allowed: %s\n",
addr);
return kr_error(EINVAL);
}
/* Already listening */
size_t index = 0;
if (network_get(net, addr, port, &index)) {
union kr_in_addr ia;
int netmask = kr_straddr_subnet(&ia, addr);
if (netmask < 0) {
kr_log_error(NETWORK, "Wrong netmask format for proxy_allowed: %s\n", addr);
return kr_error(EINVAL);
} else if (netmask == 0) {
/* Netmask is zero: allow all addresses to use PROXYv2 */
switch (family) {
case AF_INET:
the_network->proxy_all4 = true;
break;
case AF_INET6:
the_network->proxy_all6 = true;
break;
default:
kr_assert(false);
return kr_error(EINVAL);
}
return kr_ok();
}
/* Parse address. */
int ret = 0;
struct sockaddr_storage sa;
if (strchr(addr, ':') != NULL) {
ret = uv_ip6_addr(addr, port, (struct sockaddr_in6 *)&sa);
} else {
ret = uv_ip4_addr(addr, port, (struct sockaddr_in *)&sa);
}
if (ret != 0) {
return ret;
size_t addr_length;
trie_t *trie;
switch (family) {
case AF_INET:
addr_length = sizeof(ia.ip4);
trie = the_network->proxy_addrs4;
break;
case AF_INET6:
addr_length = sizeof(ia.ip6);
trie = the_network->proxy_addrs6;
break;
default:
kr_assert(false);
return kr_error(EINVAL);
}
/* Bind interfaces */
struct endpoint *ep = malloc(sizeof(*ep));
memset(ep, 0, sizeof(*ep));
ep->flags = NET_DOWN;
ep->port = port;
ret = open_endpoint(net, ep, (struct sockaddr *)&sa, flags);
if (ret == 0) {
ret = insert_endpoint(net, addr, ep);
kr_bitmask((unsigned char *) &ia, addr_length, netmask);
trie_val_t *val = trie_get_ins(trie, (char *) &ia, addr_length);
if (!val)
return kr_error(ENOMEM);
struct net_proxy_data *data = *val;
if (!data) {
/* Allocate data if the entry is new in the trie */
*val = malloc(sizeof(struct net_proxy_data));
data = *val;
data->netmask = 0;
}
if (ret != 0) {
close_endpoint(ep, false);
if (data->netmask == 0) {
memcpy(&data->addr, &ia, addr_length);
data->netmask = netmask;
} else if (data->netmask > netmask) {
/* A more relaxed netmask configured - replace it */
data->netmask = netmask;
}
return ret;
return kr_ok();
}
int network_close(struct network *net, const char *addr, uint16_t port)
void network_proxy_reset(void)
{
size_t index = 0;
endpoint_array_t *ep_array = network_get(net, addr, port, &index);
if (!ep_array) {
the_network->proxy_all4 = false;
network_proxy_free_addr_data(the_network->proxy_addrs4);
trie_clear(the_network->proxy_addrs4);
the_network->proxy_all6 = false;
network_proxy_free_addr_data(the_network->proxy_addrs6);
trie_clear(the_network->proxy_addrs6);
}
static int endpoints_close(struct endpoint_key_storage *key, ssize_t keylen,
endpoint_array_t *ep_array, int port)
{
size_t i = 0;
bool matched = false; /*< at least one match */
while (i < ep_array->len) {
struct endpoint *ep = &ep_array->at[i];
if (port < 0 || ep->port == port) {
endpoint_close(ep, false);
array_del(*ep_array, i);
matched = true;
/* do not advance i */
} else {
++i;
}
}
if (!matched) {
return kr_error(ENOENT);
}
/* Close endpoint in array. */
close_endpoint(ep_array->at[index], false);
array_del(*ep_array, index);
return kr_ok();
}
static bool endpoint_key_addr_matches(struct endpoint_key_storage *key_a,
struct endpoint_key_storage *key_b)
{
if (key_a->type != key_b->type)
return false;
if (key_a->type == ENDPOINT_KEY_IFNAME)
return strncmp(key_a->ifname.ifname,
key_b->ifname.ifname,
sizeof(key_a->ifname.ifname)) == 0;
if (key_a->type == ENDPOINT_KEY_SOCKADDR) {
return kr_sockaddr_key_same_addr(
key_a->sa.sa_key.bytes, key_b->sa.sa_key.bytes);
}
kr_assert(false);
return kr_error(EINVAL);
}
struct endpoint_key_with_len {
struct endpoint_key_storage key;
size_t keylen;
};
typedef array_t(struct endpoint_key_with_len) endpoint_key_array_t;
struct endpoint_close_wildcard_context {
struct endpoint_key_storage *match_key;
endpoint_key_array_t del;
int ret;
};
static int endpoints_close_wildcard(const char *s_key, uint32_t keylen, trie_val_t *val, void *baton)
{
struct endpoint_close_wildcard_context *ctx = baton;
struct endpoint_key_storage *key = (struct endpoint_key_storage *)s_key;
if (!endpoint_key_addr_matches(key, ctx->match_key))
return kr_ok();
endpoint_array_t *ep_array = *val;
int ret = endpoints_close(key, keylen, ep_array, -1);
if (ret)
ctx->ret = ret;
if (ep_array->len == 0) {
struct endpoint_key_with_len to_del = {
.key = *key,
.keylen = keylen
};
array_push(ctx->del, to_del);
}
return kr_ok();
}
int network_close(const char *addr_str, int port)
{
auto_free struct sockaddr *addr = kr_straddr_socket(addr_str, port, NULL);
struct endpoint_key_storage key;
ssize_t keylen = endpoint_key_create(&key, addr_str, addr);
if (keylen < 0)
return keylen;
if (port < 0) {
struct endpoint_close_wildcard_context ctx = {
.match_key = &key
};
array_init(ctx.del);
trie_apply_with_key(the_network->endpoints,
endpoints_close_wildcard, &ctx);
for (size_t i = 0; i < ctx.del.len; i++) {
trie_val_t val;
trie_del(the_network->endpoints,
ctx.del.at[i].key.bytes, ctx.del.at[i].keylen,
&val);
if (val) {
array_clear(*(endpoint_array_t *) val);
free(val);
}
}
return ctx.ret;
}
trie_val_t *val = trie_get_try(the_network->endpoints, key.bytes, keylen);
if (!val)
return kr_error(ENOENT);
endpoint_array_t *ep_array = *val;
int ret = endpoints_close(&key, keylen, ep_array, port);
/* Collapse key if it has no endpoint. */
if (ep_array->len == 0) {
array_clear(*ep_array);
free(ep_array);
map_del(&net->endpoints, addr);
trie_del(the_network->endpoints, key.bytes, keylen, NULL);
}
return kr_ok();
return ret;
}
void network_new_hostname(void)
{
if (the_network->tls_credentials &&
the_network->tls_credentials->ephemeral_servicename) {
struct tls_credentials *newcreds;
newcreds = tls_get_ephemeral_credentials();
if (newcreds) {
tls_credentials_release(the_network->tls_credentials);
the_network->tls_credentials = newcreds;
kr_log_info(TLS, "Updated ephemeral X.509 cert with new hostname\n");
} else {
kr_log_error(TLS, "Failed to update ephemeral X.509 cert with new hostname, using existing one\n");
}
}
}
#ifdef SO_ATTACH_BPF
static int set_bpf_cb(trie_val_t *val, void *ctx)
{
endpoint_array_t *endpoints = *val;
int *bpffd = (int *)ctx;
if (kr_fails_assert(endpoints && bpffd))
return kr_error(EINVAL);
for (size_t i = 0; i < endpoints->len; i++) {
struct endpoint *endpoint = &endpoints->at[i];
uv_os_fd_t sockfd = -1;
if (endpoint->handle != NULL)
uv_fileno(endpoint->handle, &sockfd);
kr_require(sockfd != -1);
if (setsockopt(sockfd, SOL_SOCKET, SO_ATTACH_BPF, bpffd, sizeof(int)) != 0) {
return 1; /* return error (and stop iterating over net->endpoints) */
}
}
return 0; /* OK */
}
#endif
int network_set_bpf(int bpf_fd)
{
#ifdef SO_ATTACH_BPF
if (trie_apply(the_network->endpoints, set_bpf_cb, &bpf_fd) != 0) {
/* set_bpf_cb() has returned error. */
network_clear_bpf();
return 0;
}
#else
kr_log_error(NETWORK, "SO_ATTACH_BPF socket option doesn't supported\n");
(void)bpf_fd;
return 0;
#endif
return 1;
}
#ifdef SO_DETACH_BPF
static int clear_bpf_cb(trie_val_t *val, void *ctx)
{
endpoint_array_t *endpoints = *val;
if (kr_fails_assert(endpoints))
return kr_error(EINVAL);
for (size_t i = 0; i < endpoints->len; i++) {
struct endpoint *endpoint = &endpoints->at[i];
uv_os_fd_t sockfd = -1;
if (endpoint->handle != NULL)
uv_fileno(endpoint->handle, &sockfd);
kr_require(sockfd != -1);
if (setsockopt(sockfd, SOL_SOCKET, SO_DETACH_BPF, NULL, 0) != 0) {
kr_log_error(NETWORK, "failed to clear SO_DETACH_BPF socket option\n");
}
/* Proceed even if setsockopt() failed,
* as we want to process all opened sockets. */
}
return 0;
}
#endif
void network_clear_bpf(void)
{
#ifdef SO_DETACH_BPF
trie_apply(the_network->endpoints, clear_bpf_cb, NULL);
#else
kr_log_error(NETWORK, "SO_DETACH_BPF socket option doesn't supported\n");
#endif
}
/* Copyright (C) 2015 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
/* Copyright (C) CZ.NIC, z.s.p.o. <knot-resolver@labs.nic.cz>
* SPDX-License-Identifier: GPL-3.0-or-later
*/
#pragma once
#include "daemon/tls.h"
#include "lib/generic/array.h"
#include "lib/generic/trie.h"
#include <uv.h>
#include <stdbool.h>
#include "lib/generic/array.h"
#include "lib/generic/map.h"
#include <sys/socket.h>
#ifndef AF_XDP
#define AF_XDP 44
#endif
enum endpoint_flag {
NET_DOWN = 0 << 0,
NET_UDP = 1 << 0,
NET_TCP = 1 << 1,
NET_TLS = 1 << 2,
};
struct engine;
struct session;
/** Ways to listen on a socket (which may exist already). */
typedef struct {
int sock_type; /**< SOCK_DGRAM or SOCK_STREAM */
bool tls; /**< only used together with .kind == NULL and SOCK_STREAM */
bool http; /**< DoH2, implies .tls (in current implementation) */
bool xdp; /**< XDP is special (not a normal socket, in particular) */
bool freebind; /**< used for binding to non-local address */
const char *kind; /**< tag for other types: "control" or module-handled kinds */
} endpoint_flags_t;
struct endpoint_key;
static inline bool endpoint_flags_eq(endpoint_flags_t f1, endpoint_flags_t f2)
{
if (f1.sock_type != f2.sock_type)
return false;
if (f1.kind && f2.kind)
return strcasecmp(f1.kind, f2.kind);
else
return f1.tls == f2.tls && f1.kind == f2.kind;
}
/** Wrapper for a single socket to listen on.
* There are two types: normal have handle, special have flags.kind (and never both).
*
* LATER: .family might be unexpected for IPv4-in-IPv6 addresses.
* ATM AF_UNIX is only supported with flags.kind != NULL
*/
struct endpoint {
uv_udp_t *udp;
uv_tcp_t *tcp;
uint16_t port;
uint16_t flags;
/** uv_{udp,tcp,poll}_t (poll for XDP);
* NULL in case of endpoints that are to be handled by modules. */
uv_handle_t *handle;
int fd; /**< POSIX file-descriptor; always used. */
int family; /**< AF_INET or AF_INET6 or AF_UNIX or AF_XDP */
uint16_t port; /**< TCP/UDP port. Meaningless with AF_UNIX. */
int16_t nic_queue; /**< -1 or queue number of the interface for AF_XDP use. */
bool engaged; /**< to some module or internally */
endpoint_flags_t flags;
};
/** @cond internal Array of endpoints */
typedef array_t(struct endpoint*) endpoint_array_t;
typedef array_t(struct endpoint) endpoint_array_t;
/* @endcond */
struct net_tcp_param {
uint64_t in_idle_timeout;
uint64_t tls_handshake_timeout;
/** Milliseconds of unacknowledged data; see TCP_USER_TIMEOUT in man tcp.7
* Linux only, probably. */
unsigned int user_timeout;
};
/** Information about an address that is allowed to use PROXYv2. */
struct net_proxy_data {
union kr_in_addr addr;
uint8_t netmask; /**< Number of bits to be matched */
};
struct network {
uv_loop_t *loop;
map_t endpoints;
/** Map: address string -> endpoint_array_t.
* \note even same address-port-flags tuples may appear. */
trie_t *endpoints;
/** Registry of callbacks for special endpoint kinds (for opening/closing).
* Map: kind (lowercased) -> lua function ID converted to void *
* The ID is the usual: raw int index in the LUA_REGISTRYINDEX table. */
trie_t *endpoint_kinds;
/** See network_engage_endpoints() */
bool missing_kind_is_error : 1;
/** True: All IPv4 addresses are allowed to use the PROXYv2 protocol */
bool proxy_all4 : 1;
/** True: All IPv6 addresses are allowed to use the PROXYv2 protocol */
bool proxy_all6 : 1;
/** IPv4 addresses and networks allowed to use the PROXYv2 protocol */
trie_t *proxy_addrs4;
/** IPv6 addresses and networks allowed to use the PROXYv2 protocol */
trie_t *proxy_addrs6;
struct tls_credentials *tls_credentials;
tls_client_params_t *tls_client_params; /**< Use tls_client_params_*() functions. */
struct tls_session_ticket_ctx *tls_session_ticket_ctx;
struct net_tcp_param tcp;
int tcp_backlog;
/** Kernel-side buffer sizes for sending and receiving. (in bytes)
* They are per socket, so in the TCP case they are per connection.
* See SO_SNDBUF and SO_RCVBUF in man socket.7 These are in POSIX. */
struct {
int snd, rcv;
} listen_udp_buflens, listen_tcp_buflens;
/** Use uv_udp_connect as the transport method for UDP.
* Enabling this increases the total number of syscalls, with a variable
* impact on the time spent processing them, sometimes resulting in
* a slight improvement in syscall processing efficiency.
* Note: This does not necessarily lead to overall performance gains. */
bool enable_connect_udp;
};
void network_init(struct network *net, uv_loop_t *loop);
void network_deinit(struct network *net);
int network_listen_fd(struct network *net, int fd, bool use_tls);
int network_listen(struct network *net, const char *addr, uint16_t port, uint32_t flags);
int network_close(struct network *net, const char *addr, uint16_t port);
int network_set_tls_cert(struct network *net, const char *cert);
int network_set_tls_key(struct network *net, const char *key);
/** Pointer to the singleton network state. NULL if not initialized. */
KR_EXPORT extern struct network *the_network;
/** Initializes the network. */
void network_init(uv_loop_t *loop, int tcp_backlog);
/** Unregisters endpoints. Should be called before `network_deinit`
* and `engine_deinit`. */
void network_unregister(void);
/** Deinitializes the network. `network_unregister` should be called before
* this and before `engine_deinit`. */
void network_deinit(void);
/** Start listening on addr#port with flags.
* \note if we did listen on that combination already,
* nothing is done and kr_error(EADDRINUSE) is returned.
* \note there's no short-hand to listen both on UDP and TCP.
* \note ownership of flags.* is taken on success. TODO: non-success?
* \param nic_queue == -1 for auto-selection or non-XDP.
* \note In XDP mode, addr may be also interface name, so kr_error(ENODEV)
* is returned if some nonsense is passed
*/
int network_listen(const char *addr, uint16_t port,
int16_t nic_queue, endpoint_flags_t flags);
/** Allow the specified address to send the PROXYv2 header.
* \note the address may be specified with a netmask
*/
int network_proxy_allow(const char* addr);
/** Reset all addresses allowed to send the PROXYv2 header. No addresses will
* be allowed to send PROXYv2 headers from the point of calling this function
* until re-allowed via network_proxy_allow again.
*/
void network_proxy_reset(void);
/** Start listening on an open file-descriptor.
* \note flags.sock_type isn't meaningful here.
* \note ownership of flags.* is taken on success. TODO: non-success?
*/
int network_listen_fd(int fd, endpoint_flags_t flags);
/** Stop listening on all endpoints with matching addr#port.
* port < 0 serves as a wild-card.
* \return kr_error(ENOENT) if nothing matched. */
int network_close(const char *addr, int port);
/** Close all endpoints immediately (no waiting for UV loop). */
void network_close_force(void);
/** Enforce that all endpoints are registered from now on.
* This only does anything with struct endpoint::flags.kind != NULL. */
int network_engage_endpoints(void);
/** Returns a string representation of the specified endpoint key.
*
* The result points into key or is on static storage like for kr_straddr() */
const char *network_endpoint_key_str(const struct endpoint_key *key);
int network_set_tls_cert(const char *cert);
int network_set_tls_key(const char *key);
void network_new_hostname(void);
int network_set_bpf(int bpf_fd);
void network_clear_bpf(void);
/* Copyright (C) CZ.NIC, z.s.p.o. <knot-resolver@labs.nic.cz>
* SPDX-License-Identifier: GPL-3.0-or-later
*/
#include "daemon/network.h"
#include "daemon/session2.h"
#include "daemon/worker.h"
#include "lib/generic/trie.h"
#include "daemon/proxyv2.h"
static const char PROXY2_SIGNATURE[12] = {
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A
};
#define PROXY2_MIN_SIZE 16
#define PROXY2_IP6_ADDR_SIZE 16
#define PROXY2_UNIX_ADDR_SIZE 108
#define TLV_TYPE_SSL 0x20
enum proxy2_family {
PROXY2_AF_UNSPEC = 0x0,
PROXY2_AF_INET = 0x1,
PROXY2_AF_INET6 = 0x2,
PROXY2_AF_UNIX = 0x3
};
enum proxy2_protocol {
PROXY2_PROTOCOL_UNSPEC = 0x0,
PROXY2_PROTOCOL_STREAM = 0x1,
PROXY2_PROTOCOL_DGRAM = 0x2
};
/** PROXYv2 protocol header section */
struct proxy2_header {
uint8_t signature[sizeof(PROXY2_SIGNATURE)];
uint8_t version_command;
uint8_t family_protocol;
uint16_t length; /**< Length of the address section */
};
/** PROXYv2 additional information in Type-Length-Value (TLV) format. */
struct proxy2_tlv {
uint8_t type;
uint8_t length_hi;
uint8_t length_lo;
uint8_t value[];
};
/** PROXYv2 protocol address section */
union proxy2_address {
struct {
uint32_t src_addr;
uint32_t dst_addr;
uint16_t src_port;
uint16_t dst_port;
} ipv4_addr;
struct {
uint8_t src_addr[PROXY2_IP6_ADDR_SIZE];
uint8_t dst_addr[PROXY2_IP6_ADDR_SIZE];
uint16_t src_port;
uint16_t dst_port;
} ipv6_addr;
struct {
uint8_t src_addr[PROXY2_UNIX_ADDR_SIZE];
uint8_t dst_addr[PROXY2_UNIX_ADDR_SIZE];
} unix_addr;
};
/** Gets protocol version from the specified PROXYv2 header. */
static inline unsigned char proxy2_header_version(const struct proxy2_header* h)
{
return (h->version_command & 0xF0) >> 4;
}
/** Gets command from the specified PROXYv2 header. */
static inline enum proxy2_command proxy2_header_command(const struct proxy2_header *h)
{
return h->version_command & 0x0F;
}
/** Gets address family from the specified PROXYv2 header. */
static inline enum proxy2_family proxy2_header_family(const struct proxy2_header *h)
{
return (h->family_protocol & 0xF0) >> 4;
}
/** Gets transport protocol from the specified PROXYv2 header. */
static inline enum proxy2_family proxy2_header_protocol(const struct proxy2_header *h)
{
return h->family_protocol & 0x0F;
}
static inline union proxy2_address *proxy2_get_address(const struct proxy2_header *h)
{
return (union proxy2_address *)((uint8_t *)h + sizeof(struct proxy2_header));
}
static inline struct proxy2_tlv *get_tlvs(const struct proxy2_header *h, size_t addr_len)
{
return (struct proxy2_tlv *)((uint8_t *)proxy2_get_address(h) + addr_len);
}
/** Gets the length of the TLV's `value` attribute. */
static inline uint16_t proxy2_tlv_length(const struct proxy2_tlv *tlv)
{
return ((uint16_t) tlv->length_hi << 16) | tlv->length_lo;
}
static inline bool has_tlv(const struct proxy2_header *h,
const struct proxy2_tlv *tlv)
{
uint64_t addr_length = ntohs(h->length);
ptrdiff_t hdr_len = sizeof(struct proxy2_header) + addr_length;
uint8_t *tlv_hdr_end = (uint8_t *)tlv + sizeof(struct proxy2_tlv);
ptrdiff_t distance = tlv_hdr_end - (uint8_t *)h;
if (hdr_len < distance)
return false;
uint8_t *tlv_end = tlv_hdr_end + proxy2_tlv_length(tlv);
distance = tlv_end - (uint8_t *)h;
return hdr_len >= distance;
}
static inline void next_tlv(struct proxy2_tlv **tlv)
{
uint8_t *next = ((uint8_t *)*tlv + sizeof(struct proxy2_tlv) + proxy2_tlv_length(*tlv));
*tlv = (struct proxy2_tlv *)next;
}
bool proxy_allowed(const struct sockaddr *saddr)
{
union kr_in_addr addr;
trie_t *trie;
size_t addr_size;
switch (saddr->sa_family) {
case AF_INET:
if (the_network->proxy_all4)
return true;
trie = the_network->proxy_addrs4;
addr_size = sizeof(addr.ip4);
addr.ip4 = ((struct sockaddr_in *)saddr)->sin_addr;
break;
case AF_INET6:
if (the_network->proxy_all6)
return true;
trie = the_network->proxy_addrs6;
addr_size = sizeof(addr.ip6);
addr.ip6 = ((struct sockaddr_in6 *)saddr)->sin6_addr;
break;
default:
kr_assert(false); // Only IPv4 and IPv6 proxy addresses supported
return false;
}
trie_val_t *val;
int ret = trie_get_leq(trie, (char *)&addr, addr_size, &val);
if (ret != kr_ok() && ret != 1)
return false;
kr_assert(val);
const struct net_proxy_data *found = *val;
kr_assert(found);
return kr_bitcmp((char *)&addr, (char *)&found->addr, found->netmask) == 0;
}
/** Parses the PROXYv2 header from buf of size nread and writes the result into
* out. The function assumes that the PROXYv2 signature is present
* and has been already checked by the caller (like `udp_recv` or `tcp_recv`). */
static ssize_t proxy_process_header(struct proxy_result *out,
const void *buf, const ssize_t nread)
{
if (!buf)
return kr_error(EINVAL);
const struct proxy2_header *hdr = (struct proxy2_header *)buf;
uint64_t content_length = ntohs(hdr->length);
ssize_t hdr_len = sizeof(struct proxy2_header) + content_length;
/* PROXYv2 requires the header to be received all at once */
if (nread < hdr_len) {
return kr_error(KNOT_EMALF);
}
unsigned char version = proxy2_header_version(hdr);
if (version != 2) {
/* Version MUST be 2 for PROXYv2 protocol */
return kr_error(KNOT_EMALF);
}
enum proxy2_command command = proxy2_header_command(hdr);
if (command == PROXY2_CMD_LOCAL) {
/* Addresses for LOCAL are to be discarded */
*out = (struct proxy_result){ .command = PROXY2_CMD_LOCAL };
goto fill_wirebuf;
}
if (command != PROXY2_CMD_PROXY) {
/* PROXYv2 prohibits values other than LOCAL and PROXY */
return kr_error(KNOT_EMALF);
}
*out = (struct proxy_result){ .command = PROXY2_CMD_PROXY };
/* Parse flags */
enum proxy2_family family = proxy2_header_family(hdr);
switch(family) {
case PROXY2_AF_UNSPEC:
case PROXY2_AF_UNIX:
/* UNIX is unsupported, fall back to UNSPEC */
out->family = AF_UNSPEC;
break;
case PROXY2_AF_INET:
out->family = AF_INET;
break;
case PROXY2_AF_INET6:
out->family = AF_INET6;
break;
default:
/* PROXYv2 prohibits other values */
return kr_error(KNOT_EMALF);
}
enum proxy2_family protocol = proxy2_header_protocol(hdr);
switch (protocol) {
case PROXY2_PROTOCOL_DGRAM:
out->protocol = SOCK_DGRAM;
break;
case PROXY2_PROTOCOL_STREAM:
out->protocol = SOCK_STREAM;
break;
default:
/* PROXYv2 prohibits other values */
return kr_error(KNOT_EMALF);
}
/* Parse addresses */
union proxy2_address* addr = proxy2_get_address(hdr);
size_t addr_length = 0;
switch(out->family) {
case AF_INET:
addr_length = sizeof(addr->ipv4_addr);
if (content_length < addr_length)
return kr_error(KNOT_EMALF);
out->src_addr.ip4 = (struct sockaddr_in){
.sin_family = AF_INET,
.sin_addr = { .s_addr = addr->ipv4_addr.src_addr },
.sin_port = addr->ipv4_addr.src_port,
};
out->dst_addr.ip4 = (struct sockaddr_in){
.sin_family = AF_INET,
.sin_addr = { .s_addr = addr->ipv4_addr.dst_addr },
.sin_port = addr->ipv4_addr.dst_port,
};
break;
case AF_INET6:
addr_length = sizeof(addr->ipv6_addr);
if (content_length < addr_length)
return kr_error(KNOT_EMALF);
out->src_addr.ip6 = (struct sockaddr_in6){
.sin6_family = AF_INET6,
.sin6_port = addr->ipv6_addr.src_port
};
memcpy(
&out->src_addr.ip6.sin6_addr.s6_addr,
&addr->ipv6_addr.src_addr,
sizeof(out->src_addr.ip6.sin6_addr.s6_addr));
out->dst_addr.ip6 = (struct sockaddr_in6){
.sin6_family = AF_INET6,
.sin6_port = addr->ipv6_addr.dst_port
};
memcpy(
&out->dst_addr.ip6.sin6_addr.s6_addr,
&addr->ipv6_addr.dst_addr,
sizeof(out->dst_addr.ip6.sin6_addr.s6_addr));
break;
default:; /* Keep zero from initializer. */
}
/* Process additional information */
for (struct proxy2_tlv *tlv = get_tlvs(hdr, addr_length); has_tlv(hdr, tlv); next_tlv(&tlv)) {
switch (tlv->type) {
case TLV_TYPE_SSL:
out->has_tls = true;
break;
default:; /* Ignore others - add more if needed */
}
}
fill_wirebuf:
return hdr_len;
}
/** Checks for a PROXY protocol version 2 signature in the specified buffer. */
static inline bool proxy_header_present(const void* buf, const ssize_t nread)
{
return nread >= PROXY2_MIN_SIZE &&
memcmp(buf, PROXY2_SIGNATURE, sizeof(PROXY2_SIGNATURE)) == 0;
}
struct pl_proxyv2_state {
struct protolayer_data h;
/** Storage for data parsed from PROXY header. */
struct proxy_result proxy;
/** Stream/TCP: Some data has already arrived and we are not expecting
* PROXY header anymore. */
bool had_data : 1;
};
static enum protolayer_iter_cb_result pl_proxyv2_dgram_unwrap(
void *sess_data, void *iter_data, struct protolayer_iter_ctx *ctx)
{
ctx->payload = protolayer_payload_as_buffer(&ctx->payload);
if (kr_fails_assert(ctx->payload.type == PROTOLAYER_PAYLOAD_BUFFER)) {
/* unsupported payload */
return protolayer_break(ctx, kr_error(EINVAL));
}
struct session2 *s = ctx->session;
struct pl_proxyv2_state *proxy_state = iter_data;
char *data = ctx->payload.buffer.buf;
ssize_t data_len = ctx->payload.buffer.len;
struct comm_info *comm = ctx->comm;
if (!s->outgoing && proxy_header_present(data, data_len)) {
if (!proxy_allowed(comm->comm_addr)) {
kr_log_debug(IO, "<= ignoring PROXYv2 UDP from disallowed address '%s'\n",
kr_straddr(comm->comm_addr));
return protolayer_break(ctx, kr_error(EPERM));
}
ssize_t trimmed = proxy_process_header(&proxy_state->proxy, data, data_len);
if (trimmed == KNOT_EMALF) {
if (kr_log_is_debug(IO, NULL)) {
kr_log_debug(IO, "<= ignoring malformed PROXYv2 UDP "
"from address '%s'\n",
kr_straddr(comm->comm_addr));
}
return protolayer_break(ctx, kr_error(EINVAL));
} else if (trimmed < 0) {
if (kr_log_is_debug(IO, NULL)) {
kr_log_debug(IO, "<= error processing PROXYv2 UDP "
"from address '%s', ignoring\n",
kr_straddr(comm->comm_addr));
}
return protolayer_break(ctx, kr_error(EINVAL));
}
if (proxy_state->proxy.command == PROXY2_CMD_PROXY && proxy_state->proxy.family != AF_UNSPEC) {
comm->src_addr = &proxy_state->proxy.src_addr.ip;
comm->dst_addr = &proxy_state->proxy.dst_addr.ip;
comm->proxy = &proxy_state->proxy;
if (kr_log_is_debug(IO, NULL)) {
kr_log_debug(IO, "<= UDP query from '%s'\n",
kr_straddr(comm->src_addr));
kr_log_debug(IO, "<= proxied through '%s'\n",
kr_straddr(comm->comm_addr));
}
}
ctx->payload = protolayer_payload_buffer(
data + trimmed, data_len - trimmed, false);
}
return protolayer_continue(ctx);
}
static enum protolayer_iter_cb_result pl_proxyv2_stream_unwrap(
void *sess_data, void *iter_data, struct protolayer_iter_ctx *ctx)
{
struct session2 *s = ctx->session;
struct pl_proxyv2_state *proxy_state = sess_data;
struct sockaddr *peer = session2_get_peer(s);
if (kr_fails_assert(ctx->payload.type == PROTOLAYER_PAYLOAD_WIRE_BUF)) {
/* Only wire buffer is supported */
return protolayer_break(ctx, kr_error(EINVAL));
}
char *data = wire_buf_data(ctx->payload.wire_buf); /* layer's or session's wirebuf */
ssize_t data_len = wire_buf_data_length(ctx->payload.wire_buf);
struct comm_info *comm = ctx->comm;
if (!s->outgoing && !proxy_state->had_data && proxy_header_present(data, data_len)) {
if (!proxy_allowed(comm->src_addr)) {
if (kr_log_is_debug(IO, NULL)) {
kr_log_debug(IO, "<= connection to '%s': PROXYv2 not allowed "
"for this peer, close\n",
kr_straddr(peer));
}
session2_force_close(s);
return protolayer_break(ctx, kr_error(ECONNRESET));
}
ssize_t trimmed = proxy_process_header(&proxy_state->proxy, data, data_len);
if (trimmed < 0) {
if (kr_log_is_debug(IO, NULL)) {
if (trimmed == KNOT_EMALF) {
kr_log_debug(IO, "<= connection to '%s': "
"malformed PROXYv2 header, close\n",
kr_straddr(comm->src_addr));
} else {
kr_log_debug(IO, "<= connection to '%s': "
"error processing PROXYv2 header, close\n",
kr_straddr(comm->src_addr));
}
}
session2_force_close(s);
return protolayer_break(ctx, kr_error(ECONNRESET));
} else if (trimmed == 0) {
session2_close(s);
return protolayer_break(ctx, kr_error(ECONNRESET));
}
if (proxy_state->proxy.command != PROXY2_CMD_LOCAL && proxy_state->proxy.family != AF_UNSPEC) {
comm->src_addr = &proxy_state->proxy.src_addr.ip;
comm->dst_addr = &proxy_state->proxy.dst_addr.ip;
comm->proxy = &proxy_state->proxy;
if (kr_log_is_debug(IO, NULL)) {
kr_log_debug(IO, "<= TCP stream from '%s'\n",
kr_straddr(comm->src_addr));
kr_log_debug(IO, "<= proxied through '%s'\n",
kr_straddr(comm->comm_addr));
}
}
wire_buf_trim(ctx->payload.wire_buf, trimmed);
}
proxy_state->had_data = true;
return protolayer_continue(ctx);
}
__attribute__((constructor))
static void proxy_protolayers_init(void)
{
protolayer_globals[PROTOLAYER_TYPE_PROXYV2_DGRAM] = (struct protolayer_globals){
.iter_size = sizeof(struct pl_proxyv2_state),
.unwrap = pl_proxyv2_dgram_unwrap,
};
protolayer_globals[PROTOLAYER_TYPE_PROXYV2_STREAM] = (struct protolayer_globals){
.sess_size = sizeof(struct pl_proxyv2_state),
.unwrap = pl_proxyv2_stream_unwrap,
};
}
/* Copyright (C) CZ.NIC, z.s.p.o. <knot-resolver@labs.nic.cz>
* SPDX-License-Identifier: GPL-3.0-or-later
*/
#pragma once
#include <stdint.h>
#include "lib/utils.h"
enum proxy2_command {
PROXY2_CMD_LOCAL = 0x0,
PROXY2_CMD_PROXY = 0x1
};
/** Parsed result of the PROXY protocol */
struct proxy_result {
/** Proxy command - PROXY or LOCAL. */
enum proxy2_command command;
/** Address family from netinet library (e.g. AF_INET6). */
int family;
/** Protocol type from socket library (e.g. SOCK_STREAM). */
int protocol;
/** Parsed source address and port. */
union kr_sockaddr src_addr;
/** Parsed destination address and port. */
union kr_sockaddr dst_addr;
/** `true` = client has used TLS with the proxy. If TLS padding is
* enabled, it will be used even if the communication between kresd and
* the proxy is unencrypted. */
bool has_tls : 1;
};
/** Checks whether the use of PROXYv2 protocol is allowed for the specified
* address. */
bool proxy_allowed(const struct sockaddr *saddr);
# SPDX-License-Identifier: GPL-3.0-or-later
#
programs:
- name: dnsdist
binary: dnsdist
additional:
- --verbose
- --supervised
- --config
- dnsdist.conf
ignore_exit_code: True
templates:
- daemon/proxyv2.test/dnsdist_config.j2
configs:
- dnsdist.conf
- name: kresd
binary: kresd
additional:
- --noninteractive
templates:
- daemon/proxyv2.test/kresd_config.j2
- tests/integration/hints_zone.j2
configs:
- config
- hints
-- vim:syntax=lua
setLocal('{{SELF_ADDR}}')
setVerboseHealthChecks(true)
setServerPolicy(firstAvailable)
local server = newServer({
address="{{PROGRAMS['kresd']['address']}}",
useProxyProtocol=true,
checkName="example.cz."
})
server:setUp()
-- SPDX-License-Identifier: GPL-3.0-or-later
{% raw %}
modules.load('view < policy')
view:addr("127.127.0.0", policy.suffix(policy.DENY_MSG("addr 127.127.0.0 matched com"),{"\3com\0"}))
-- make sure DNSSEC is turned off for tests
trust_anchors.remove('.')
-- Disable RFC5011 TA update
if ta_update then
modules.unload('ta_update')
end
-- Disable RFC8145 signaling, scenario doesn't provide expected answers
if ta_signal_query then
modules.unload('ta_signal_query')
end
-- Disable RFC8109 priming, scenario doesn't provide expected answers
if priming then
modules.unload('priming')
end
-- Disable this module because it make one priming query
if detect_time_skew then
modules.unload('detect_time_skew')
end
_hint_root_file('hints')
cache.size = 2*MB
log_level('debug')
{% endraw %}
-- Allow PROXYv2 from dnsdist's address
--net.proxy_allowed("{{PROGRAMS['dnsdist']['address']}}")
net.proxy_allowed("127.127.0.0/16")
net = { '{{SELF_ADDR}}' }
{% if QMIN == "false" %}
option('NO_MINIMIZE', true)
{% else %}
option('NO_MINIMIZE', false)
{% endif %}
-- Self-checks on globals
assert(help() ~= nil)
assert(worker.id ~= nil)
-- Self-checks on facilities
assert(cache.count() == 0)
assert(cache.stats() ~= nil)
assert(cache.backends() ~= nil)
assert(worker.stats() ~= nil)
assert(net.interfaces() ~= nil)
-- Self-checks on loaded stuff
assert(net.list()[1].transport.ip == '{{SELF_ADDR}}')
assert(#modules.list() > 0)
-- Self-check timers
ev = event.recurrent(1 * sec, function (ev) return 1 end)
event.cancel(ev)
ev = event.after(0, function (ev) return 1 end)
; SPDX-License-Identifier: GPL-3.0-or-later
; config options
stub-addr: 1.2.3.4
query-minimization: off
CONFIG_END
SCENARIO_BEGIN proxyv2:valid test
RANGE_BEGIN 0 110
ADDRESS 1.2.3.4
ENTRY_BEGIN
MATCH opcode qtype qname
ADJUST copy_id
REPLY QR RD RA NOERROR
SECTION QUESTION
example.cz. IN A
SECTION ANSWER
example.cz. IN A 5.6.7.8
ENTRY_END
ENTRY_BEGIN
MATCH opcode qtype qname
ADJUST copy_id
REPLY QR RD RA NOERROR
SECTION QUESTION
k.root-servers.net. IN AAAA
SECTION ANSWER
k.root-servers.net. IN AAAA ::1
ENTRY_END
RANGE_END
; query with PROXYv2 header - not blocked
STEP 10 QUERY
ENTRY_BEGIN
ADJUST raw_id
REPLY RD
SECTION QUESTION
example.cz. IN A
ENTRY_END
STEP 20 CHECK_ANSWER
ENTRY_BEGIN
MATCH flags rcode question answer
REPLY QR RD RA NOERROR
SECTION QUESTION
example.cz. IN A
SECTION ANSWER
example.cz. IN A 5.6.7.8
ENTRY_END
; query with PROXYv2 header - blocked by view:addr
; NXDOMAIN expected
STEP 30 QUERY
ENTRY_BEGIN
REPLY RD
SECTION QUESTION
example.com. IN A
ENTRY_END
STEP 31 CHECK_ANSWER
ENTRY_BEGIN
MATCH opcode question rcode additional
REPLY QR RD RA AA NXDOMAIN
SECTION QUESTION
example.com. IN A
SECTION ADDITIONAL
explanation.invalid. 10800 IN TXT "addr 127.127.0.0 matched com"
ENTRY_END
SCENARIO_END
/* Copyright (C) CZ.NIC, z.s.p.o. <knot-resolver@labs.nic.cz>
* SPDX-License-Identifier: GPL-3.0-or-later
*/
#include <stdatomic.h>
#include "daemon/ratelimiting.h"
#include "lib/kru.h"
#include "lib/mmapped.h"
#include "lib/utils.h"
#include "lib/resolve.h"
#define V4_PREFIXES (uint8_t[]) { 18, 20, 24, 32 }
#define V4_RATE_MULT (kru_price_t[]) { 768, 256, 32, 1 }
#define V6_PREFIXES (uint8_t[]) { 32, 48, 56, 64, 128 }
#define V6_RATE_MULT (kru_price_t[]) { 64, 4, 3, 2, 1 }
#define V4_PREFIXES_CNT (sizeof(V4_PREFIXES) / sizeof(*V4_PREFIXES))
#define V6_PREFIXES_CNT (sizeof(V6_PREFIXES) / sizeof(*V6_PREFIXES))
#define MAX_PREFIXES_CNT ((V4_PREFIXES_CNT > V6_PREFIXES_CNT) ? V4_PREFIXES_CNT : V6_PREFIXES_CNT)
struct ratelimiting {
size_t capacity;
uint32_t instant_limit;
uint32_t rate_limit;
uint32_t log_period;
uint16_t slip;
bool dry_run;
bool using_avx2;
_Atomic uint32_t log_time;
kru_price_t v4_prices[V4_PREFIXES_CNT];
kru_price_t v6_prices[V6_PREFIXES_CNT];
_Alignas(64) uint8_t kru[];
};
struct ratelimiting *ratelimiting = NULL;
struct mmapped ratelimiting_mmapped = {0};
/// return whether we're using optimized variant right now
static bool using_avx2(void)
{
bool result = (KRU.initialize == KRU_AVX2.initialize);
kr_require(result || KRU.initialize == KRU_GENERIC.initialize);
return result;
}
int ratelimiting_init(const char *mmap_file, size_t capacity, uint32_t instant_limit,
uint32_t rate_limit, uint16_t slip, uint32_t log_period, bool dry_run)
{
size_t capacity_log = 0;
for (size_t c = capacity - 1; c > 0; c >>= 1) capacity_log++;
size_t size = offsetof(struct ratelimiting, kru) + KRU.get_size(capacity_log);
struct ratelimiting header = {
.capacity = capacity,
.instant_limit = instant_limit,
.rate_limit = rate_limit,
.log_period = log_period,
.slip = slip,
.dry_run = dry_run,
.using_avx2 = using_avx2()
};
size_t header_size = offsetof(struct ratelimiting, using_avx2) + sizeof(header.using_avx2);
static_assert( // no padding up to .using_avx2
offsetof(struct ratelimiting, using_avx2) ==
sizeof(header.capacity) +
sizeof(header.instant_limit) +
sizeof(header.rate_limit) +
sizeof(header.log_period) +
sizeof(header.slip) +
sizeof(header.dry_run),
"detected padding with undefined data inside mmapped header");
int ret = mmapped_init(&ratelimiting_mmapped, mmap_file, size, &header, header_size);
if (ret == MMAPPED_WAS_FIRST) {
kr_log_info(SYSTEM, "Initializing rate-limiting...\n");
ratelimiting = ratelimiting_mmapped.mem;
const kru_price_t base_price = KRU_LIMIT / instant_limit;
const kru_price_t max_decay = rate_limit > 1000ll * instant_limit ? base_price :
(uint64_t) base_price * rate_limit / 1000;
bool succ = KRU.initialize((struct kru *)ratelimiting->kru, capacity_log, max_decay);
if (!succ) {
ratelimiting = NULL;
ret = kr_error(EINVAL);
goto fail;
}
ratelimiting->log_time = kr_now() - log_period;
for (size_t i = 0; i < V4_PREFIXES_CNT; i++) {
ratelimiting->v4_prices[i] = base_price / V4_RATE_MULT[i];
}
for (size_t i = 0; i < V6_PREFIXES_CNT; i++) {
ratelimiting->v6_prices[i] = base_price / V6_RATE_MULT[i];
}
ret = mmapped_init_continue(&ratelimiting_mmapped);
if (ret != 0) goto fail;
kr_log_info(SYSTEM, "Rate-limiting initialized (%s).\n", (ratelimiting->using_avx2 ? "AVX2" : "generic"));
return 0;
} else if (ret == 0) {
ratelimiting = ratelimiting_mmapped.mem;
kr_log_info(SYSTEM, "Using existing rate-limiting data (%s).\n", (ratelimiting->using_avx2 ? "AVX2" : "generic"));
return 0;
} // else fail
fail:
kr_log_crit(SYSTEM, "Initialization of shared rate-limiting data failed.\n");
return ret;
}
void ratelimiting_deinit(void)
{
mmapped_deinit(&ratelimiting_mmapped);
ratelimiting = NULL;
}
bool ratelimiting_request_begin(struct kr_request *req)
{
if (!ratelimiting) return false;
if (!req->qsource.addr)
return false; // don't consider internal requests
if (req->qsource.price_factor16 == 0)
return false; // whitelisted
// We only do this on pure UDP. (also TODO if cookies get implemented)
const bool ip_validated = req->qsource.flags.tcp || req->qsource.flags.tls;
if (ip_validated) return false;
const uint32_t time_now = kr_now();
// classify
_Alignas(16) uint8_t key[16] = {0, };
uint8_t limited_prefix;
if (req->qsource.addr->sa_family == AF_INET6) {
struct sockaddr_in6 *ipv6 = (struct sockaddr_in6 *)req->qsource.addr;
memcpy(key, &ipv6->sin6_addr, 16);
// compute adjusted prices, using standard rounding
kru_price_t prices[V6_PREFIXES_CNT];
for (int i = 0; i < V6_PREFIXES_CNT; ++i) {
prices[i] = (req->qsource.price_factor16
* (uint64_t)ratelimiting->v6_prices[i] + (1<<15)) >> 16;
}
limited_prefix = KRU.limited_multi_prefix_or((struct kru *)ratelimiting->kru, time_now,
1, key, V6_PREFIXES, prices, V6_PREFIXES_CNT, NULL);
} else {
struct sockaddr_in *ipv4 = (struct sockaddr_in *)req->qsource.addr;
memcpy(key, &ipv4->sin_addr, 4); // TODO append port?
// compute adjusted prices, using standard rounding
kru_price_t prices[V4_PREFIXES_CNT];
for (int i = 0; i < V4_PREFIXES_CNT; ++i) {
prices[i] = (req->qsource.price_factor16
* (uint64_t)ratelimiting->v4_prices[i] + (1<<15)) >> 16;
}
limited_prefix = KRU.limited_multi_prefix_or((struct kru *)ratelimiting->kru, time_now,
0, key, V4_PREFIXES, prices, V4_PREFIXES_CNT, NULL);
}
if (!limited_prefix) return false; // not limited
// slip: truncating vs dropping
bool tc =
(ratelimiting->slip > 1) ?
((kr_rand_bytes(1) % ratelimiting->slip == 0) ? true : false) :
((ratelimiting->slip == 1) ? true : false);
// logging
uint32_t log_time_orig = atomic_load_explicit(&ratelimiting->log_time, memory_order_relaxed);
if (ratelimiting->log_period) {
while (time_now - log_time_orig + 1024 >= ratelimiting->log_period + 1024) {
if (atomic_compare_exchange_weak_explicit(&ratelimiting->log_time, &log_time_orig, time_now,
memory_order_relaxed, memory_order_relaxed)) {
kr_log_notice(SYSTEM, "address %s rate-limited on /%d (%s%s)\n",
kr_straddr(req->qsource.addr), limited_prefix,
ratelimiting->dry_run ? "dry-run, " : "",
tc ? "truncated" : "dropped");
break;
}
}
}
req->ratelimited = true; // we set this even on dry_run
if (ratelimiting->dry_run) return false;
// perform limiting
if (tc) { // TC=1: return truncated reply to force source IP validation
knot_pkt_t *answer = kr_request_ensure_answer(req);
if (!answer) { // something bad; TODO: perhaps improve recovery from this
kr_assert(false);
return true;
}
// at this point the packet should be pretty clear
// The TC=1 answer is not perfect, as the right RCODE might differ
// in some cases, but @vcunat thinks that NOERROR isn't really risky here.
knot_wire_set_tc(answer->wire);
knot_wire_clear_ad(answer->wire);
req->state = KR_STATE_DONE;
} else {
// no answer
req->options.NO_ANSWER = true;
req->state = KR_STATE_FAIL;
}
return true;
}
/* Copyright (C) CZ.NIC, z.s.p.o. <knot-resolver@labs.nic.cz>
* SPDX-License-Identifier: GPL-3.0-or-later
*/
#include <stdbool.h>
#include "lib/defines.h"
#include "lib/utils.h"
struct kr_request;
/** Initialize rate-limiting with shared mmapped memory.
* The existing data are used if another instance is already using the file
* and it was initialized with the same parameters; it fails on mismatch. */
KR_EXPORT
int ratelimiting_init(const char *mmap_file, size_t capacity, uint32_t instant_limit,
uint32_t rate_limit, uint16_t slip, uint32_t log_period, bool dry_run);
/** Do rate-limiting, during knot_layer_api::begin. */
KR_EXPORT
bool ratelimiting_request_begin(struct kr_request *req);
/** Remove mmapped file data if not used by other processes. */
KR_EXPORT
void ratelimiting_deinit(void);
/* Copyright (C) 2024 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
static void the_tests(void **state);
#include "./tests.inc.c" // NOLINT(bugprone-suspicious-include)
#define THREADS 4
#define BATCH_QUERIES_LOG 3 // threads acquire queries in batches of 8
#define HOSTS_LOG 3 // at most 6 attackers + 2 wildcard addresses for normal users
#define TICK_QUERIES_LOG 13 // at most 1024 queries per host per tick
// Expected range of limits for parallel test.
#define RANGE_INST(Vx, prefix) INST(Vx, prefix) - 1, INST(Vx, prefix) + THREADS - 1
#define RANGE_RATEM(Vx, prefix) RATEM(Vx, prefix) - 1, RATEM(Vx, prefix)
#define RANGE_UNLIM(queries) queries, queries
struct host {
uint32_t queries_per_tick;
int addr_family;
char *addr_format;
uint32_t min_passed, max_passed;
_Atomic uint32_t passed;
};
struct stage {
uint32_t first_tick, last_tick;
struct host hosts[1 << HOSTS_LOG];
};
struct runnable_data {
int prime;
_Atomic uint32_t *queries_acquired, *queries_done;
struct stage *stages;
};
static void *runnable(void *arg)
{
struct runnable_data *d = (struct runnable_data *)arg;
size_t si = 0;
char addr_str[40];
struct sockaddr_storage addr;
uint8_t wire[KNOT_WIRE_MIN_PKTSIZE] = { 0 };
knot_pkt_t answer = { .wire = wire };
struct kr_request req = {
.qsource.addr = (struct sockaddr *) &addr,
.qsource.price_factor16 = 1 << 16,
.answer = &answer
};
while (true) {
uint32_t qi1 = atomic_fetch_add(d->queries_acquired, 1 << BATCH_QUERIES_LOG);
/* increment time if needed; sync on incrementing using spinlock */
uint32_t tick = qi1 >> TICK_QUERIES_LOG;
for (size_t i = 1; tick != fakeclock_tick; i++) {
if ((*d->queries_done >> TICK_QUERIES_LOG) >= tick) {
fakeclock_tick = tick;
}
if (i % (1<<14) == 0) sched_yield();
__sync_synchronize();
}
/* increment stage if needed */
while (tick > d->stages[si].last_tick) {
++si;
if (!d->stages[si].first_tick) return NULL;
}
if (tick >= d->stages[si].first_tick) {
uint32_t qi2 = 0;
do {
uint32_t qi = qi1 + qi2;
/* perform query qi */
uint32_t hi = qi % (1 << HOSTS_LOG);
if (!d->stages[si].hosts[hi].queries_per_tick) continue;
uint32_t hqi = (qi % (1 << TICK_QUERIES_LOG)) >> HOSTS_LOG; // host query index within tick
if (hqi >= d->stages[si].hosts[hi].queries_per_tick) continue;
hqi += (qi >> TICK_QUERIES_LOG) * d->stages[si].hosts[hi].queries_per_tick; // across ticks
(void)snprintf(addr_str, sizeof(addr_str), d->stages[si].hosts[hi].addr_format,
hqi % 0xff, (hqi >> 8) % 0xff, (hqi >> 16) % 0xff);
kr_straddr_socket_set((struct sockaddr *)&addr, addr_str, 0);
if (!ratelimiting_request_begin(&req)) {
atomic_fetch_add(&d->stages[si].hosts[hi].passed, 1);
}
} while ((qi2 = (qi2 + d->prime) % (1 << BATCH_QUERIES_LOG)));
}
atomic_fetch_add(d->queries_done, 1 << BATCH_QUERIES_LOG);
}
}
static void the_tests(void **state)
{
/* parallel tests */
struct stage stages[] = {
/* first tick, last tick, hosts */
{32, 32, {
/* queries per tick, family, address, min passed, max passed */
{1024, AF_INET, "%d.%d.%d.1", RANGE_UNLIM ( 1024 )},
{1024, AF_INET, "3.3.3.3", RANGE_INST ( V4, 32 )},
{ 512, AF_INET, "4.4.4.4", RANGE_INST ( V4, 32 )},
{1024, AF_INET6, "%x%x:%x00::1", RANGE_UNLIM ( 1024 )},
{1024, AF_INET6, "3333::3333", RANGE_INST ( V6, 128 )},
{ 512, AF_INET6, "4444::4444", RANGE_INST ( V6, 128 )}
}},
{33, 255, {
{1024, AF_INET, "%d.%d.%d.1", RANGE_UNLIM ( 1024 )},
{1024, AF_INET, "3.3.3.3", RANGE_RATEM ( V4, 32 )},
{ 512, AF_INET, "4.4.4.4", RANGE_RATEM ( V4, 32 )},
{1024, AF_INET6, "%x%x:%x00::1", RANGE_UNLIM ( 1024 )},
{1024, AF_INET6, "3333::3333", RANGE_RATEM ( V6, 128 )},
{ 512, AF_INET6, "4444::4444", RANGE_RATEM ( V6, 128 )},
}},
{256, 511, {
{1024, AF_INET, "3.3.3.3", RANGE_RATEM ( V4, 32 )},
{1024, AF_INET6, "3333::3333", RANGE_RATEM ( V6, 128 )}
}},
{512, 512, {
{1024, AF_INET, "%d.%d.%d.1", RANGE_UNLIM ( 1024 )},
{1024, AF_INET, "3.3.3.3", RANGE_RATEM ( V4, 32 )},
{ 512, AF_INET, "4.4.4.4", RANGE_INST ( V4, 32 )},
{1024, AF_INET6, "%x%x:%x00::1", RANGE_UNLIM ( 1024 )},
{1024, AF_INET6, "3333::3333", RANGE_RATEM ( V6, 128 )},
{ 512, AF_INET6, "4444::4444", RANGE_INST ( V6, 128 )}
}},
{0}
};
pthread_t thr[THREADS];
struct runnable_data rd[THREADS];
_Atomic uint32_t queries_acquired = 0, queries_done = 0;
int primes[] = {3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61};
assert(sizeof(primes)/sizeof(*primes) >= THREADS);
for (unsigned i = 0; i < THREADS; ++i) {
rd[i].queries_acquired = &queries_acquired;
rd[i].queries_done = &queries_done;
rd[i].prime = primes[i];
rd[i].stages = stages;
pthread_create(thr + i, NULL, &runnable, rd + i);
}
for (unsigned i = 0; i < THREADS; ++i) {
pthread_join(thr[i], NULL);
}
unsigned si = 0;
do {
struct host * const h = stages[si].hosts;
uint32_t ticks = stages[si].last_tick - stages[si].first_tick + 1;
for (size_t i = 0; h[i].queries_per_tick; i++) {
assert_int_between(h[i].passed, ticks * h[i].min_passed, ticks * h[i].max_passed,
"parallel stage %d, addr %-25s", si, h[i].addr_format);
}
} while (stages[++si].first_tick);
}
/* Copyright (C) 2024 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
static void the_tests(void **state);
#include "./tests.inc.c" // NOLINT(bugprone-suspicious-include)
// defining count_test as macro to let it print usable line number on failure
#define count_test(DESC, EXPECTED_PASSING, MARGIN_FRACT, ...) { \
int _max_diff = (EXPECTED_PASSING) * (MARGIN_FRACT); \
int cnt = _count_test(EXPECTED_PASSING, __VA_ARGS__); \
assert_int_between(cnt, (EXPECTED_PASSING) - _max_diff, (EXPECTED_PASSING) + _max_diff, DESC); }
uint32_t _count_test(int expected_passing, int addr_family, char *addr_format, uint32_t min_value, uint32_t max_value)
{
uint32_t max_queries = expected_passing > 0 ? 2 * expected_passing : -expected_passing;
struct sockaddr_storage addr;
uint8_t wire[KNOT_WIRE_MIN_PKTSIZE] = { 0 };
knot_pkt_t answer = { .wire = wire };
struct kr_request req = {
.qsource.addr = (struct sockaddr *) &addr,
.qsource.price_factor16 = 1 << 16,
.answer = &answer
};
char addr_str[40];
int cnt = -1;
for (size_t i = 0; i < max_queries; i++) {
(void)snprintf(addr_str, sizeof(addr_str), addr_format,
i % (max_value - min_value + 1) + min_value,
i / (max_value - min_value + 1) % 256);
kr_straddr_socket_set((struct sockaddr *) &addr, addr_str, 0);
if (ratelimiting_request_begin(&req)) {
cnt = i;
break;
}
}
return cnt;
}
static void the_tests(void **state)
{
/* IPv4 multi-prefix tests */
static_assert(V4_PREFIXES_CNT == 4,
"There are no more IPv4 limited prefixes (/32, /24, /20, /18 will be tested).");
count_test("IPv4 instant limit /32", INST(V4, 32), 0,
AF_INET, "128.0.0.0", 0, 0);
count_test("IPv4 instant limit /32 not applied on /31", -1, 0,
AF_INET, "128.0.0.1", 0, 0);
count_test("IPv4 instant limit /24", INST(V4, 24) - INST(V4, 32) - 1, 0,
AF_INET, "128.0.0.%d", 2, 255);
count_test("IPv4 instant limit /24 not applied on /23", -1, 0,
AF_INET, "128.0.1.0", 0, 0);
count_test("IPv4 instant limit /20", INST(V4, 20) - INST(V4, 24) - 1, 0.001,
AF_INET, "128.0.%d.%d", 2, 15);
count_test("IPv4 instant limit /20 not applied on /19", -1, 0,
AF_INET, "128.0.16.0", 0, 0);
count_test("IPv4 instant limit /18", INST(V4, 18) - INST(V4, 20) - 1, 0.01,
AF_INET, "128.0.%d.%d", 17, 63);
count_test("IPv4 instant limit /18 not applied on /17", -1, 0,
AF_INET, "128.0.64.0", 0, 0);
/* IPv6 multi-prefix tests */
static_assert(V6_PREFIXES_CNT == 5,
"There are no more IPv6 limited prefixes (/128, /64, /56, /48, /32 will be tested).");
count_test("IPv6 instant limit /128, independent to IPv4", INST(V6, 128), 0,
AF_INET6, "8000::", 0, 0);
count_test("IPv6 instant limit /128 not applied on /127", -1, 0,
AF_INET6, "8000::1", 0, 0);
count_test("IPv6 instant limit /64", INST(V6, 64) - INST(V6, 128) - 1, 0,
AF_INET6, "8000:0:0:0:%02x%02x::", 0x01, 0xff);
count_test("IPv6 instant limit /64 not applied on /63", -1, 0,
AF_INET6, "8000:0:0:1::", 0, 0);
count_test("IPv6 instant limit /56", INST(V6, 56) - INST(V6, 64) - 1, 0,
AF_INET6, "8000:0:0:00%02x:%02x00::", 0x02, 0xff);
count_test("IPv6 instant limit /56 not applied on /55", -1, 0,
AF_INET6, "8000:0:0:0100::", 0, 0);
count_test("IPv6 instant limit /48", INST(V6, 48) - INST(V6, 56) - 1, 0.01,
AF_INET6, "8000:0:0:%02x%02x::", 0x02, 0xff);
count_test("IPv6 instant limit /48 not applied on /47", -1, 0,
AF_INET6, "8000:0:1::", 0, 0);
count_test("IPv6 instant limit /32", INST(V6, 32) - INST(V6, 48) - 1, 0.001,
AF_INET6, "8000:0:%02x%02x::", 0x02, 0xff);
count_test("IPv6 instant limit /32 not applied on /31", -1, 0,
AF_INET6, "8000:1::", 0, 0);
/* limit after 1 msec */
fakeclock_tick++;
count_test("IPv4 rate limit /32 after 1 msec", RATEM(V4, 32), 0,
AF_INET, "128.0.0.0", 0, 0);
count_test("IPv6 rate limit /128 after 1 msec", RATEM(V6, 128), 0,
AF_INET6, "8000::", 0, 0);
}
/* Copyright (C) 2024 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
#include <pthread.h>
#include <sched.h>
#include <stdio.h>
#include <stdatomic.h>
#include "tests/unit/test.h"
#include "libdnssec/crypto.h"
#include "libdnssec/random.h"
#include "libknot/libknot.h"
#include "contrib/openbsd/siphash.h"
#include "lib/resolve.h"
#include "lib/utils.h"
uint64_t fakeclock_now(void);
#define kr_now fakeclock_now
#include "daemon/ratelimiting.c"
#undef kr_now
#define RRL_TABLE_SIZE (1 << 20)
#define RRL_INSTANT_LIMIT (1 << 8)
#define RRL_RATE_LIMIT (1 << 17)
#define RRL_BASE_PRICE (KRU_LIMIT / RRL_INSTANT_LIMIT)
// Accessing RRL configuration of INSTANT/RATE limits for V4/V6 and specific prefix.
#define LIMIT(type, Vx, prefix) (RRL_MULT(Vx, prefix) * RRL_ ## type ## _LIMIT)
#define RRL_CONFIG(Vx, name) Vx ## _ ## name
#define RRL_MULT(Vx, prefix) get_mult(RRL_CONFIG(Vx, PREFIXES), RRL_CONFIG(Vx, RATE_MULT), RRL_CONFIG(Vx, PREFIXES_CNT), prefix)
static inline kru_price_t get_mult(uint8_t prefixes[], kru_price_t mults[], size_t cnt, uint8_t wanted_prefix) {
for (size_t i = 0; i < cnt; i++)
if (prefixes[i] == wanted_prefix)
return mults[i];
assert(0);
return 0;
}
// Instant limits and rate limits per msec.
#define INST(Vx, prefix) LIMIT(INSTANT, Vx, prefix)
#define RATEM(Vx, prefix) (LIMIT(RATE, Vx, prefix) / 1000)
/* Fix seed for randomness in RLL module. Change if improbable collisions arise. (one byte) */
#define RRL_SEED_GENERIC 1
#define RRL_SEED_AVX2 1
#define assert_int_between(VAL, MIN, MAX, ...) \
if (((MIN) > (VAL)) || ((VAL) > (MAX))) { \
fprintf(stderr, __VA_ARGS__); fprintf(stderr, ": %d <= %d <= %d, ", MIN, VAL, MAX); \
assert_true(false); }
struct kru_generic {
SIPHASH_KEY hash_key;
// ...
};
struct kru_avx2 {
_Alignas(32) char hash_key[48];
// ...
};
/* Override time. */
uint64_t fakeclock_tick = 0;
uint64_t fakeclock_start = 0;
void fakeclock_init(void)
{
fakeclock_start = kr_now();
fakeclock_tick = 0;
}
uint64_t fakeclock_now(void)
{
return fakeclock_start + fakeclock_tick;
}
static void test_rrl(void **state) {
dnssec_crypto_init();
fakeclock_init();
/* create rrl table */
const char *tmpdir = test_tmpdir_create();
char mmap_file[64];
stpcpy(stpcpy(mmap_file, tmpdir), "/ratelimiting");
ratelimiting_init(mmap_file, RRL_TABLE_SIZE, RRL_INSTANT_LIMIT, RRL_RATE_LIMIT, 0, 0, false);
if (KRU.initialize == KRU_GENERIC.initialize) {
struct kru_generic *kru = (struct kru_generic *) ratelimiting->kru;
memset(&kru->hash_key, RRL_SEED_GENERIC, sizeof(kru->hash_key));
} else if (KRU.initialize == KRU_AVX2.initialize) {
struct kru_avx2 *kru = (struct kru_avx2 *) ratelimiting->kru;
memset(&kru->hash_key, RRL_SEED_AVX2, sizeof(kru->hash_key));
} else {
assert(0);
}
the_tests(state);
ratelimiting_deinit();
test_tmpdir_remove(tmpdir);
dnssec_crypto_cleanup();
}
static void test_rrl_generic(void **state) {
KRU = KRU_GENERIC;
test_rrl(state);
}
static void test_rrl_avx2(void **state) {
KRU = KRU_AVX2;
test_rrl(state);
}
int main(int argc, char *argv[])
{
assert(KRU_GENERIC.initialize != KRU_AVX2.initialize);
if (KRU.initialize == KRU_AVX2.initialize) {
const UnitTest tests[] = {
unit_test(test_rrl_generic),
unit_test(test_rrl_avx2)
};
return run_tests(tests);
} else {
const UnitTest tests[] = {
unit_test(test_rrl_generic)
};
return run_tests(tests);
}
}
.. SPDX-License-Identifier: GPL-3.0-or-later
.. _runtime-cfg:
Run-time reconfiguration
========================
Knot Resolver offers several ways to modify its configuration at run-time:
- Using control socket driven by an external system
- Using Lua program embedded in Resolver's configuration file
Both ways can also be combined: For example the configuration file can contain
a little Lua function which gathers statistics and returns them in JSON string.
This can be used by an external system which uses control socket to call this
user-defined function and to retrieve its results.
.. _control-sockets:
Control sockets
---------------
Control socket acts like "an interactive configuration file" so all actions
available in configuration file can be executed interactively using the control
socket. One possible use-case is reconfiguring the resolver instances from
another program, e.g. a maintenance script.
.. note:: Each instance of Knot Resolver exposes its own control socket. Take
that into account when scripting deployments with
:ref:`systemd-multiple-instances`.
When Knot Resolver is started using Systemd (see section
`Startup <../gettingstarted-startup.html>`_) it creates a control socket in path
``/run/knot-resolver/control/$ID``. Connection to the socket can be made from
command line using e.g. ``socat``:
.. code-block:: bash
$ socat - UNIX-CONNECT:/run/knot-resolver/control/1
When successfully connected to a socket, the command line should change to
something like ``>``. Then you can interact with kresd to see configuration or
set a new one. There are some basic commands to start with.
.. code-block:: lua
> help() -- shows help
> net.interfaces() -- lists available interfaces
> net.list() -- lists running network services
The *direct output* of commands sent over socket is captured and sent back,
which gives you an immediate response on the outcome of your command.
The commands and their output are also logged in ``contrl`` group,
on ``debug`` level if successful or ``warning`` level if failed
(see around :func:`log_level`).
Control sockets are also a way to enumerate and test running instances, the
list of sockets corresponds to the list of processes, and you can test the
process for liveliness by connecting to the UNIX socket.
.. function:: map(lua_snippet)
Executes the provided string as lua code on every running resolver instance
and returns the results as a table.
Key ``n`` is always present in the returned table and specifies the total
number of instances the command was executed on. The table also contains
results from each instance accessible through keys ``1`` to ``n``
(inclusive). If any instance returns ``nil``, it is not explicitly part of
the table, but you can detect it by iterating through ``1`` to ``n``.
.. code-block:: lua
> map('worker.id') -- return an ID of every active instance
{
'2',
'1',
['n'] = 2,
}
> map('worker.id == "1" or nil') -- example of `nil` return value
{
[2] = true,
['n'] = 2,
}
The order of instances isn't guaranteed or stable. When you need to identify
the instances, you may use ``kluautil.kr_table_pack()`` function to return multiple
values as a table. It uses similar semantics with ``n`` as described above
to allow ``nil`` values.
.. code-block:: lua
> map('require("kluautil").kr_table_pack(worker.id, stats.get("answer.total"))')
{
{
'2',
42,
['n'] = 2,
},
{
'1',
69,
['n'] = 2,
},
['n'] = 2,
}
If the command fails on any instance, an error is returned and the execution
is in an undefined state (the command might not have been executed on all
instances). When using the ``map()`` function to execute any code that might
fail, your code should be wrapped in `pcall()
<https://www.lua.org/manual/5.1/manual.html#pdf-pcall>`_ to avoid this
issue.
.. code-block:: lua
> map('require("kluautil").kr_table_pack(pcall(net.tls, "cert.pem", "key.pem"))')
{
{
true, -- function succeeded
true, -- function return value(s)
['n'] = 2,
},
{
false, -- function failed
'error occurred...', -- the returned error message
['n'] = 2,
},
['n'] = 2,
}
Lua scripts
-----------
As it was mentioned in section :ref:`config-lua-syntax`, Resolver's configuration
file contains program in Lua programming language. This allows you to write
dynamic rules and helps you to avoid repetitive templating that is unavoidable
with static configuration. For example parts of configuration can depend on
:func:`hostname` of the machine:
.. code-block:: lua
if hostname() == 'hidden' then
net.listen(net.eth0, 5353)
else
net.listen('127.0.0.1')
net.listen(net.eth1.addr[1])
end
Another example would show how it is possible to bind to all interfaces, using
iteration.
.. code-block:: lua
for name, addr_list in pairs(net.interfaces()) do
net.listen(addr_list)
end
.. tip:: Some users observed a considerable, close to 100%, performance gain in
Docker containers when they bound the daemon to a single interface:ip
address pair. One may expand the aforementioned example with browsing
available addresses as:
.. code-block:: lua
addrpref = env.EXPECTED_ADDR_PREFIX
for k, v in pairs(addr_list["addr"]) do
if string.sub(v,1,string.len(addrpref)) == addrpref then
net.listen(v)
...
You can also use third-party Lua libraries (available for example through
LuaRocks_) as on this example to download cache from parent,
to avoid cold-cache start.
.. code-block:: lua
local http = require('socket.http')
local ltn12 = require('ltn12')
local cache_size = 100*MB
local cache_path = '/var/cache/knot-resolver'
cache.open(cache_size, 'lmdb://' .. cache_path)
if cache.count() == 0 then
cache.close()
-- download cache from parent
http.request {
url = 'http://parent/data.mdb',
sink = ltn12.sink.file(io.open(cache_path .. '/data.mdb', 'w'))
}
-- reopen cache with 100M limit
cache.open(cache_size, 'lmdb://' .. cache_path)
end
Helper functions
^^^^^^^^^^^^^^^^
Following built-in functions are useful for scripting:
.. envvar:: env (table)
Retrieve environment variables.
Example:
.. code-block:: lua
env.USER -- equivalent to $USER in shell
.. function:: fromjson(JSONstring)
:return: Lua representation of data in JSON string.
Example:
.. code-block:: lua
> fromjson('{"key1": "value1", "key2": {"subkey1": 1, "subkey2": 2}}')
[key1] => value1
[key2] => {
[subkey1] => 1
[subkey2] => 2
}
.. function:: hostname([fqdn])
:return: Machine hostname.
If called with a parameter, it will set kresd's internal
hostname. If called without a parameter, it will return kresd's
internal hostname, or the system's POSIX hostname (see
gethostname(2)) if kresd's internal hostname is unset.
This also affects ephemeral (self-signed) certificates generated by kresd
for DNS over TLS.
.. function:: package_version()
:return: Current package version as string.
Example:
.. code-block:: lua
> package_version()
2.1.1
.. function:: resolve(name, type[, class = kres.class.IN, options = {}, finish = nil, init = nil])
:param string name: Query name (e.g. 'com.')
:param number type: Query type (e.g. ``kres.type.NS``)
:param number class: Query class *(optional)* (e.g. ``kres.class.IN``)
:param strings options: Resolution options (see :c:type:`kr_qflags`)
:param function finish: Callback to be executed when resolution completes (e.g. `function cb (pkt, req) end`). The callback gets a packet containing the final answer and doesn't have to return anything.
:param function init: Callback to be executed with the :c:type:`kr_request` before resolution starts.
:return: boolean, ``true`` if resolution was started
The function can also be executed with a table of arguments instead. This is
useful if you'd like to skip some arguments, for example:
.. code-block:: lua
resolve {
name = 'example.com',
type = kres.type.AAAA,
init = function (req)
end,
}
Example:
.. code-block:: lua
-- Send query for root DNSKEY, ignore cache
resolve('.', kres.type.DNSKEY, kres.class.IN, 'NO_CACHE')
-- Query for AAAA record
resolve('example.com', kres.type.AAAA, kres.class.IN, 0,
function (pkt, req)
-- Check answer RCODE
if pkt:rcode() == kres.rcode.NOERROR then
-- Print matching records
local records = pkt:section(kres.section.ANSWER)
for i = 1, #records do
local rr = records[i]
if rr.type == kres.type.AAAA then
print ('record:', kres.rr2str(rr))
end
end
else
print ('rcode: ', pkt:rcode())
end
end)
.. function:: tojson(object)
:return: JSON text representation of `object`.
Example:
.. code-block:: lua
> testtable = { key1 = "value1", "key2" = { subkey1 = 1, subkey2 = 2 } }
> tojson(testtable)
{"key1":"value1","key2":{"subkey1":1,"subkey2":2}}
.. _async-events:
Asynchronous events
-------------------
Lua language used in configuration file allows you to script actions upon
various events, for example publish statistics each minute. Following example
uses built-in function :func:`event.recurrent()` which calls user-supplied
anonymous function:
.. code-block:: lua
local ffi = require('ffi')
modules.load('stats')
-- log statistics every second
local stat_id = event.recurrent(1 * second, function(evid)
log_info(ffi.C.LOG_GRP_STATISTICS, table_print(stats.list()))
end)
-- stop printing statistics after first minute
event.after(1 * minute, function(evid)
event.cancel(stat_id)
end)
Note that each scheduled event is identified by a number valid for the duration
of the event, you may use it to cancel the event at any time.
To persist state between two invocations of a function Lua uses concept called
closures_. In the following example function ``speed_monitor()`` is a closure
function, which provides persistent variable called ``previous``.
.. code-block:: lua
local ffi = require('ffi')
modules.load('stats')
-- make a closure, encapsulating counter
function speed_monitor()
local previous = stats.list()
-- monitoring function
return function(evid)
local now = stats.list()
local total_increment = now['answer.total'] - previous['answer.total']
local slow_increment = now['answer.slow'] - previous['answer.slow']
if slow_increment / total_increment > 0.05 then
log_warn(ffi.C.LOG_GRP_STATISTICS, 'WARNING! More than 5 %% of queries was slow!')
end
previous = now -- store current value in closure
end
end
-- monitor every minute
local monitor_id = event.recurrent(1 * minute, speed_monitor())
Another type of actionable event is activity on a file descriptor. This allows
you to embed other event loops or monitor open files and then fire a callback
when an activity is detected. This allows you to build persistent services
like monitoring probes that cooperate well with the daemon internal operations.
See :func:`event.socket()`.
Filesystem watchers are possible with :func:`worker.coroutine()` and cqueues_,
see the cqueues documentation for more information. Here is an simple example:
.. code-block:: lua
local notify = require('cqueues.notify')
local watcher = notify.opendir('/etc')
watcher:add('hosts')
-- Watch changes to /etc/hosts
worker.coroutine(function ()
for flags, name in watcher:changes() do
for flag in notify.flags(flags) do
-- print information about the modified file
print(name, notify[flag])
end
end
end)
.. include:: ../../daemon/bindings/event.rst
.. include:: ../../modules/etcd/README.rst
.. _closures: https://www.lua.org/pil/6.1.html
.. _cqueues: https://25thandclement.com/~william/projects/cqueues.html
.. _LuaRocks: https://luarocks.org/
/* Copyright (C) CZ.NIC, z.s.p.o. <knot-resolver@labs.nic.cz>
* SPDX-License-Identifier: GPL-3.0-or-later
*/
#include "kresconfig.h"
#include <ucw/lib.h>
#include <sys/socket.h>
#if ENABLE_XDP
#include <libknot/xdp/xdp.h>
#endif
#include "lib/log.h"
#include "lib/utils.h"
#include "daemon/io.h"
#include "daemon/udp_queue.h"
#include "daemon/worker.h"
#include "daemon/defer.h"
#include "daemon/proxyv2.h"
#include "daemon/session2.h"
#define VERBOSE_LOG(session, fmt, ...) do {\
if (kr_log_is_debug(PROTOLAYER, NULL)) {\
const char *sess_dir = (session)->outgoing ? "out" : "in";\
kr_log_debug(PROTOLAYER, "[%08X] (%s) " fmt, \
(session)->log_id, sess_dir, __VA_ARGS__);\
}\
} while (0);\
static uint32_t next_log_id = 1;
struct protolayer_globals protolayer_globals[PROTOLAYER_TYPE_COUNT] = {{0}};
static const enum protolayer_type protolayer_grp_udp53[] = {
PROTOLAYER_TYPE_UDP,
PROTOLAYER_TYPE_PROXYV2_DGRAM,
PROTOLAYER_TYPE_DEFER,
PROTOLAYER_TYPE_DNS_DGRAM,
};
static const enum protolayer_type protolayer_grp_tcp53[] = {
PROTOLAYER_TYPE_TCP,
PROTOLAYER_TYPE_PROXYV2_STREAM,
PROTOLAYER_TYPE_DEFER,
PROTOLAYER_TYPE_DNS_MULTI_STREAM,
};
static const enum protolayer_type protolayer_grp_dot[] = {
PROTOLAYER_TYPE_TCP,
PROTOLAYER_TYPE_PROXYV2_STREAM,
PROTOLAYER_TYPE_DEFER,
PROTOLAYER_TYPE_TLS,
PROTOLAYER_TYPE_DNS_MULTI_STREAM,
};
static const enum protolayer_type protolayer_grp_doh[] = {
PROTOLAYER_TYPE_TCP,
PROTOLAYER_TYPE_PROXYV2_STREAM,
PROTOLAYER_TYPE_DEFER,
PROTOLAYER_TYPE_TLS,
PROTOLAYER_TYPE_HTTP,
PROTOLAYER_TYPE_DNS_UNSIZED_STREAM,
};
static const enum protolayer_type protolayer_grp_doq[] = {
// not yet used
PROTOLAYER_TYPE_NULL,
};
struct protolayer_grp {
const enum protolayer_type *layers;
size_t num_layers;
};
#define PROTOLAYER_GRP(p_array) { \
.layers = (p_array), \
.num_layers = sizeof((p_array)) / sizeof((p_array)[0]), \
}
/** Sequences of layers, or groups, mapped by `enum kr_proto`.
*
* Each group represents a sequence of layers in the unwrap direction (wrap
* direction being the opposite). The sequence dictates the order in which
* individual layers are processed. This macro is used to generate global data
* about groups.
*
* To define a new group, add a new entry in the `KR_PROTO_MAP()` macro and
* create a new static `protolayer_grp_*` array above, similarly to the already
* existing ones. Each array must end with `PROTOLAYER_TYPE_NULL`, to
* indicate the end of the list of protocol layers. The array name's suffix must
* be the one defined as *Variable name* (2nd parameter) in the
* `KR_PROTO_MAP` macro. */
static const struct protolayer_grp protolayer_grps[KR_PROTO_COUNT] = {
#define XX(cid, vid, name) [KR_PROTO_##cid] = PROTOLAYER_GRP(protolayer_grp_##vid),
KR_PROTO_MAP(XX)
#undef XX
};
const char *protolayer_layer_name(enum protolayer_type p)
{
switch (p) {
case PROTOLAYER_TYPE_NULL:
return "(null)";
#define XX(cid) case PROTOLAYER_TYPE_ ## cid: \
return #cid;
PROTOLAYER_TYPE_MAP(XX)
#undef XX
default:
return "(invalid)";
}
}
const char *protolayer_event_name(enum protolayer_event_type e)
{
switch (e) {
case PROTOLAYER_EVENT_NULL:
return "(null)";
#define XX(cid) case PROTOLAYER_EVENT_ ## cid: \
return #cid;
PROTOLAYER_EVENT_MAP(XX)
#undef XX
default:
return "(invalid)";
}
}
const char *protolayer_payload_name(enum protolayer_payload_type p)
{
switch (p) {
case PROTOLAYER_PAYLOAD_NULL:
return "(null)";
#define XX(cid, name) case PROTOLAYER_PAYLOAD_ ## cid: \
return (name);
PROTOLAYER_PAYLOAD_MAP(XX)
#undef XX
default:
return "(invalid)";
}
}
/* Forward decls. */
static int session2_transport_pushv(struct session2 *s,
struct iovec *iov, int iovcnt,
bool iov_short_lived,
const struct comm_info *comm,
protolayer_finished_cb cb, void *baton);
static inline int session2_transport_push(struct session2 *s,
char *buf, size_t buf_len,
bool buf_short_lived,
const struct comm_info *comm,
protolayer_finished_cb cb, void *baton);
static int session2_transport_event(struct session2 *s,
enum protolayer_event_type event,
void *baton);
static size_t iovecs_size(const struct iovec *iov, int cnt)
{
size_t sum = 0;
for (int i = 0; i < cnt; i++) {
sum += iov[i].iov_len;
}
return sum;
}
static size_t iovecs_copy(void *dest, const struct iovec *iov, int cnt,
size_t max_len)
{
const size_t pld_size = iovecs_size(iov, cnt);
const size_t copy_size = MIN(max_len, pld_size);
char *cur = dest;
size_t remaining = copy_size;
for (int i = 0; i < cnt && remaining; i++) {
size_t l = iov[i].iov_len;
size_t to_copy = MIN(l, remaining);
memcpy(cur, iov[i].iov_base, to_copy);
remaining -= l;
cur += l;
}
kr_assert(remaining == 0 && (cur - (char *)dest) == copy_size);
return copy_size;
}
size_t protolayer_payload_size(const struct protolayer_payload *payload)
{
switch (payload->type) {
case PROTOLAYER_PAYLOAD_BUFFER:
return payload->buffer.len;
case PROTOLAYER_PAYLOAD_IOVEC:
return iovecs_size(payload->iovec.iov, payload->iovec.cnt);
case PROTOLAYER_PAYLOAD_WIRE_BUF:
return wire_buf_data_length(payload->wire_buf);
case PROTOLAYER_PAYLOAD_NULL:
return 0;
default:
kr_assert(false && "Invalid payload type");
return 0;
}
}
size_t protolayer_payload_copy(void *dest,
const struct protolayer_payload *payload,
size_t max_len)
{
const size_t pld_size = protolayer_payload_size(payload);
const size_t copy_size = MIN(max_len, pld_size);
if (payload->type == PROTOLAYER_PAYLOAD_BUFFER) {
memcpy(dest, payload->buffer.buf, copy_size);
return copy_size;
} else if (payload->type == PROTOLAYER_PAYLOAD_IOVEC) {
char *cur = dest;
size_t remaining = copy_size;
for (int i = 0; i < payload->iovec.cnt && remaining; i++) {
size_t l = payload->iovec.iov[i].iov_len;
size_t to_copy = MIN(l, remaining);
memcpy(cur, payload->iovec.iov[i].iov_base, to_copy);
remaining -= l;
cur += l;
}
kr_assert(remaining == 0 && (cur - (char *)dest) == copy_size);
return copy_size;
} else if (payload->type == PROTOLAYER_PAYLOAD_WIRE_BUF) {
memcpy(dest, wire_buf_data(payload->wire_buf), copy_size);
return copy_size;
} else if(!payload->type) {
return 0;
} else {
kr_assert(false && "Invalid payload type");
return 0;
}
}
struct protolayer_payload protolayer_payload_as_buffer(
const struct protolayer_payload *payload)
{
if (payload->type == PROTOLAYER_PAYLOAD_BUFFER)
return *payload;
if (payload->type == PROTOLAYER_PAYLOAD_WIRE_BUF) {
struct protolayer_payload new_payload = {
.type = PROTOLAYER_PAYLOAD_BUFFER,
.short_lived = payload->short_lived,
.ttl = payload->ttl,
.buffer = {
.buf = wire_buf_data(payload->wire_buf),
.len = wire_buf_data_length(payload->wire_buf)
}
};
wire_buf_reset(payload->wire_buf);
return new_payload;
}
kr_assert(false && "Unsupported payload type.");
return (struct protolayer_payload){
.type = PROTOLAYER_PAYLOAD_NULL
};
}
size_t protolayer_queue_count_payload(const protolayer_iter_ctx_queue_t *queue)
{
if (!queue || queue_len(*queue) == 0)
return 0;
size_t sum = 0;
/* We're only reading from the queue, but we need to discard the
* `const` so that `queue_it_begin()` accepts it. As long as
* `queue_it_` operations do not write into the queue (which they do
* not, checked at the time of writing), we should be safely in the
* defined behavior territory. */
queue_it_t(struct protolayer_iter_ctx *) it =
queue_it_begin(*(protolayer_iter_ctx_queue_t *)queue);
for (; !queue_it_finished(it); queue_it_next(it)) {
struct protolayer_iter_ctx *ctx = queue_it_val(it);
sum += protolayer_payload_size(&ctx->payload);
}
return sum;
}
bool protolayer_queue_has_payload(const protolayer_iter_ctx_queue_t *queue)
{
if (!queue || queue_len(*queue) == 0)
return false;
/* We're only reading from the queue, but we need to discard the
* `const` so that `queue_it_begin()` accepts it. As long as
* `queue_it_` operations do not write into the queue (which they do
* not, checked at the time of writing), we should be safely in the
* defined behavior territory. */
queue_it_t(struct protolayer_iter_ctx *) it =
queue_it_begin(*(protolayer_iter_ctx_queue_t *)queue);
for (; !queue_it_finished(it); queue_it_next(it)) {
struct protolayer_iter_ctx *ctx = queue_it_val(it);
if (protolayer_payload_size(&ctx->payload))
return true;
}
return false;
}
static inline ssize_t session2_get_protocol(
struct session2 *s, enum protolayer_type protocol)
{
const struct protolayer_grp *grp = &protolayer_grps[s->proto];
for (ssize_t i = 0; i < grp->num_layers; i++) {
enum protolayer_type found = grp->layers[i];
if (protocol == found)
return i;
}
return -1;
}
/** Gets layer-specific session data for the layer with the specified index
* from the manager. */
static inline struct protolayer_data *protolayer_sess_data_get(
struct session2 *s, size_t layer_ix)
{
const struct protolayer_grp *grp = &protolayer_grps[s->proto];
if (kr_fails_assert(layer_ix < grp->num_layers))
return NULL;
/* See doc comment of `struct session2::layer_data` */
const ssize_t *offsets = (ssize_t *)s->layer_data;
char *pl_data_beg = &s->layer_data[2 * grp->num_layers * sizeof(*offsets)];
ssize_t offset = offsets[layer_ix];
if (offset < 0) /* No session data for this layer */
return NULL;
return (struct protolayer_data *)(pl_data_beg + offset);
}
void *protolayer_sess_data_get_current(struct protolayer_iter_ctx *ctx)
{
return protolayer_sess_data_get(ctx->session, ctx->layer_ix);
}
void *protolayer_sess_data_get_proto(struct session2 *s, enum protolayer_type protocol) {
ssize_t layer_ix = session2_get_protocol(s, protocol);
if (layer_ix < 0)
return NULL;
return protolayer_sess_data_get(s, layer_ix);
}
/** Gets layer-specific iteration data for the layer with the specified index
* from the context. */
static inline struct protolayer_data *protolayer_iter_data_get(
struct protolayer_iter_ctx *ctx, size_t layer_ix)
{
struct session2 *s = ctx->session;
const struct protolayer_grp *grp = &protolayer_grps[s->proto];
if (kr_fails_assert(layer_ix < grp->num_layers))
return NULL;
/* See doc comment of `struct session2::layer_data` */
const ssize_t *offsets = (ssize_t *)&s->layer_data[grp->num_layers * sizeof(*offsets)];
ssize_t offset = offsets[layer_ix];
if (offset < 0) /* No iteration data for this layer */
return NULL;
return (struct protolayer_data *)(ctx->data + offset);
}
void *protolayer_iter_data_get_current(struct protolayer_iter_ctx *ctx)
{
return protolayer_iter_data_get(ctx, ctx->layer_ix);
}
size_t protolayer_sess_size_est(struct session2 *s)
{
return s->session_size + s->wire_buf.size;
}
size_t protolayer_iter_size_est(struct protolayer_iter_ctx *ctx, bool incl_payload)
{
size_t size = ctx->session->iter_ctx_size;
if (incl_payload)
size += protolayer_payload_size(&ctx->payload);
return size;
}
static inline bool protolayer_iter_ctx_is_last(struct protolayer_iter_ctx *ctx)
{
unsigned int last_ix = (ctx->direction == PROTOLAYER_UNWRAP)
? protolayer_grps[ctx->session->proto].num_layers - 1
: 0;
return ctx->layer_ix == last_ix;
}
static inline void protolayer_iter_ctx_next(struct protolayer_iter_ctx *ctx)
{
if (ctx->direction == PROTOLAYER_UNWRAP)
ctx->layer_ix++;
else
ctx->layer_ix--;
}
static inline const char *layer_name(enum kr_proto grp, ssize_t layer_ix)
{
if (grp >= KR_PROTO_COUNT)
return "(invalid)";
enum protolayer_type p = protolayer_grps[grp].layers[layer_ix];
return protolayer_layer_name(p);
}
static inline const char *layer_name_ctx(struct protolayer_iter_ctx *ctx)
{
return layer_name(ctx->session->proto, ctx->layer_ix);
}
static int protolayer_iter_ctx_finish(struct protolayer_iter_ctx *ctx, int ret)
{
struct session2 *s = ctx->session;
const struct protolayer_globals *globals = &protolayer_globals[s->proto];
const struct protolayer_grp *grp = &protolayer_grps[s->proto];
for (size_t i = 0; i < grp->num_layers; i++) {
struct protolayer_data *d = protolayer_iter_data_get(ctx, i);
if (globals->iter_deinit)
globals->iter_deinit(ctx, d);
}
if (ret) {
VERBOSE_LOG(s, "layer context of group '%s' (on %u: %s) ended with return code %d\n",
kr_proto_name(s->proto),
ctx->layer_ix, layer_name_ctx(ctx), ret);
}
if (ctx->status) {
VERBOSE_LOG(s, "iteration of group '%s' (on %u: %s) ended with status '%s (%d)'\n",
kr_proto_name(s->proto),
ctx->layer_ix, layer_name_ctx(ctx),
kr_strerror(ctx->status), ctx->status);
}
if (ctx->finished_cb)
ctx->finished_cb(ret, s, ctx->comm, ctx->finished_cb_baton);
mm_ctx_delete(&ctx->pool);
free(ctx);
session2_unhandle(s);
return ret;
}
static void protolayer_push_finished(int status, struct session2 *s, const struct comm_info *comm, void *baton)
{
struct protolayer_iter_ctx *ctx = baton;
ctx->status = status;
protolayer_iter_ctx_finish(ctx, PROTOLAYER_RET_NORMAL);
}
/** Pushes the specified protocol layer's payload to the session's transport. */
static int protolayer_push(struct protolayer_iter_ctx *ctx)
{
struct session2 *session = ctx->session;
if (ctx->payload.type == PROTOLAYER_PAYLOAD_WIRE_BUF) {
ctx->payload = protolayer_payload_as_buffer(&ctx->payload);
}
if (kr_log_is_debug(PROTOLAYER, NULL)) {
VERBOSE_LOG(session, "Pushing %s\n",
protolayer_payload_name(ctx->payload.type));
}
if (ctx->payload.type == PROTOLAYER_PAYLOAD_BUFFER) {
session2_transport_push(session,
ctx->payload.buffer.buf, ctx->payload.buffer.len,
ctx->payload.short_lived,
ctx->comm, protolayer_push_finished, ctx);
} else if (ctx->payload.type == PROTOLAYER_PAYLOAD_IOVEC) {
session2_transport_pushv(session,
ctx->payload.iovec.iov, ctx->payload.iovec.cnt,
ctx->payload.short_lived,
ctx->comm, protolayer_push_finished, ctx);
} else {
kr_assert(false && "Invalid payload type");
return kr_error(EINVAL);
}
return PROTOLAYER_RET_ASYNC;
}
static void protolayer_payload_ensure_long_lived(struct protolayer_iter_ctx *ctx)
{
if (!ctx->payload.short_lived)
return;
size_t buf_len = protolayer_payload_size(&ctx->payload);
if (kr_fails_assert(buf_len))
return;
void *buf = mm_alloc(&ctx->pool, buf_len);
kr_require(buf);
protolayer_payload_copy(buf, &ctx->payload, buf_len);
ctx->payload = protolayer_payload_buffer(buf, buf_len, false);
}
/** Processes as many layers as possible synchronously, returning when either
* a layer has gone asynchronous, or when the whole sequence has finished.
*
* May be called multiple times on the same `ctx` to continue processing after
* an asynchronous operation - user code will do this via *layer sequence return
* functions*. */
static int protolayer_step(struct protolayer_iter_ctx *ctx)
{
while (true) {
if (kr_fails_assert(ctx->session->proto < KR_PROTO_COUNT))
return kr_error(EFAULT);
enum protolayer_type protocol = protolayer_grps[ctx->session->proto].layers[ctx->layer_ix];
struct protolayer_globals *globals = &protolayer_globals[protocol];
bool was_async = ctx->async_mode;
ctx->async_mode = false;
/* Basically if we went asynchronous, we want to "resume" from
* underneath this `if` block. */
if (!was_async) {
ctx->status = 0;
ctx->action = PROTOLAYER_ITER_ACTION_NULL;
protolayer_iter_cb cb = (ctx->direction == PROTOLAYER_UNWRAP)
? globals->unwrap : globals->wrap;
if (ctx->session->closing) {
return protolayer_iter_ctx_finish(
ctx, kr_error(ECANCELED));
}
if (cb) {
struct protolayer_data *sess_data = protolayer_sess_data_get(
ctx->session, ctx->layer_ix);
struct protolayer_data *iter_data = protolayer_iter_data_get(
ctx, ctx->layer_ix);
enum protolayer_iter_cb_result result = cb(sess_data, iter_data, ctx);
if (kr_fails_assert(result == PROTOLAYER_ITER_CB_RESULT_MAGIC)) {
/* Callback did not use a *layer
* sequence return function* (see
* glossary). */
return protolayer_iter_ctx_finish(ctx, kr_error(EINVAL));
}
} else {
ctx->action = PROTOLAYER_ITER_ACTION_CONTINUE;
}
if (!ctx->action) {
/* We're going asynchronous - the next step is
* probably going to be from some sort of a
* callback and we will "resume" from underneath
* this `if` block. */
ctx->async_mode = true;
protolayer_payload_ensure_long_lived(ctx);
return PROTOLAYER_RET_ASYNC;
}
}
if (kr_fails_assert(ctx->action)) {
return protolayer_iter_ctx_finish(ctx, kr_error(EINVAL));
}
if (ctx->action == PROTOLAYER_ITER_ACTION_BREAK) {
return protolayer_iter_ctx_finish(
ctx, PROTOLAYER_RET_NORMAL);
}
if (kr_fails_assert(ctx->status == 0)) {
/* Status should be zero without a BREAK. */
return protolayer_iter_ctx_finish(ctx, kr_error(EINVAL));
}
if (ctx->action == PROTOLAYER_ITER_ACTION_CONTINUE) {
if (protolayer_iter_ctx_is_last(ctx)) {
if (ctx->direction == PROTOLAYER_WRAP)
return protolayer_push(ctx);
return protolayer_iter_ctx_finish(
ctx, PROTOLAYER_RET_NORMAL);
}
protolayer_iter_ctx_next(ctx);
continue;
}
/* Should never get here */
kr_assert(false && "Invalid layer callback action");
return protolayer_iter_ctx_finish(ctx, kr_error(EINVAL));
}
}
/** Submits the specified buffer to the sequence of layers represented by the
* specified protolayer manager. The sequence will be processed in the
* specified `direction`, starting by the layer specified by `layer_ix`.
*
* Returns PROTOLAYER_RET_NORMAL when all layers have finished,
* PROTOLAYER_RET_ASYNC when some layers are asynchronous and waiting for
* continuation, or a negative number for errors (kr_error). */
static int session2_submit(
struct session2 *session,
enum protolayer_direction direction, size_t layer_ix,
struct protolayer_payload payload, const struct comm_info *comm,
protolayer_finished_cb cb, void *baton)
{
if (session->closing)
return kr_error(ECANCELED);
if (session->ref_count >= INT_MAX - 1)
return kr_error(ETOOMANYREFS);
if (kr_fails_assert(session->proto < KR_PROTO_COUNT))
return kr_error(EFAULT);
bool had_comm_param = (comm != NULL);
if (!had_comm_param)
comm = &session->comm_storage;
// DEFER: at this point we might start doing nontrivial work,
// but we may not know the client's IP yet.
// Note two cases: incoming session (new request)
// vs. outgoing session (resuming work on some request)
if ((direction == PROTOLAYER_UNWRAP) && (layer_ix == 0))
defer_sample_start(NULL);
struct protolayer_iter_ctx *ctx = malloc(session->iter_ctx_size);
kr_require(ctx);
VERBOSE_LOG(session,
"%s submitted to grp '%s' in %s direction (%zu: %s)\n",
protolayer_payload_name(payload.type),
kr_proto_name(session->proto),
(direction == PROTOLAYER_UNWRAP) ? "unwrap" : "wrap",
layer_ix, layer_name(session->proto, layer_ix));
*ctx = (struct protolayer_iter_ctx) {
.payload = payload,
.direction = direction,
.layer_ix = layer_ix,
.session = session,
.finished_cb = cb,
.finished_cb_baton = baton
};
session->ref_count++;
if (had_comm_param) {
struct comm_addr_storage *addrst = &ctx->comm_addr_storage;
if (comm->src_addr) {
int len = kr_sockaddr_len(comm->src_addr);
kr_require(len > 0 && len <= sizeof(union kr_sockaddr));
memcpy(&addrst->src_addr, comm->src_addr, len);
ctx->comm_storage.src_addr = &addrst->src_addr.ip;
}
if (comm->comm_addr) {
int len = kr_sockaddr_len(comm->comm_addr);
kr_require(len > 0 && len <= sizeof(union kr_sockaddr));
memcpy(&addrst->comm_addr, comm->comm_addr, len);
ctx->comm_storage.comm_addr = &addrst->comm_addr.ip;
}
if (comm->dst_addr) {
int len = kr_sockaddr_len(comm->dst_addr);
kr_require(len > 0 && len <= sizeof(union kr_sockaddr));
memcpy(&addrst->dst_addr, comm->dst_addr, len);
ctx->comm_storage.dst_addr = &addrst->dst_addr.ip;
}
ctx->comm = &ctx->comm_storage;
} else {
ctx->comm = &session->comm_storage;
}
mm_ctx_mempool(&ctx->pool, CPU_PAGE_SIZE);
const struct protolayer_grp *grp = &protolayer_grps[session->proto];
for (size_t i = 0; i < grp->num_layers; i++) {
struct protolayer_globals *globals = &protolayer_globals[grp->layers[i]];
struct protolayer_data *iter_data = protolayer_iter_data_get(ctx, i);
if (iter_data) {
memset(iter_data, 0, globals->iter_size);
iter_data->session = session;
}
if (globals->iter_init)
globals->iter_init(ctx, iter_data);
}
int ret = protolayer_step(ctx);
if ((direction == PROTOLAYER_UNWRAP) && (layer_ix == 0))
defer_sample_stop(NULL, false);
return ret;
}
static void *get_init_param(enum protolayer_type p,
struct protolayer_data_param *layer_param,
size_t layer_param_count)
{
if (!layer_param || !layer_param_count)
return NULL;
for (size_t i = 0; i < layer_param_count; i++) {
if (layer_param[i].protocol == p)
return layer_param[i].param;
}
return NULL;
}
/** Called by *Layer sequence return functions* to proceed with protolayer
* processing. If the */
static inline void maybe_async_do_step(struct protolayer_iter_ctx *ctx)
{
if (ctx->async_mode)
protolayer_step(ctx);
}
enum protolayer_iter_cb_result protolayer_continue(struct protolayer_iter_ctx *ctx)
{
ctx->action = PROTOLAYER_ITER_ACTION_CONTINUE;
maybe_async_do_step(ctx);
return PROTOLAYER_ITER_CB_RESULT_MAGIC;
}
enum protolayer_iter_cb_result protolayer_break(struct protolayer_iter_ctx *ctx, int status)
{
ctx->status = status;
ctx->action = PROTOLAYER_ITER_ACTION_BREAK;
maybe_async_do_step(ctx);
return PROTOLAYER_ITER_CB_RESULT_MAGIC;
}
int wire_buf_init(struct wire_buf *wb, size_t initial_size)
{
char *buf = malloc(initial_size);
kr_require(buf);
*wb = (struct wire_buf){
.buf = buf,
.size = initial_size
};
return kr_ok();
}
void wire_buf_deinit(struct wire_buf *wb)
{
free(wb->buf);
}
int wire_buf_reserve(struct wire_buf *wb, size_t size)
{
if (wb->buf && wb->size >= size)
return kr_ok();
char *newbuf = realloc(wb->buf, size);
kr_require(newbuf);
wb->buf = newbuf;
wb->size = size;
return kr_ok();
}
int wire_buf_consume(struct wire_buf *wb, size_t length)
{
size_t ne = wb->end + length;
if (kr_fails_assert(wb->buf && ne <= wb->size))
return kr_error(EINVAL);
wb->end = ne;
return kr_ok();
}
int wire_buf_trim(struct wire_buf *wb, size_t length)
{
size_t ns = wb->start + length;
if (kr_fails_assert(ns <= wb->end))
return kr_error(EINVAL);
wb->start = ns;
return kr_ok();
}
int wire_buf_movestart(struct wire_buf *wb)
{
if (kr_fails_assert(wb->buf))
return kr_error(EINVAL);
if (wb->start == 0)
return kr_ok();
size_t len = wire_buf_data_length(wb);
if (len) {
if (wb->start < len)
memmove(wb->buf, wire_buf_data(wb), len);
else
memcpy(wb->buf, wire_buf_data(wb), len);
}
wb->start = 0;
wb->end = len;
return kr_ok();
}
int wire_buf_reset(struct wire_buf *wb)
{
wb->start = 0;
wb->end = 0;
return kr_ok();
}
struct session2 *session2_new(enum session2_transport_type transport_type,
enum kr_proto proto,
struct protolayer_data_param *layer_param,
size_t layer_param_count,
bool outgoing)
{
kr_require(transport_type && proto);
size_t session_size = sizeof(struct session2);
size_t iter_ctx_size = sizeof(struct protolayer_iter_ctx);
const struct protolayer_grp *grp = &protolayer_grps[proto];
if (kr_fails_assert(grp->num_layers))
return NULL;
size_t wire_buf_length = 0;
ssize_t offsets[2 * grp->num_layers];
session_size += sizeof(offsets);
ssize_t *sess_offsets = offsets;
ssize_t *iter_offsets = &offsets[grp->num_layers];
/* Space for layer-specific data, guaranteeing alignment */
size_t total_sess_data_size = 0;
size_t total_iter_data_size = 0;
for (size_t i = 0; i < grp->num_layers; i++) {
const struct protolayer_globals *g = &protolayer_globals[grp->layers[i]];
sess_offsets[i] = g->sess_size ? total_sess_data_size : -1;
total_sess_data_size += ALIGN_TO(g->sess_size, CPU_STRUCT_ALIGN);
iter_offsets[i] = g->iter_size ? total_iter_data_size : -1;
total_iter_data_size += ALIGN_TO(g->iter_size, CPU_STRUCT_ALIGN);
size_t wire_buf_overhead = (g->wire_buf_overhead_cb)
? g->wire_buf_overhead_cb(outgoing)
: g->wire_buf_overhead;
wire_buf_length += wire_buf_overhead;
}
session_size += total_sess_data_size;
iter_ctx_size += total_iter_data_size;
struct session2 *s = malloc(session_size);
kr_require(s);
*s = (struct session2) {
.transport = {
.type = transport_type,
},
.log_id = next_log_id++,
.outgoing = outgoing,
.tasks = trie_create(NULL),
.proto = proto,
.iter_ctx_size = iter_ctx_size,
.session_size = session_size,
};
memcpy(&s->layer_data, offsets, sizeof(offsets));
queue_init(s->waiting);
int ret = wire_buf_init(&s->wire_buf, wire_buf_length);
kr_require(!ret);
ret = uv_timer_init(uv_default_loop(), &s->timer);
kr_require(!ret);
s->timer.data = s;
s->ref_count++; /* Session owns the timer */
/* Initialize the layer's session data */
for (size_t i = 0; i < grp->num_layers; i++) {
struct protolayer_globals *globals = &protolayer_globals[grp->layers[i]];
struct protolayer_data *sess_data = protolayer_sess_data_get(s, i);
if (sess_data) {
memset(sess_data, 0, globals->sess_size);
sess_data->session = s;
}
void *param = get_init_param(grp->layers[i], layer_param, layer_param_count);
if (globals->sess_init)
globals->sess_init(s, sess_data, param);
}
session2_touch(s);
return s;
}
/** De-allocates the session. Must only be called once the underlying IO handle
* and timer are already closed, otherwise may leak resources. */
static void session2_free(struct session2 *s)
{
const struct protolayer_grp *grp = &protolayer_grps[s->proto];
for (size_t i = 0; i < grp->num_layers; i++) {
struct protolayer_globals *globals = &protolayer_globals[grp->layers[i]];
if (globals->sess_deinit) {
struct protolayer_data *sess_data = protolayer_sess_data_get(s, i);
globals->sess_deinit(s, sess_data);
}
}
wire_buf_deinit(&s->wire_buf);
trie_free(s->tasks);
queue_deinit(s->waiting);
free(s);
}
void session2_unhandle(struct session2 *s)
{
if (kr_fails_assert(s->ref_count > 0)) {
session2_free(s);
return;
}
s->ref_count--;
if (s->ref_count <= 0)
session2_free(s);
}
int session2_start_read(struct session2 *session)
{
if (session->transport.type == SESSION2_TRANSPORT_IO)
return io_start_read(session->transport.io.handle);
/* TODO - probably just some event for this */
kr_assert(false && "Parent start_read unsupported");
return kr_error(EINVAL);
}
int session2_stop_read(struct session2 *session)
{
if (session->transport.type == SESSION2_TRANSPORT_IO)
return io_stop_read(session->transport.io.handle);
/* TODO - probably just some event for this */
kr_assert(false && "Parent stop_read unsupported");
return kr_error(EINVAL);
}
struct sockaddr *session2_get_peer(struct session2 *s)
{
while (s && s->transport.type == SESSION2_TRANSPORT_PARENT)
s = s->transport.parent;
return (s && s->transport.type == SESSION2_TRANSPORT_IO)
? &s->transport.io.peer.ip
: NULL;
}
struct sockaddr *session2_get_sockname(struct session2 *s)
{
while (s && s->transport.type == SESSION2_TRANSPORT_PARENT)
s = s->transport.parent;
return (s && s->transport.type == SESSION2_TRANSPORT_IO)
? &s->transport.io.sockname.ip
: NULL;
}
uv_handle_t *session2_get_handle(struct session2 *s)
{
while (s && s->transport.type == SESSION2_TRANSPORT_PARENT)
s = s->transport.parent;
return (s && s->transport.type == SESSION2_TRANSPORT_IO)
? s->transport.io.handle
: NULL;
}
static void session2_on_timeout(uv_timer_t *timer)
{
struct session2 *s = timer->data;
session2_event(s, s->timer_event, NULL);
}
int session2_timer_start(struct session2 *s, enum protolayer_event_type event, uint64_t timeout, uint64_t repeat)
{
s->timer_event = event;
return uv_timer_start(&s->timer, session2_on_timeout, timeout, repeat);
}
int session2_timer_restart(struct session2 *s)
{
return uv_timer_again(&s->timer);
}
int session2_timer_stop(struct session2 *s)
{
return uv_timer_stop(&s->timer);
}
int session2_tasklist_add(struct session2 *session, struct qr_task *task)
{
trie_t *t = session->tasks;
uint16_t task_msg_id = 0;
const char *key = NULL;
size_t key_len = 0;
if (session->outgoing) {
knot_pkt_t *pktbuf = worker_task_get_pktbuf(task);
task_msg_id = knot_wire_get_id(pktbuf->wire);
key = (const char *)&task_msg_id;
key_len = sizeof(task_msg_id);
} else {
key = (const char *)&task;
key_len = sizeof(char *);
}
trie_val_t *v = trie_get_ins(t, key, key_len);
if (kr_fails_assert(v))
return kr_error(ENOMEM);
if (*v == NULL) {
*v = task;
worker_task_ref(task);
} else if (kr_fails_assert(*v == task)) {
return kr_error(EINVAL);
}
return kr_ok();
}
int session2_tasklist_del(struct session2 *session, struct qr_task *task)
{
trie_t *t = session->tasks;
uint16_t task_msg_id = 0;
const char *key = NULL;
size_t key_len = 0;
trie_val_t val;
if (session->outgoing) {
knot_pkt_t *pktbuf = worker_task_get_pktbuf(task);
task_msg_id = knot_wire_get_id(pktbuf->wire);
key = (const char *)&task_msg_id;
key_len = sizeof(task_msg_id);
} else {
key = (const char *)&task;
key_len = sizeof(char *);
}
int ret = trie_del(t, key, key_len, &val);
if (ret == KNOT_EOK) {
kr_require(val == task);
worker_task_unref(val);
}
return ret;
}
struct qr_task *session2_tasklist_get_first(struct session2 *session)
{
trie_val_t *val = trie_get_first(session->tasks, NULL, NULL);
return val ? (struct qr_task *) *val : NULL;
}
struct qr_task *session2_tasklist_del_first(struct session2 *session, bool deref)
{
trie_val_t val = NULL;
int res = trie_del_first(session->tasks, NULL, NULL, &val);
if (res != KNOT_EOK) {
val = NULL;
} else if (deref) {
worker_task_unref(val);
}
return (struct qr_task *)val;
}
struct qr_task *session2_tasklist_find_msgid(const struct session2 *session, uint16_t msg_id)
{
if (kr_fails_assert(session->outgoing))
return NULL;
trie_t *t = session->tasks;
struct qr_task *ret = NULL;
trie_val_t *val = trie_get_try(t, (char *)&msg_id, sizeof(msg_id));
if (val) {
ret = *val;
}
return ret;
}
struct qr_task *session2_tasklist_del_msgid(const struct session2 *session, uint16_t msg_id)
{
if (kr_fails_assert(session->outgoing))
return NULL;
trie_t *t = session->tasks;
struct qr_task *ret = NULL;
const char *key = (const char *)&msg_id;
size_t key_len = sizeof(msg_id);
trie_val_t val;
int res = trie_del(t, key, key_len, &val);
if (res == KNOT_EOK) {
if (worker_task_numrefs(val) > 1) {
ret = val;
}
worker_task_unref(val);
}
return ret;
}
void session2_tasklist_finalize(struct session2 *session, int status)
{
if (session2_tasklist_get_len(session) > 0) {
defer_sample_state_t defer_prev_sample_state;
defer_sample_start(&defer_prev_sample_state);
do {
struct qr_task *t = session2_tasklist_del_first(session, false);
kr_require(worker_task_numrefs(t) > 0);
worker_task_finalize(t, status);
worker_task_unref(t);
defer_sample_restart();
} while (session2_tasklist_get_len(session) > 0);
defer_sample_stop(&defer_prev_sample_state, true);
}
}
int session2_tasklist_finalize_expired(struct session2 *session)
{
int ret = 0;
queue_t(struct qr_task *) q;
uint64_t now = kr_now();
trie_t *t = session->tasks;
trie_it_t *it;
queue_init(q);
for (it = trie_it_begin(t); !trie_it_finished(it); trie_it_next(it)) {
trie_val_t *v = trie_it_val(it);
struct qr_task *task = (struct qr_task *)*v;
if ((now - worker_task_creation_time(task)) >= KR_RESOLVE_TIME_LIMIT) {
struct kr_request *req = worker_task_request(task);
if (!kr_fails_assert(req))
kr_query_inform_timeout(req, req->current_query);
queue_push(q, task);
worker_task_ref(task);
}
}
trie_it_free(it);
struct qr_task *task = NULL;
uint16_t msg_id = 0;
char *key = (char *)&task;
int32_t keylen = sizeof(struct qr_task *);
if (session->outgoing) {
key = (char *)&msg_id;
keylen = sizeof(msg_id);
}
if (queue_len(q) > 0) {
defer_sample_state_t defer_prev_sample_state;
defer_sample_start(&defer_prev_sample_state);
do {
task = queue_head(q);
if (session->outgoing) {
knot_pkt_t *pktbuf = worker_task_get_pktbuf(task);
msg_id = knot_wire_get_id(pktbuf->wire);
}
int res = trie_del(t, key, keylen, NULL);
if (!worker_task_finished(task)) {
/* task->pending_count must be zero,
* but there are can be followers,
* so run worker_task_subreq_finalize() to ensure retrying
* for all the followers. */
worker_task_subreq_finalize(task);
worker_task_finalize(task, KR_STATE_FAIL);
}
if (res == KNOT_EOK) {
worker_task_unref(task);
}
queue_pop(q);
worker_task_unref(task);
++ret;
defer_sample_restart();
} while (queue_len(q) > 0);
defer_sample_stop(&defer_prev_sample_state, true);
}
queue_deinit(q);
return ret;
}
int session2_waitinglist_push(struct session2 *session, struct qr_task *task)
{
queue_push(session->waiting, task);
worker_task_ref(task);
return kr_ok();
}
struct qr_task *session2_waitinglist_get(const struct session2 *session)
{
return (queue_len(session->waiting) > 0) ? (queue_head(session->waiting)) : NULL;
}
struct qr_task *session2_waitinglist_pop(struct session2 *session, bool deref)
{
struct qr_task *t = session2_waitinglist_get(session);
queue_pop(session->waiting);
if (deref) {
worker_task_unref(t);
}
return t;
}
void session2_waitinglist_retry(struct session2 *session, bool increase_timeout_cnt)
{
if (!session2_waitinglist_is_empty(session)) {
defer_sample_state_t defer_prev_sample_state;
defer_sample_start(&defer_prev_sample_state);
do {
struct qr_task *task = session2_waitinglist_pop(session, false);
if (increase_timeout_cnt) {
worker_task_timeout_inc(task);
}
worker_task_step(task, session2_get_peer(session), NULL);
worker_task_unref(task);
defer_sample_restart();
} while (!session2_waitinglist_is_empty(session));
defer_sample_stop(&defer_prev_sample_state, true);
}
}
void session2_waitinglist_finalize(struct session2 *session, int status)
{
if (!session2_waitinglist_is_empty(session)) {
defer_sample_state_t defer_prev_sample_state;
defer_sample_start(&defer_prev_sample_state);
do {
struct qr_task *t = session2_waitinglist_pop(session, false);
worker_task_finalize(t, status);
worker_task_unref(t);
defer_sample_restart();
} while (!session2_waitinglist_is_empty(session));
defer_sample_stop(&defer_prev_sample_state, true);
}
}
void session2_penalize(struct session2 *session)
{
if (session->was_useful || !session->outgoing)
return;
/* We want to penalize the IP address, if a task is asking a query.
* It might not be the right task, but that doesn't matter so much
* for attributing the useless session to the IP address. */
struct qr_task *t = session2_tasklist_get_first(session);
struct kr_query *qry = NULL;
if (t) {
struct kr_request *req = worker_task_request(t);
qry = array_tail(req->rplan.pending);
}
if (qry) /* We reuse the error for connection, as it's quite similar. */
qry->server_selection.error(qry, worker_task_get_transport(t),
KR_SELECTION_TCP_CONNECT_FAILED);
}
int session2_unwrap(struct session2 *s, struct protolayer_payload payload,
const struct comm_info *comm, protolayer_finished_cb cb,
void *baton)
{
return session2_submit(s, PROTOLAYER_UNWRAP,
0, payload, comm, cb, baton);
}
int session2_unwrap_after(struct session2 *s, enum protolayer_type protocol,
struct protolayer_payload payload,
const struct comm_info *comm,
protolayer_finished_cb cb, void *baton)
{
ssize_t layer_ix = session2_get_protocol(s, protocol);
bool ok = layer_ix >= 0 && layer_ix + 1 < protolayer_grps[s->proto].num_layers;
if (kr_fails_assert(ok)) // not found or "last layer"
return kr_error(EINVAL);
return session2_submit(s, PROTOLAYER_UNWRAP,
layer_ix + 1, payload, comm, cb, baton);
}
int session2_wrap(struct session2 *s, struct protolayer_payload payload,
const struct comm_info *comm, protolayer_finished_cb cb,
void *baton)
{
return session2_submit(s, PROTOLAYER_WRAP,
protolayer_grps[s->proto].num_layers - 1,
payload, comm, cb, baton);
}
int session2_wrap_after(struct session2 *s, enum protolayer_type protocol,
struct protolayer_payload payload,
const struct comm_info *comm,
protolayer_finished_cb cb, void *baton)
{
ssize_t layer_ix = session2_get_protocol(s, protocol);
if (kr_fails_assert(layer_ix > 0)) // not found or "last layer"
return kr_error(EINVAL);
return session2_submit(s, PROTOLAYER_WRAP, layer_ix - 1,
payload, comm, cb, baton);
}
static void session2_event_wrap(struct session2 *s, enum protolayer_event_type event, void *baton)
{
bool cont;
const struct protolayer_grp *grp = &protolayer_grps[s->proto];
for (ssize_t i = grp->num_layers - 1; i >= 0; i--) {
struct protolayer_globals *globals = &protolayer_globals[grp->layers[i]];
if (globals->event_wrap) {
struct protolayer_data *sess_data = protolayer_sess_data_get(s, i);
cont = globals->event_wrap(event, &baton, s, sess_data);
} else {
cont = true;
}
if (!cont)
return;
}
session2_transport_event(s, event, baton);
}
static void session2_event_unwrap(struct session2 *s, ssize_t start_ix, enum protolayer_event_type event, void *baton)
{
bool cont;
const struct protolayer_grp *grp = &protolayer_grps[s->proto];
for (ssize_t i = start_ix; i < grp->num_layers; i++) {
struct protolayer_globals *globals = &protolayer_globals[grp->layers[i]];
if (globals->event_unwrap) {
struct protolayer_data *sess_data = protolayer_sess_data_get(s, i);
cont = globals->event_unwrap(event, &baton, s, sess_data);
} else {
cont = true;
}
if (!cont)
return;
}
/* Immediately bounce back in the `wrap` direction.
*
* TODO: This might be undesirable for cases with sub-sessions - the
* current idea is for the layers managing sub-sessions to just return
* `PROTOLAYER_EVENT_CONSUME` on `event_unwrap`, but a more "automatic"
* mechanism may be added when this is relevant, to make it less
* error-prone. */
session2_event_wrap(s, event, baton);
}
void session2_event(struct session2 *s, enum protolayer_event_type event, void *baton)
{
/* Events may be sent from inside or outside of already measured code.
* From inside: close by us, statistics, ...
* From outside: timeout, EOF, close by external reasons, ... */
bool defer_accounting_here = false;
if (!defer_sample_is_accounting() && s->stream && !s->outgoing) {
defer_sample_start(NULL);
defer_accounting_here = true;
}
session2_event_unwrap(s, 0, event, baton);
if (defer_accounting_here)
defer_sample_stop(NULL, false);
}
void session2_event_after(struct session2 *s, enum protolayer_type protocol,
enum protolayer_event_type event, void *baton)
{
ssize_t start_ix = session2_get_protocol(s, protocol);
if (kr_fails_assert(start_ix >= 0))
return;
session2_event_unwrap(s, start_ix + 1, event, baton);
}
void session2_init_request(struct session2 *s, struct kr_request *req)
{
const struct protolayer_grp *grp = &protolayer_grps[s->proto];
for (ssize_t i = 0; i < grp->num_layers; i++) {
struct protolayer_globals *globals = &protolayer_globals[grp->layers[i]];
if (globals->request_init) {
struct protolayer_data *sess_data = protolayer_sess_data_get(s, i);
globals->request_init(s, req, sess_data);
}
}
}
struct session2_pushv_ctx {
struct session2 *session;
protolayer_finished_cb cb;
const struct comm_info *comm;
void *baton;
char *async_buf;
};
static void session2_transport_parent_pushv_finished(int status,
struct session2 *session,
const struct comm_info *comm,
void *baton)
{
struct session2_pushv_ctx *ctx = baton;
if (ctx->cb)
ctx->cb(status, ctx->session, comm, ctx->baton);
free(ctx->async_buf);
free(ctx);
}
static void session2_transport_pushv_finished(int status, struct session2_pushv_ctx *ctx)
{
if (ctx->cb)
ctx->cb(status, ctx->session, ctx->comm, ctx->baton);
free(ctx->async_buf);
free(ctx);
}
static void session2_transport_udp_queue_pushv_finished(int status, void *baton)
{
session2_transport_pushv_finished(status, baton);
}
static void session2_transport_udp_pushv_finished(uv_udp_send_t *req, int status)
{
session2_transport_pushv_finished(status, req->data);
free(req);
}
static void session2_transport_stream_pushv_finished(uv_write_t *req, int status)
{
session2_transport_pushv_finished(status, req->data);
free(req);
}
#if ENABLE_XDP
static void xdp_tx_waker(uv_idle_t *handle)
{
xdp_handle_data_t *xhd = handle->data;
int ret = knot_xdp_send_finish(xhd->socket);
if (ret != KNOT_EAGAIN && ret != KNOT_EOK)
kr_log_error(XDP, "check: ret = %d, %s\n", ret, knot_strerror(ret));
/* Apparently some drivers need many explicit wake-up calls
* even if we push no additional packets (in case they accumulated a lot) */
if (ret != KNOT_EAGAIN)
uv_idle_stop(handle);
knot_xdp_send_prepare(xhd->socket);
/* LATER(opt.): it _might_ be better for performance to do these two steps
* at different points in time */
while (queue_len(xhd->tx_waker_queue)) {
struct session2_pushv_ctx *ctx = queue_head(xhd->tx_waker_queue);
if (ctx->cb)
ctx->cb(kr_ok(), ctx->session, ctx->comm, ctx->baton);
free(ctx);
queue_pop(xhd->tx_waker_queue);
}
}
#endif
static void session2_transport_pushv_ensure_long_lived(
struct iovec **iov, int *iovcnt, bool iov_short_lived,
struct iovec *out_iovecmem, struct session2_pushv_ctx *ctx)
{
if (!iov_short_lived)
return;
size_t iovsize = iovecs_size(*iov, *iovcnt);
if (kr_fails_assert(iovsize))
return;
void *buf = malloc(iovsize);
kr_require(buf);
iovecs_copy(buf, *iov, *iovcnt, iovsize);
ctx->async_buf = buf;
out_iovecmem->iov_base = buf;
out_iovecmem->iov_len = iovsize;
*iov = out_iovecmem;
*iovcnt = 1;
}
/// Count the total size of an iovec[] in bytes.
static inline size_t iovec_sum(const struct iovec iov[], const int iovcnt)
{
size_t result = 0;
for (int i = 0; i < iovcnt; ++i)
result += iov[i].iov_len;
return result;
}
static int session2_transport_pushv(struct session2 *s,
struct iovec *iov, int iovcnt,
bool iov_short_lived,
const struct comm_info *comm,
protolayer_finished_cb cb, void *baton)
{
struct iovec iovecmem;
if (kr_fails_assert(s))
return kr_error(EINVAL);
struct session2_pushv_ctx *ctx = malloc(sizeof(*ctx));
kr_require(ctx);
*ctx = (struct session2_pushv_ctx){
.session = s,
.cb = cb,
.baton = baton,
.comm = comm
};
int err_ret = kr_ok();
switch (s->transport.type) {
case SESSION2_TRANSPORT_IO:;
uv_handle_t *handle = s->transport.io.handle;
if (kr_fails_assert(handle)) {
err_ret = kr_error(EINVAL);
goto exit_err;
}
if (handle->type == UV_UDP) {
if (ENABLE_SENDMMSG && !s->outgoing) {
int fd;
int ret = uv_fileno(handle, &fd);
if (kr_fails_assert(!ret)) {
err_ret = kr_error(EIO);
goto exit_err;
}
/* TODO: support multiple iovecs properly? */
if (kr_fails_assert(iovcnt == 1)) {
err_ret = kr_error(EINVAL);
goto exit_err;
}
session2_transport_pushv_ensure_long_lived(
&iov, &iovcnt, iov_short_lived,
&iovecmem, ctx);
udp_queue_push(fd, comm->comm_addr, iov->iov_base, iov->iov_len,
session2_transport_udp_queue_pushv_finished,
ctx);
return kr_ok();
} else {
int ret = uv_udp_try_send((uv_udp_t*)handle, (uv_buf_t *)iov, iovcnt,
the_network->enable_connect_udp ? NULL : comm->comm_addr);
if (ret > 0) // equals buffer size, only confuses us
ret = 0;
if (ret == UV_EAGAIN) {
ret = kr_error(ENOBUFS);
session2_event(s, PROTOLAYER_EVENT_OS_BUFFER_FULL, NULL);
}
if (false && ret == UV_EAGAIN) { // XXX: see uv_try_write() below
uv_udp_send_t *req = malloc(sizeof(*req));
req->data = ctx;
session2_transport_pushv_ensure_long_lived(
&iov, &iovcnt, iov_short_lived,
&iovecmem, ctx);
ret = uv_udp_send(req, (uv_udp_t *)handle,
(uv_buf_t *)iov, iovcnt, comm->comm_addr,
session2_transport_udp_pushv_finished);
if (ret)
session2_transport_udp_pushv_finished(req, ret);
return ret;
}
session2_transport_pushv_finished(ret, ctx);
return ret;
}
} else if (handle->type == UV_TCP) {
int ret = uv_try_write((uv_stream_t *)handle, (uv_buf_t *)iov, iovcnt);
// XXX: queueing disabled for now if the OS can't accept the data.
// Typically that happens when OS buffers are full.
// We were missing any handling of partial write success, too.
if (ret == UV_EAGAIN || (ret >= 0 && ret != iovec_sum(iov, iovcnt))) {
ret = kr_error(ENOBUFS);
session2_event(s, PROTOLAYER_EVENT_OS_BUFFER_FULL, NULL);
}
else if (ret > 0) // iovec_sum was checked, let's not get confused anymore
ret = 0;
if (false && ret == UV_EAGAIN) {
uv_write_t *req = malloc(sizeof(*req));
req->data = ctx;
session2_transport_pushv_ensure_long_lived(
&iov, &iovcnt, iov_short_lived,
&iovecmem, ctx);
ret = uv_write(req, (uv_stream_t *)handle, (uv_buf_t *)iov, iovcnt,
session2_transport_stream_pushv_finished);
if (ret)
session2_transport_stream_pushv_finished(req, ret);
return ret;
}
session2_transport_pushv_finished(ret, ctx);
return ret;
#if ENABLE_XDP
} else if (handle->type == UV_POLL) {
xdp_handle_data_t *xhd = handle->data;
if (kr_fails_assert(xhd && xhd->socket)) {
err_ret = kr_error(EIO);
goto exit_err;
}
/* TODO: support multiple iovecs properly? */
if (kr_fails_assert(iovcnt == 1)) {
err_ret = kr_error(EINVAL);
goto exit_err;
}
session2_transport_pushv_ensure_long_lived(
&iov, &iovcnt, iov_short_lived,
&iovecmem, ctx);
knot_xdp_msg_t msg;
/* We don't have a nice way of preserving the _msg_t from frame allocation,
* so we manually redo all other parts of knot_xdp_send_alloc() */
memset(&msg, 0, sizeof(msg));
bool ipv6 = comm->comm_addr->sa_family == AF_INET6;
msg.flags = ipv6 ? KNOT_XDP_MSG_IPV6 : 0;
memcpy(msg.eth_from, comm->eth_from, sizeof(comm->eth_from));
memcpy(msg.eth_to, comm->eth_to, sizeof(comm->eth_to));
const struct sockaddr *ip_from = comm->dst_addr;
const struct sockaddr *ip_to = comm->comm_addr;
memcpy(&msg.ip_from, ip_from, kr_sockaddr_len(ip_from));
memcpy(&msg.ip_to, ip_to, kr_sockaddr_len(ip_to));
msg.payload = *iov;
uint32_t sent;
int ret = knot_xdp_send(xhd->socket, &msg, 1, &sent);
queue_push(xhd->tx_waker_queue, ctx);
uv_idle_start(&xhd->tx_waker, xdp_tx_waker);
kr_log_debug(XDP, "pushed a packet, ret = %d\n", ret);
return kr_ok();
#endif
} else {
kr_assert(false && "Unsupported handle");
err_ret = kr_error(EINVAL);
goto exit_err;
}
case SESSION2_TRANSPORT_PARENT:;
struct session2 *parent = s->transport.parent;
if (kr_fails_assert(parent)) {
err_ret = kr_error(EINVAL);
goto exit_err;
}
int ret = session2_wrap(parent,
protolayer_payload_iovec(iov, iovcnt, iov_short_lived),
comm, session2_transport_parent_pushv_finished,
ctx);
return (ret < 0) ? ret : kr_ok();
default:
kr_assert(false && "Invalid transport");
err_ret = kr_error(EINVAL);
goto exit_err;
}
exit_err:
session2_transport_pushv_finished(err_ret, ctx);
return err_ret;
}
struct push_ctx {
struct iovec iov;
protolayer_finished_cb cb;
void *baton;
};
static void session2_transport_single_push_finished(int status,
struct session2 *s,
const struct comm_info *comm,
void *baton)
{
struct push_ctx *ctx = baton;
if (ctx->cb)
ctx->cb(status, s, comm, ctx->baton);
free(ctx);
}
static inline int session2_transport_push(struct session2 *s,
char *buf, size_t buf_len,
bool buf_short_lived,
const struct comm_info *comm,
protolayer_finished_cb cb, void *baton)
{
struct push_ctx *ctx = malloc(sizeof(*ctx));
kr_require(ctx);
*ctx = (struct push_ctx){
.iov = {
.iov_base = buf,
.iov_len = buf_len
},
.cb = cb,
.baton = baton
};
return session2_transport_pushv(s, &ctx->iov, 1, buf_short_lived, comm,
session2_transport_single_push_finished, ctx);
}
static void on_session2_handle_close(uv_handle_t *handle)
{
struct session2 *session = handle->data;
kr_require(session->transport.type == SESSION2_TRANSPORT_IO &&
session->transport.io.handle == handle);
io_free(handle);
}
static void on_session2_timer_close(uv_handle_t *handle)
{
session2_unhandle(handle->data);
}
static int session2_handle_close(struct session2 *s)
{
if (kr_fails_assert(s->transport.type == SESSION2_TRANSPORT_IO))
return kr_error(EINVAL);
uv_handle_t *handle = s->transport.io.handle;
if (!handle->loop) {
/* This happens when kresd is stopping and the libUV loop has
* been ended. We do not `uv_close` the handles, we just free
* up the memory. */
session2_unhandle(s); /* For timer handle */
io_free(handle); /* This will unhandle the transport handle */
return kr_ok();
}
io_stop_read(handle);
uv_close((uv_handle_t *)&s->timer, on_session2_timer_close);
uv_close(handle, on_session2_handle_close);
return kr_ok();
}
static int session2_transport_event(struct session2 *s,
enum protolayer_event_type event,
void *baton)
{
if (s->closing)
return kr_ok();
if (event == PROTOLAYER_EVENT_EOF) {
// no layer wanted to retain TCP half-closed state
session2_force_close(s);
return kr_ok();
}
bool is_close_event = (event == PROTOLAYER_EVENT_CLOSE ||
event == PROTOLAYER_EVENT_FORCE_CLOSE);
if (is_close_event) {
kr_require(session2_is_empty(s));
session2_timer_stop(s);
s->closing = true;
}
switch (s->transport.type) {
case SESSION2_TRANSPORT_IO:;
if (kr_fails_assert(s->transport.io.handle)) {
return kr_error(EINVAL);
}
if (is_close_event)
return session2_handle_close(s);
return kr_ok();
case SESSION2_TRANSPORT_PARENT:;
session2_event_wrap(s, event, baton);
return kr_ok();
default:
kr_assert(false && "Invalid transport");
return kr_error(EINVAL);
}
}
void session2_kill_ioreq(struct session2 *session, struct qr_task *task)
{
if (!session || session->closing)
return;
if (kr_fails_assert(session->outgoing
&& session->transport.type == SESSION2_TRANSPORT_IO
&& session->transport.io.handle))
return;
session2_tasklist_del(session, task);
if (session->transport.io.handle->type == UV_UDP)
session2_close(session);
}
/* Copyright (C) CZ.NIC, z.s.p.o. <knot-resolver@labs.nic.cz>
* SPDX-License-Identifier: GPL-3.0-or-later
*/
/* High-level explanation of layered protocols: ./layered-protocols.rst */
/* HINT: If you are looking to implement support for a new transport protocol,
* start with the doc comment of the `PROTOLAYER_TYPE_MAP` macro and
* continue from there. */
/* GLOSSARY:
*
* Event:
* - An Event may be processed by the protocol layer sequence much like a
* Payload, but with a special callback. Events may be used to notify layers
* that e.g. a connection has been established; a timeout has occurred; a
* malformed packet has been received, etc. Events are generally not sent
* through the transport - they may, however, trigger a new payload to be
* sent, e.g. a HTTP error status response.
*
* Iteration:
* - The processing of Payload data or an event using a particular sequence
* of Protocol layers, either in Wrap or Unwrap direction. For payload
* processing, it is also the lifetime of `struct protolayer_iter_ctx` and
* layer-specific data contained therein.
*
* Layer sequence return function:
* - One of `protolayer_break()`, `protolayer_continue()`, or
* `protolayer_async()` - a function that a protolayer's `_wrap` or `_unwrap`
* callback should call to get its return value. They may either be called
* synchronously directly in the callback to end/pause the processing, or, if
* the processing went asynchronous, called to resume the iteration of layers.
*
* Payload:
* - Data processed by protocol layers in a particular sequence. In the wrap
* direction, this data generally starts as a DNS packet, which is then
* wrapped in protocol ceremony data by each layer. In the unwrap direction,
* the opposite takes place - ceremony data is removed until a raw DNS packet
* is retrieved.
*
* Protocol layer:
* - Not to be confused with `struct kr_layer_api`. An implementation of a
* particular protocol. A protocol layer transforms payloads to conform to a
* particular protocol, e.g. UDP, TCP, TLS, HTTP, QUIC, etc. While
* transforming a payload, a layer may also modify metadata - e.g. the UDP and
* TCP layers in the Unwrap direction implement the PROXYv2 protocol, using
* which they retrieve the IP address of the actual originating client and
* store it in the appropriate struct.
*
* Protolayer:
* - Short for 'protocol layer'.
*
* Unwrap:
* - The direction of data transformation, which starts with the transport
* (e.g. bytes that came from the network) and ends with an internal subsystem
* (e.g. DNS query resolution).
*
* Wrap:
* - The direction of data transformation, which starts with an internal
* subsystem (e.g. an answer to a resolved DNS query) and ends with the
* transport (e.g. bytes that are going to be sent to the client). */
#pragma once
#include <stdalign.h>
#include <stdint.h>
#include <stdlib.h>
#include <uv.h>
#include "contrib/mempattern.h"
#include "lib/generic/queue.h"
#include "lib/generic/trie.h"
#include "lib/proto.h"
#include "lib/utils.h"
/* Forward declarations */
struct session2;
struct protolayer_iter_ctx;
/** Type of MAC addresses. */
typedef uint8_t ethaddr_t[6];
/** Information about the transport - addresses and proxy. */
struct comm_info {
/** The original address the data came from. May be that of a proxied
* client, if they came through a proxy. May be `NULL` if
* the communication did not come from network. */
const struct sockaddr *src_addr;
/** The actual address the resolver is communicating with. May be
* the address of a proxy if the communication came through one,
* otherwise it will be the same as `src_addr`. May be `NULL` if
* the communication did not come from network. */
const struct sockaddr *comm_addr;
/** The original destination address. May be the resolver's address, or
* the address of a proxy if the communication came through one. May be
* `NULL` if the communication did not come from network. */
const struct sockaddr *dst_addr;
/** Data parsed from a PROXY header. May be `NULL` if the communication
* did not come through a proxy, or if the PROXYv2 protocol was not
* used. */
const struct proxy_result *proxy;
/** Pointer to protolayer-specific data, e.g. a key to decide, which
* sub-session to use. */
void *target;
/* XDP data */
ethaddr_t eth_from;
ethaddr_t eth_to;
bool xdp:1;
};
/** Just a simple struct able to hold three IPv6 or IPv4 addresses, so that we
* can hold them somewhere. */
struct comm_addr_storage {
union kr_sockaddr src_addr;
union kr_sockaddr comm_addr;
union kr_sockaddr dst_addr;
};
/** A buffer control struct, with indices marking a chunk containing received
* but as of yet unprocessed data - the data in this chunk is called "valid
* data". The struct may be manipulated using `wire_buf_` functions, which
* contain bounds checks to ensure correct behaviour.
*
* The struct may be used to retrieve data piecewise (e.g. from a stream-based
* transport like TCP) by writing data to the buffer's free space, then
* "consuming" that space with `wire_buf_consume`. It can also be handy for
* processing message headers, then trimming the beginning of the buffer (using
* `wire_buf_trim`) so that the next part of the data may be processed by
* another part of a pipeline.
*
* May be initialized in two possible ways:
* - via `wire_buf_init`
* - to zero, then reserved via `wire_buf_reserve`. */
struct wire_buf {
char *buf; /**< Buffer memory. */
size_t size; /**< Current size of the buffer memory. */
size_t start; /**< Index at which the valid data of the buffer starts (inclusive). */
size_t end; /**< Index at which the valid data of the buffer ends (exclusive). */
};
/** Initializes the wire buffer with the specified `initial_size` and allocates
* the underlying memory. */
int wire_buf_init(struct wire_buf *wb, size_t initial_size);
/** De-allocates the wire buffer's underlying memory (the struct itself is left
* intact). */
void wire_buf_deinit(struct wire_buf *wb);
/** Ensures that the wire buffer's size is at least `size`. The memory at `wb`
* must be initialized, either to zero or via `wire_buf_init`. */
int wire_buf_reserve(struct wire_buf *wb, size_t size);
/** Adds `length` to the end index of the valid data, marking `length` more
* bytes as valid.
*
* Returns 0 on success.
* Assert-fails and/or returns `kr_error(EINVAL)` if the end index would exceed
* the buffer size. */
int wire_buf_consume(struct wire_buf *wb, size_t length);
/** Adds `length` to the start index of the valid data, marking `length` less
* bytes as valid.
*
* Returns 0 on success.
* Assert-fails and/or returns `kr_error(EINVAL)` if the start index would
* exceed the end index. */
int wire_buf_trim(struct wire_buf *wb, size_t length);
/** Moves the valid bytes of the buffer to the buffer's beginning. */
int wire_buf_movestart(struct wire_buf *wb);
/** Marks the wire buffer as empty. */
int wire_buf_reset(struct wire_buf *wb);
/** Gets a pointer to the data marked as valid in the wire buffer. */
static inline void *wire_buf_data(const struct wire_buf *wb)
{
return &wb->buf[wb->start];
}
/** Gets the length of the data marked as valid in the wire buffer. */
static inline size_t wire_buf_data_length(const struct wire_buf *wb)
{
return wb->end - wb->start;
}
/** Gets a pointer to the free space after the valid data of the wire buffer. */
static inline void *wire_buf_free_space(const struct wire_buf *wb)
{
return (wb->buf) ? &wb->buf[wb->end] : NULL;
}
/** Gets the length of the free space after the valid data of the wire buffer. */
static inline size_t wire_buf_free_space_length(const struct wire_buf *wb)
{
if (kr_fails_assert(wb->end <= wb->size))
return 0;
return (wb->buf) ? wb->size - wb->end : 0;
}
/** Protocol layer types map - an enumeration of individual protocol layer
* implementations
*
* This macro is used to generate `enum protolayer_type` as well as other
* additional data on protocols, e.g. name string constants.
*
* To define a new protocol, add a new identifier to this macro, and, within
* some logical compilation unit (e.g. `daemon/worker.c` for DNS layers),
* initialize the protocol's `protolayer_globals[]`, ideally in a function
* called at the start of the program (e.g. `worker_init()`). See the docs of
* `struct protolayer_globals` for details on what data this structure should
* contain.
*
* To use protocols within sessions, protocol layer groups also need to be
* defined, to indicate the order in which individual protocols are to be
* processed. See `KR_PROTO_MAP` below for more details. */
#define PROTOLAYER_TYPE_MAP(XX) \
/* General transport protocols */\
XX(UDP)\
XX(TCP)\
XX(TLS)\
XX(HTTP)\
\
/* PROXYv2 */\
XX(PROXYV2_DGRAM)\
XX(PROXYV2_STREAM)\
\
/* DNS (`worker`) */\
XX(DNS_DGRAM) /**< Packets WITHOUT prepended size, one per (un)wrap,
* limited to UDP sizes, multiple sources (single
* session for multiple clients). */\
XX(DNS_UNSIZED_STREAM) /**< Singular packet WITHOUT prepended size, one
* per (un)wrap, no UDP limits, single source. */\
XX(DNS_MULTI_STREAM) /**< Multiple packets WITH prepended sizes in a
* stream (may span multiple (un)wraps). */\
XX(DNS_SINGLE_STREAM) /**< Singular packet WITH prepended size in a
* stream (may span multiple (un)wraps). */\
/* Prioritization of requests */\
XX(DEFER)
/** The identifiers of protocol layer types. */
enum protolayer_type {
PROTOLAYER_TYPE_NULL = 0,
#define XX(cid) PROTOLAYER_TYPE_ ## cid,
PROTOLAYER_TYPE_MAP(XX)
#undef XX
PROTOLAYER_TYPE_COUNT /* must be the last! */
};
/** Gets the constant string name of the specified protocol. */
const char *protolayer_layer_name(enum protolayer_type p);
/** Flow control indicators for protocol layer `wrap` and `unwrap` callbacks.
* Use via `protolayer_continue`, `protolayer_break`, and `protolayer_push`
* functions. */
enum protolayer_iter_action {
PROTOLAYER_ITER_ACTION_NULL = 0,
PROTOLAYER_ITER_ACTION_CONTINUE,
PROTOLAYER_ITER_ACTION_BREAK,
};
/** Direction of layer sequence processing. */
enum protolayer_direction {
/** Processes buffers in order of layers as defined in the layer group.
* In this direction, protocol ceremony data should be removed from the
* buffer, parsing additional data provided by the protocol. */
PROTOLAYER_UNWRAP,
/** Processes buffers in reverse order of layers as defined in the
* layer group. In this direction, protocol ceremony data should be
* added. */
PROTOLAYER_WRAP,
};
/** Returned by a successful call to `session2_wrap()` or `session2_unwrap()`
* functions. */
enum protolayer_ret {
/** Returned when a protolayer context iteration has finished
* processing, i.e. with `protolayer_break()`. */
PROTOLAYER_RET_NORMAL = 0,
/** Returned when a protolayer context iteration is waiting for an
* asynchronous callback to a continuation function. This will never be
* passed to `protolayer_finished_cb`, only returned by
* `session2_unwrap` or `session2_wrap`. */
PROTOLAYER_RET_ASYNC,
};
/** Called when a payload iteration (started by `session2_unwrap` or
* `session2_wrap`) has ended - i.e. the input buffer will not be processed any
* further.
*
* `status` may be one of `enum protolayer_ret` or a negative number indicating
* an error.
* `target` is the `target` parameter passed to the `session2_(un)wrap`
* function.
* `baton` is the `baton` parameter passed to the `session2_(un)wrap` function. */
typedef void (*protolayer_finished_cb)(int status, struct session2 *session,
const struct comm_info *comm, void *baton);
/** Protocol layer event type map
*
* This macro is used to generate `enum protolayer_event_type` as well as the
* relevant name string constants for each event type.
*
* Event types are used to distinguish different events that can be passed to
* sessions using `session2_event()`. */
#define PROTOLAYER_EVENT_MAP(XX) \
/** Closes the session gracefully - i.e. layers add their standard
* disconnection ceremony (e.g. `gnutls_bye()`). */\
XX(CLOSE) \
/** Closes the session forcefully - i.e. layers SHOULD NOT add any
* disconnection ceremony, if avoidable. */\
XX(FORCE_CLOSE) \
/** Signal that a connection could not be established due to a timeout. */\
XX(CONNECT_TIMEOUT) \
/** Signal that a general application-defined timeout has occurred. */\
XX(GENERAL_TIMEOUT) \
/** Signal that a connection has been established. */\
XX(CONNECT) \
/** Signal that a connection could not have been established. */\
XX(CONNECT_FAIL) \
/** Signal that a malformed request has been received. */\
XX(MALFORMED) \
/** Signal that a connection has ended. */\
XX(DISCONNECT) \
/** Signal EOF from peer (e.g. half-closed TCP connection). */\
XX(EOF) \
/** Failed task send - update stats. */\
XX(STATS_SEND_ERR) \
/** Outgoing query submission - update stats. */\
XX(STATS_QRY_OUT) \
/** OS buffers are full, so not sending any more data. */\
XX(OS_BUFFER_FULL) \
//
/** Event type, to be interpreted by a layer. */
enum protolayer_event_type {
PROTOLAYER_EVENT_NULL = 0,
#define XX(cid) PROTOLAYER_EVENT_ ## cid,
PROTOLAYER_EVENT_MAP(XX)
#undef XX
PROTOLAYER_EVENT_COUNT
};
/** Gets the constant string name of the specified event. */
const char *protolayer_event_name(enum protolayer_event_type e);
/** Payload types.
*
* Parameters are:
* 1. Constant name
* 2. Human-readable name for logging */
#define PROTOLAYER_PAYLOAD_MAP(XX) \
XX(BUFFER, "Buffer") \
XX(IOVEC, "IOVec") \
XX(WIRE_BUF, "Wire buffer")
/** Determines which union member of `struct protolayer_payload` is currently
* valid. */
enum protolayer_payload_type {
PROTOLAYER_PAYLOAD_NULL = 0,
#define XX(cid, name) PROTOLAYER_PAYLOAD_##cid,
PROTOLAYER_PAYLOAD_MAP(XX)
#undef XX
PROTOLAYER_PAYLOAD_COUNT
};
/** Gets the constant string name of the specified payload type. */
const char *protolayer_payload_name(enum protolayer_payload_type p);
/** Data processed by the sequence of layers. All pointed-to memory is always
* owned by its creator. It is also the layer (group) implementor's
* responsibility to keep data compatible in between layers. No payload memory
* is ever (de-)allocated by the protolayer manager! */
struct protolayer_payload {
enum protolayer_payload_type type;
/** Time-to-live hint (e.g. for HTTP Cache-Control) */
unsigned int ttl;
/** If `true`, signifies that the memory this payload points to may
* become invalid when we return from one of the functions in the
* current stack. That is fine as long as all the protocol layer
* processing for this payload takes place in a single `session2_wrap()`
* or `session2_unwrap()` call, but may become a problem, when a layer
* goes asynchronous (via `protolayer_async()`).
*
* Setting this to `true` will ensure that the payload will get copied
* into a separate memory buffer if and only if a layer goes
* asynchronous. It makes sure that if all processing for the payload is
* synchronous, no copies or reallocations for the payload are done. */
bool short_lived;
union {
/** Only valid if `type` is `_BUFFER`. */
struct {
void *buf;
size_t len;
} buffer;
/** Only valid if `type` is `_IOVEC`. */
struct {
struct iovec *iov;
int cnt;
} iovec;
/** Only valid if `type` is `_WIRE_BUF`. */
struct wire_buf *wire_buf;
};
};
/** Context for protocol layer iterations, containing payload data,
* layer-specific data, and internal information for the protocol layer
* manager. */
struct protolayer_iter_ctx {
/* read-write for layers: */
/** The payload */
struct protolayer_payload payload;
/** Pointer to communication information. For TCP, this will generally
* point to the storage in the session. For UDP, this will generally
* point to the storage in this context. */
struct comm_info *comm;
/** Communication information storage. This will generally be set by one
* of the first layers in the sequence, if used, e.g. UDP PROXYv2. */
struct comm_info comm_storage;
struct comm_addr_storage comm_addr_storage;
/** Per-iter memory pool. Has no `free` procedure, gets freed as a whole
* when the context is being destroyed. Initialized and destroyed
* automatically - layers may use it to allocate memory. */
knot_mm_t pool;
/* callback for when the layer iteration has ended - read-only for layers: */
protolayer_finished_cb finished_cb;
void *finished_cb_baton;
/* internal information for the manager - should only be used by the protolayer
* system, never by layers: */
enum protolayer_direction direction;
/** If `true`, the processing of the layer sequence has been paused and
* is waiting to be resumed (`protolayer_continue()`) or cancelled
* (`protolayer_break()`). */
bool async_mode;
/** The index of the layer that is currently being (or has just been)
* processed. */
unsigned int layer_ix;
struct session2 *session;
/** Status passed to the finish callback. */
int status;
enum protolayer_iter_action action;
/** Contains a sequence of variably-sized CPU-aligned layer-specific
* structs. See `struct session2::layer_data` for details. */
alignas(CPU_STRUCT_ALIGN) char data[];
};
/** Gets the total size of the data in the specified payload. */
size_t protolayer_payload_size(const struct protolayer_payload *payload);
/** Copies the specified payload to `dest`. Only `max_len` or the size of the
* payload is written, whichever is less.
*
* Returns the actual length of copied data. */
size_t protolayer_payload_copy(void *dest,
const struct protolayer_payload *payload,
size_t max_len);
/** Convenience function to get a buffer-type payload. */
static inline struct protolayer_payload protolayer_payload_buffer(
void *buf, size_t len, bool short_lived)
{
return (struct protolayer_payload){
.type = PROTOLAYER_PAYLOAD_BUFFER,
.short_lived = short_lived,
.buffer = {
.buf = buf,
.len = len
}
};
}
/** Convenience function to get an iovec-type payload. */
static inline struct protolayer_payload protolayer_payload_iovec(
struct iovec *iov, int iovcnt, bool short_lived)
{
return (struct protolayer_payload){
.type = PROTOLAYER_PAYLOAD_IOVEC,
.short_lived = short_lived,
.iovec = {
.iov = iov,
.cnt = iovcnt
}
};
}
/** Convenience function to get a wire-buf-type payload. */
static inline struct protolayer_payload protolayer_payload_wire_buf(
struct wire_buf *wire_buf, bool short_lived)
{
return (struct protolayer_payload){
.type = PROTOLAYER_PAYLOAD_WIRE_BUF,
.short_lived = short_lived,
.wire_buf = wire_buf
};
}
/** Convenience function to represent the specified payload as a buffer-type.
* Supports only `_BUFFER` and `_WIRE_BUF` on the input, otherwise returns
* `_NULL` type or aborts on assertion if allowed.
*
* If the input payload is `_WIRE_BUF`, the pointed-to wire buffer is reset to
* indicate that all of its contents have been used up, and the buffer is ready
* to be reused. */
struct protolayer_payload protolayer_payload_as_buffer(
const struct protolayer_payload *payload);
/** A predefined queue type for iteration context. */
typedef queue_t(struct protolayer_iter_ctx *) protolayer_iter_ctx_queue_t;
/** Iterates through the specified `queue` and gets the sum of all payloads
* available in it. */
size_t protolayer_queue_count_payload(const protolayer_iter_ctx_queue_t *queue);
/** Checks if the specified `queue` has any payload data (i.e.
* `protolayer_queue_count_payload` would be non-zero). This optimizes calls to
* queue iterators, as it does not need to iterate through the whole queue. */
bool protolayer_queue_has_payload(const protolayer_iter_ctx_queue_t *queue);
/** Gets layer-specific session data for the specified protocol layer.
* Returns NULL if the layer is not present in the session. */
void *protolayer_sess_data_get_proto(struct session2 *s, enum protolayer_type protocol);
/** Gets layer-specific session data for the last processed layer.
* To be used after returning from its callback for async continuation but before calling protolayer_continue. */
void *protolayer_sess_data_get_current(struct protolayer_iter_ctx *ctx);
/** Gets layer-specific iteration data for the last processed layer.
* To be used after returning from its callback for async continuation but before calling protolayer_continue. */
void *protolayer_iter_data_get_current(struct protolayer_iter_ctx *ctx);
/** Gets rough memory footprint estimate of session/iteration for use in defer.
* Different, hopefully minor, allocations are not counted here;
* tasks and subsessions are also not counted;
* read the code before using elsewhere. */
size_t protolayer_sess_size_est(struct session2 *s);
size_t protolayer_iter_size_est(struct protolayer_iter_ctx *ctx, bool incl_payload);
/** Layer-specific data - the generic struct. To be added as the first member of
* each specific struct. */
struct protolayer_data {
struct session2 *session; /**< Pointer to the owner session. */\
};
/** Return value of `protolayer_iter_cb` callbacks. To be returned by *layer
* sequence return functions* (see glossary) as a sanity check. Not to be used
* directly by user code. */
enum protolayer_iter_cb_result {
PROTOLAYER_ITER_CB_RESULT_MAGIC = 0x364F392E,
};
/** Function type for `struct protolayer_globals::wrap` and `struct
* protolayer_globals::unwrap`. The function processes the provided
* `ctx->payload` and decides the next action for the currently processed
* sequence.
*
* The function (or another function, that the pointed-to function causes to be
* called, directly or through an asynchronous operation), must call one of the
* *layer sequence return functions* (see glossary) to advance (or end) the
* layer sequence. The function must return the result of such a return
* function. */
typedef enum protolayer_iter_cb_result (*protolayer_iter_cb)(
void *sess_data,
void *iter_data,
struct protolayer_iter_ctx *ctx);
/** Return value of `protolayer_event_cb` callbacks. Controls the flow of
* events. See `protolayer_event_cb` for details. */
enum protolayer_event_cb_result {
PROTOLAYER_EVENT_CONSUME = 0,
PROTOLAYER_EVENT_PROPAGATE = 1
};
/** Function type for `struct protolayer_globals::event_wrap` and `struct
* protolayer_globals::event_unwrap` callbacks of layers. The `baton` parameter
* points to the mutable, iteration-specific baton pointer, initialized by the
* `baton` parameter of one of the `session2_event` functions. The pointed-to
* value of `baton` may be modified to accommodate for the next layer in the
* sequence.
*
* When `PROTOLAYER_EVENT_PROPAGATE` is returned, iteration over the sequence
* of layers continues. When `PROTOLAYER_EVENT_CONSUME` is returned, iteration
* stops.
*
* **IMPORTANT:** A well-behaved layer will **ALWAYS** propagate events it knows
* nothing about. Only ever consume events you actually have good reason to
* consume (like TLS consumes `CONNECT` from TCP, because it needs to perform
* its own handshake first). */
typedef enum protolayer_event_cb_result (*protolayer_event_cb)(
enum protolayer_event_type event, void **baton,
struct session2 *session, void *sess_data);
/** Function type for initialization callbacks of layer session data.
*
* The `param` value is the one associated with the currently initialized
* layer, from the `layer_param` array of `session2_new()` - may be NULL if
* none is provided for the current layer.
*
* `data` points to the layer-specific data struct.
*
* Returning 0 means success, other return values mean error and halt the
* initialization. */
typedef int (*protolayer_data_sess_init_cb)(struct session2 *session,
void *data, void *param);
/** Function type for determining the size of a layer's wire buffer overhead. */
typedef size_t (*protolayer_wire_buf_overhead_cb)(bool outgoing);
/** Function type for (de)initialization callback of layer iteration data.
*
* `ctx` points to the iteration context that `data` belongs to.
*
* `data` points to the layer-specific data struct.
*
* Returning 0 means success, other return values mean error and halt the
* initialization. */
typedef int (*protolayer_iter_data_cb)(struct protolayer_iter_ctx *ctx,
void *data);
/** Function type for (de)initialization callbacks of layers.
*
* `data` points to the layer-specific data struct.
*
* Returning 0 means success, other return values mean error and halt the
* initialization. */
typedef int (*protolayer_data_cb)(struct session2 *session, void *data);
/** Function type for (de)initialization callbacks of DNS requests.
*
* `req` points to the request for initialization.
* `sess_data` points to layer-specific session data struct. */
typedef void (*protolayer_request_cb)(struct session2 *session,
struct kr_request *req,
void *sess_data);
/** Initialization parameters for protocol layer session data. */
struct protolayer_data_param {
enum protolayer_type protocol; /**< Which protocol these parameters
* are meant for. */
void *param; /**< Pointer to protolayer-related initialization
* parameters. Only needs to be valid during session
* initialization. */
};
/** Global data for a specific layered protocol. This is to be initialized in
* the `protolayer_globals` global array (below) during the the resolver's
* startup. It contains pointers to functions implementing a particular
* protocol, as well as other important data.
*
* Every member of this struct is allowed to be zero/NULL if a particular
* protocol has no use for it. */
struct protolayer_globals {
/** Size of the layer-specific data struct, valid per-session.
*
* The struct MUST begin with a `struct protolayer_data` member. If
* no session struct is used by the layer, the value may be zero. */
size_t sess_size;
/** Size of the layer-specific data struct, valid per-iteration. It
* gets created and destroyed together with a `struct
* protolayer_iter_ctx`.
*
* The struct MUST begin with a `struct protolayer_data` member. If
* no iteration struct is used by the layer, the value may be zero. */
size_t iter_size;
/** Number of bytes that this layer adds onto the session's wire buffer
* by default. All overheads in a group are summed together to form the
* resulting default wire buffer length.
*
* Ignored when `wire_buf_overhead_cb` is non-NULL. */
size_t wire_buf_overhead;
/** Called during session initialization to determine the number of
* bytes that this layer adds onto the session's wire buffer.
*
* It is the dynamic version of `wire_buf_overhead`, which is ignored
* when this is non-NULL. */
protolayer_wire_buf_overhead_cb wire_buf_overhead_cb;
/** Number of bytes that this layer adds onto the session's wire buffer
* at most. All overheads in a group are summed together to form the
* resulting default wire buffer length.
*
* If this is less than the default overhead, the default is used
* instead. */
size_t wire_buf_max_overhead;
/** Called during session creation to initialize
* layer-specific session data. The data is always provided
* zero-initialized to this function. */
protolayer_data_sess_init_cb sess_init;
/** Called during session destruction to deinitialize
* layer-specific session data. */
protolayer_data_cb sess_deinit;
/** Called at the beginning of a non-event layer sequence to initialize
* layer-specific iteration data. The data is always zero-initialized
* during iteration context initialization. */
protolayer_iter_data_cb iter_init;
/** Called at the end of a non-event layer sequence to deinitialize
* layer-specific iteration data. */
protolayer_iter_data_cb iter_deinit;
/** Strips the buffer of protocol-specific data. E.g. a HTTP layer
* removes HTTP status and headers. Optional - iteration continues
* automatically if this is NULL. */
protolayer_iter_cb unwrap;
/** Wraps the buffer into protocol-specific data. E.g. a HTTP layer
* adds HTTP status and headers. Optional - iteration continues
* automatically if this is NULL. */
protolayer_iter_cb wrap;
/** Processes events in the unwrap order (sent from the outside).
* Optional - iteration continues automatically if this is NULL. */
protolayer_event_cb event_unwrap;
/** Processes events in the wrap order (bounced back by the session).
* Optional - iteration continues automatically if this is NULL. */
protolayer_event_cb event_wrap;
/** Modifies the provided request for use with the layer. Mostly for
* setting `struct kr_request::qsource.comm_flags`. */
protolayer_request_cb request_init;
};
/** Global data about layered protocols. Mapped by `enum protolayer_type`.
* Individual protocols are to be initialized during resolver startup. */
extern struct protolayer_globals protolayer_globals[PROTOLAYER_TYPE_COUNT];
/** *Layer sequence return function* (see glossary) - signalizes the protolayer
* manager to continue processing the next layer. */
enum protolayer_iter_cb_result protolayer_continue(struct protolayer_iter_ctx *ctx);
/** *Layer sequence return function* (see glossary) - signalizes that the layer
* wants to stop processing of the buffer and clean up, possibly due to an error
* (indicated by a non-zero `status`). */
enum protolayer_iter_cb_result protolayer_break(struct protolayer_iter_ctx *ctx, int status);
/** *Layer sequence return function* (see glossary) - signalizes that the
* current sequence will continue in an asynchronous manner. The layer should
* store the context and call another sequence return function at another point.
* This may be used in layers that work through libraries whose operation is
* asynchronous, like GnuTLS.
*
* Note that this one is basically just a readability hint - another return
* function may be actually called before it (generally during a call to an
* external library function, e.g. GnuTLS or nghttp2). This is completely legal
* and the sequence will continue correctly. */
static inline enum protolayer_iter_cb_result protolayer_async(void)
{
return PROTOLAYER_ITER_CB_RESULT_MAGIC;
}
/** Indicates how a session sends data in the `wrap` direction and receives
* data in the `unwrap` direction. */
enum session2_transport_type {
SESSION2_TRANSPORT_NULL = 0,
SESSION2_TRANSPORT_IO,
SESSION2_TRANSPORT_PARENT,
};
/** A data unit for a single sequential data source. The data may be organized
* as a stream or a sequence of datagrams - this is up to the actual individual
* protocols used by the session - see `enum kr_proto` and
* `protolayer_`-prefixed types and functions for more information.
*
* A session processes data in two directions:
*
* - `_UNWRAP` deals with raw data received from the session's transport. It
* strips the ceremony of individual protocols from the buffers, retaining any
* required metadata in an iteration context (`struct protolayer_iter_ctx`).
* The last layer (as defined by a `protolayer_grp_*` array in `session2.c`) in
* a sequence is generally responsible for submitting the unwrapped data to be
* processed by an internal system, e.g. to be resolved as a DNS query.
*
* - `_WRAP` deals with data generated by an internal system. It adds the
* required protocol ceremony to it (e.g. encryption). The first layer (as
* defined by a `protolayer_grp_*` array in `session2.c`) is responsible for
* preparing the data to be sent through the session's transport. */
struct session2 {
/** Data for sending data out in the `wrap` direction and receiving new
* data in the `unwrap` direction. */
struct {
enum session2_transport_type type; /**< See `enum session2_transport_type` */
union {
/** For `_IO` type transport. Contains a libuv handle
* and session-related address storage. */
struct {
uv_handle_t *handle;
union kr_sockaddr peer;
union kr_sockaddr sockname;
} io;
/** For `_PARENT` type transport. */
struct session2 *parent;
};
} transport;
uv_timer_t timer; /**< For session-wide timeout events. */
enum protolayer_event_type timer_event; /**< The event fired on timeout. */
trie_t *tasks; /**< List of tasks associated with given session. */
queue_t(struct qr_task *) waiting; /**< List of tasks waiting for
* sending to upstream. */
struct wire_buf wire_buf;
uint32_t log_id; /**< Session ID for logging. */
int ref_count; /**< Number of unclosed libUV handles owned by this
* session + iteration contexts referencing the session. */
/** Communication information. Typically written into by one of the
* first layers facilitating transport protocol processing.
* Zero-initialized by default. */
struct comm_info comm_storage;
/** Time of last IO activity (if any occurs). Otherwise session
* creation time. */
uint64_t last_activity;
/** If true, the session's transport is towards an upstream server.
* Otherwise, it is towards a client. */
bool outgoing : 1;
/** If true, the session is at the end of its lifecycle and is about
* to close. */
bool closing : 1;
/** If true, the session has done something useful,
* e.g. it has produced a packet. */
bool was_useful : 1;
/** If true, encryption takes place in this session. Layers may use
* this to determine whether padding should be applied. A layer that
* provides security shall set this to `true` during session
* initialization. */
bool secure : 1;
/** If true, the session contains a stream-based protocol layer.
* Set during protocol layer initialization by the stream-based layer. */
bool stream : 1;
/** If true, the session contains a protocol layer with custom handling
* of malformed queries. This is used e.g. by the HTTP layer, which will
* return a Bad Request status on a malformed query. */
bool custom_emalf_handling : 1;
/** If true, session is being rate-limited. One of the protocol layers
* is going to be the writer for this flag. */
bool throttled : 1;
/* Protocol layers */
/** The set of protocol layers used by this session. */
enum kr_proto proto;
/** The size of a single iteration context
* (`struct protolayer_iter_ctx`), including layer-specific data. */
size_t iter_ctx_size;
/** The size of this session struct. */
size_t session_size;
/** The following flexible array has basically this structure:
*
* struct {
* size_t sess_offsets[num_layers];
* size_t iter_offsets[num_layers];
* variably-sized-data sess_data[num_layers];
* }
*
* It is done this way, because different layer groups will have
* different numbers of layers and differently-sized layer-specific
* data. C does not have a convenient way to define this in structs, so
* we do it via this flexible array.
*
* `sess_data` is a sequence of variably-sized CPU-aligned
* layer-specific structs.
*
* `sess_offsets` determines data offsets in `sess_data` for pointer
* retrieval.
*
* `iter_offsets` determines data offsets in `struct
* protolayer_iter_ctx::data` for pointer retrieval. */
alignas(CPU_STRUCT_ALIGN) char layer_data[];
};
/** Allocates and initializes a new session with the specified protocol layer
* group, and the provided transport context.
*
* `layer_param` is a pointer to an array of size `layer_param_count`. The
* parameters are passed to the layer session initializers. The parameter array
* is only required to be valid before this function returns. It is up to the
* individual layer implementations to determine the lifetime of the data
* pointed to by the parameters. */
struct session2 *session2_new(enum session2_transport_type transport_type,
enum kr_proto proto,
struct protolayer_data_param *layer_param,
size_t layer_param_count,
bool outgoing);
/** Allocates and initializes a new session with the specified protocol layer
* group, using a *libuv handle* as its transport. */
static inline struct session2 *session2_new_io(uv_handle_t *handle,
enum kr_proto layer_grp,
struct protolayer_data_param *layer_param,
size_t layer_param_count,
bool outgoing)
{
struct session2 *s = session2_new(SESSION2_TRANSPORT_IO, layer_grp,
layer_param, layer_param_count, outgoing);
s->transport.io.handle = handle;
handle->data = s;
s->ref_count++; /* Session owns the handle */
return s;
}
/** Allocates and initializes a new session with the specified protocol layer
* group, using a *parent session* as its transport. */
static inline struct session2 *session2_new_child(struct session2 *parent,
enum kr_proto layer_grp,
struct protolayer_data_param *layer_param,
size_t layer_param_count,
bool outgoing)
{
struct session2 *s = session2_new(SESSION2_TRANSPORT_PARENT, layer_grp,
layer_param, layer_param_count, outgoing);
s->transport.parent = parent;
return s;
}
/** Used when a libUV handle owned by the session is closed. Once all owned
* handles are closed, the session is freed. */
void session2_unhandle(struct session2 *s);
/** Start reading from the underlying transport. */
int session2_start_read(struct session2 *session);
/** Stop reading from the underlying transport. */
int session2_stop_read(struct session2 *session);
/** Gets the peer address from the specified session, iterating through the
* session hierarchy (child-to-parent) until an `_IO` session is found if
* needed.
*
* May return `NULL` if no peer is set. */
struct sockaddr *session2_get_peer(struct session2 *s);
/** Gets the sockname from the specified session, iterating through the
* session hierarchy (child-to-parent) until an `_IO` session is found if
* needed.
*
* May return `NULL` if no peer is set. */
struct sockaddr *session2_get_sockname(struct session2 *s);
/** Gets the libuv handle from the specified session, iterating through the
* session hierarchy (child-to-parent) until an `_IO` session is found if
* needed.
*
* May return `NULL` if no peer is set. */
KR_EXPORT uv_handle_t *session2_get_handle(struct session2 *s);
/** Start the session timer. On timeout, the specified `event` is sent in the
* `_UNWRAP` direction. Only a single timeout can be active at a time. */
int session2_timer_start(struct session2 *s, enum protolayer_event_type event,
uint64_t timeout, uint64_t repeat);
/** Restart the session timer without changing any of its parameters. */
int session2_timer_restart(struct session2 *s);
/** Stop the session timer. */
int session2_timer_stop(struct session2 *s);
int session2_tasklist_add(struct session2 *session, struct qr_task *task);
int session2_tasklist_del(struct session2 *session, struct qr_task *task);
struct qr_task *session2_tasklist_get_first(struct session2 *session);
struct qr_task *session2_tasklist_del_first(struct session2 *session, bool deref);
struct qr_task *session2_tasklist_find_msgid(const struct session2 *session, uint16_t msg_id);
struct qr_task *session2_tasklist_del_msgid(const struct session2 *session, uint16_t msg_id);
void session2_tasklist_finalize(struct session2 *session, int status);
int session2_tasklist_finalize_expired(struct session2 *session);
static inline size_t session2_tasklist_get_len(const struct session2 *session)
{
return trie_weight(session->tasks);
}
static inline bool session2_tasklist_is_empty(const struct session2 *session)
{
return session2_tasklist_get_len(session) == 0;
}
int session2_waitinglist_push(struct session2 *session, struct qr_task *task);
struct qr_task *session2_waitinglist_get(const struct session2 *session);
struct qr_task *session2_waitinglist_pop(struct session2 *session, bool deref);
void session2_waitinglist_retry(struct session2 *session, bool increase_timeout_cnt);
void session2_waitinglist_finalize(struct session2 *session, int status);
static inline size_t session2_waitinglist_get_len(const struct session2 *session)
{
return queue_len(session->waiting);
}
static inline bool session2_waitinglist_is_empty(const struct session2 *session)
{
return session2_waitinglist_get_len(session) == 0;
}
static inline bool session2_is_empty(const struct session2 *session)
{
return session2_tasklist_is_empty(session) &&
session2_waitinglist_is_empty(session);
}
/** Penalizes the server the specified `session` is connected to, if the session
* has not been useful (see `struct session2::was_useful`). Only applies to
* `outgoing` sessions, and the session should not be connection-less. */
void session2_penalize(struct session2 *session);
/** Sends the specified `payload` to be processed in the `_UNWRAP` direction by
* the session's protocol layers.
*
* The `comm` parameter may contain a pointer to comm data, e.g. for UDP, that
* comm data shall contain a pointer to the sender's `struct sockaddr_*`. If
* `comm` is `NULL`, session-wide data shall be used.
*
* Note that the payload data may be modified by any of the layers, to avoid
* making copies. Once the payload is passed to this function, the content of
* the referenced data is undefined to the caller.
*
* Once all layers are processed, `cb` is called with `baton` passed as one
* of its parameters. `cb` may also be `NULL`. See `protolayer_finished_cb` for
* more info.
*
* Returns one of `enum protolayer_ret` or a negative number
* indicating an error. */
int session2_unwrap(struct session2 *s, struct protolayer_payload payload,
const struct comm_info *comm, protolayer_finished_cb cb,
void *baton);
/** Same as `session2_unwrap`, but looks up the specified `protocol` in the
* session's assigned protocol group and sends the `payload` to the layer that
* is next in the sequence in the `_UNWRAP` direction.
*
* Layers may use this to generate their own data to send in the sequence, e.g.
* for protocol-specific ceremony. */
int session2_unwrap_after(struct session2 *s, enum protolayer_type protocol,
struct protolayer_payload payload,
const struct comm_info *comm,
protolayer_finished_cb cb, void *baton);
/** Sends the specified `payload` to be processed in the `_WRAP` direction by
* the session's protocol layers. The `target` parameter may contain a pointer
* to some data specific to the bottommost layer of this session.
*
* Note that the payload data may be modified by any of the layers, to avoid
* making copies. Once the payload is passed to this function, the content of
* the referenced data is undefined to the caller.
*
* Once all layers are processed, `cb` is called with `baton` passed as one
* of its parameters. `cb` may also be `NULL`. See `protolayer_finished_cb` for
* more info.
*
* Returns one of `enum protolayer_ret` or a negative number
* indicating an error. */
int session2_wrap(struct session2 *s, struct protolayer_payload payload,
const struct comm_info *comm, protolayer_finished_cb cb,
void *baton);
/** Same as `session2_wrap`, but looks up the specified `protocol` in the
* session's assigned protocol group and sends the `payload` to the layer that
* is next in the sequence in the `_WRAP` direction.
*
* Layers may use this to generate their own data to send in the sequence, e.g.
* for protocol-specific ceremony. */
int session2_wrap_after(struct session2 *s, enum protolayer_type protocol,
struct protolayer_payload payload,
const struct comm_info *comm,
protolayer_finished_cb cb, void *baton);
/** Sends an event to be synchronously processed by the protocol layers of the
* specified session. The layers are first iterated through in the `_UNWRAP`
* direction, then bounced back in the `_WRAP` direction. */
void session2_event(struct session2 *s, enum protolayer_event_type event, void *baton);
/** Sends an event to be synchronously processed by the protocol layers of the
* specified session, starting from the specified `protocol` in the `_UNWRAP`
* direction. The layers are first iterated through in the `_UNWRAP` direction,
* then bounced back in the `_WRAP` direction.
*
* NOTE: The bounced iteration does not exclude any layers - the layer
* specified by `protocol` and those before it are only skipped in the
* `_UNWRAP` direction! */
void session2_event_after(struct session2 *s, enum protolayer_type protocol,
enum protolayer_event_type event, void *baton);
/** Sends a `PROTOLAYER_EVENT_CLOSE` event to be processed by the protocol
* layers of the specified session. This function exists for readability
* reasons, to signal the intent that sending this event is used to actually
* close the session. */
static inline void session2_close(struct session2 *s)
{
session2_event(s, PROTOLAYER_EVENT_CLOSE, NULL);
}
/** Sends a `PROTOLAYER_EVENT_FORCE_CLOSE` event to be processed by the
* protocol layers of the specified session. This function exists for
* readability reasons, to signal the intent that sending this event is used to
* actually close the session. */
static inline void session2_force_close(struct session2 *s)
{
session2_event(s, PROTOLAYER_EVENT_FORCE_CLOSE, NULL);
}
/** Performs initial setup of the specified `req`, using the session's protocol
* layers. Layers are processed in the `_UNWRAP` direction. */
void session2_init_request(struct session2 *s, struct kr_request *req);
/** Removes the specified request task from the session's tasklist. The session
* must be outgoing. If the session is UDP, a signal to close is also sent to it. */
void session2_kill_ioreq(struct session2 *session, struct qr_task *task);
/** Update `last_activity` to the current timestamp. */
static inline void session2_touch(struct session2 *session)
{
session->last_activity = kr_now();
}
/*
* Copyright (C) 2016 American Civil Liberties Union (ACLU)
*
* Copyright (C) CZ.NIC, z.s.p.o
*
* Initial Author: Daniel Kahn Gillmor <dkg@fifthhorseman.net>
*
* This program is free software: you can redistribute it and/or modify it under
* the terms of the GNU General Public License as published by the Free
* Software Foundation, either version 3 of the License, or (at your option)
* any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along with
* this program. If not, see <http://www.gnu.org/licenses/>.
* Ondřej Surý <ondrej@sury.org>
*
* SPDX-License-Identifier: GPL-3.0-or-later
*/
#include <gnutls/gnutls.h>
#include <gnutls/x509.h>
#include <gnutls/abstract.h>
#include <gnutls/crypto.h>
#include <stdlib.h>
#include <errno.h>
#include <assert.h>
#include <gnutls/gnutls.h>
#include <gnutls/x509.h>
#include <uv.h>
#include <contrib/ucw/lib.h>
#include <errno.h>
#include <stdalign.h>
#include <stdlib.h>
#include "contrib/ucw/lib.h"
#include "contrib/base64.h"
#include "daemon/worker.h"
#include "daemon/tls.h"
#include "daemon/io.h"
#include "daemon/worker.h"
#include "daemon/session2.h"
static const char *priorities = "NORMAL";
#define EPHEMERAL_CERT_EXPIRATION_SECONDS_RENEW_BEFORE ((time_t)60*60*24*7)
#define GNUTLS_PIN_MIN_VERSION 0x030400
#define UNWRAP_BUF_SIZE 131072
#define TLS_CHUNK_SIZE ((size_t)16 * 1024)
/* gnutls_record_recv and gnutls_record_send */
struct tls_ctx_t {
gnutls_session_t session;
bool handshake_done;
#define VERBOSE_MSG(cl_side, ...)\
if (cl_side) \
kr_log_debug(TLSCLIENT, __VA_ARGS__); \
else \
kr_log_debug(TLS, __VA_ARGS__);
uv_stream_t *handle;
static const gnutls_datum_t tls_grp_alpn[KR_PROTO_COUNT] = {
[KR_PROTO_DOT] = { (uint8_t *)"dot", 3 },
[KR_PROTO_DOH] = { (uint8_t *)"h2", 2 },
};
typedef enum tls_client_hs_state {
TLS_HS_NOT_STARTED = 0,
TLS_HS_IN_PROGRESS,
TLS_HS_DONE,
TLS_HS_CLOSING,
TLS_HS_LAST
} tls_hs_state_t;
/* for reading from the network */
const uint8_t *buf;
ssize_t nread;
ssize_t consumed;
uint8_t recv_buf[4096];
struct tls_credentials *credentials;
struct pl_tls_sess_data {
struct protolayer_data h;
bool client_side;
bool first_handshake_done;
gnutls_session_t tls_session;
tls_hs_state_t handshake_state;
protolayer_iter_ctx_queue_t unwrap_queue;
protolayer_iter_ctx_queue_t wrap_queue;
struct wire_buf unwrap_buf;
size_t write_queue_size;
union {
struct tls_credentials *server_credentials;
tls_client_param_t *client_params; /**< Ref-counted. */
};
};
/** @internal Debugging facility. */
#ifdef DEBUG
#define DEBUG_MSG(fmt...) fprintf(stderr, "[tls] " fmt)
#else
#define DEBUG_MSG(fmt...)
#endif
static void
kres_gnutls_log(int level, const char *message)
{
kr_log_error("[tls] gnutls: (%d) %s", level, message);
}
struct tls_credentials * tls_get_ephemeral_credentials(void);
void tls_credentials_log_pins(struct tls_credentials *tls_credentials);
static int client_verify_certificate(gnutls_session_t tls_session);
static struct tls_credentials *tls_credentials_reserve(struct tls_credentials *tls_credentials);
void
tls_setup_logging(bool verbose)
{
gnutls_global_set_log_function(kres_gnutls_log);
gnutls_global_set_log_level(verbose ? 1 : 0);
/**
* Set restrictions on TLS features, in particular ciphers.
*
* We explicitly disable features according to:
* https://datatracker.ietf.org/doc/html/rfc8310#section-9
* in case the gnutls+OS defaults allow them.
* Performance optimizations are not implemented at the moment.
*
* OS defaults are taken into account, e.g. on Red Hat family there's
* /etc/crypto-policies/back-ends/gnutls.config and update-crypto-policies tool.
*/
static int kres_gnutls_set_priority(gnutls_session_t session) {
static const char * const extra_prio =
"-VERS-TLS1.0:-VERS-TLS1.1:-VERS-SSL3.0:-ARCFOUR-128:-COMP-ALL:+COMP-NULL";
const char *errpos = NULL;
int err = gnutls_set_default_priority_append(session, extra_prio, &errpos, 0);
if (err != GNUTLS_E_SUCCESS) {
kr_log_error(TLS, "setting priority '%s' failed at character %zd (...'%s') with %s (%d)\n",
extra_prio, errpos - extra_prio, errpos, gnutls_strerror_name(err), err);
}
return err;
}
static ssize_t kres_gnutls_push(gnutls_transport_ptr_t h, const void *buf, size_t len)
static ssize_t kres_gnutls_pull(gnutls_transport_ptr_t h, void *buf, size_t len)
{
struct tls_ctx_t *t = (struct tls_ctx_t *)h;
const uv_buf_t ub = {(void *)buf, len};
DEBUG_MSG("[tls] push %zu <%p>\n", len, h);
if (t == NULL) {
struct pl_tls_sess_data *tls = h;
if (kr_fails_assert(tls)) {
errno = EFAULT;
return -1;
}
int ret = uv_try_write(t->handle, &ub, 1);
if (ret > 0) {
return (ssize_t) ret;
}
if (ret == UV_EAGAIN) {
bool avail = protolayer_queue_has_payload(&tls->unwrap_queue);
VERBOSE_MSG(tls->client_side, "pull wanted: %zu avail: %s\n",
len, avail ? "yes" : "no");
if (!avail) {
errno = EAGAIN;
} else {
kr_log_error("[tls] uv_try_write: %s\n", uv_strerror(ret));
errno = EIO;
return -1;
}
return -1;
}
char *dest = buf;
size_t transfer = 0;
while (queue_len(tls->unwrap_queue) > 0 && len > 0) {
struct protolayer_iter_ctx *ctx = queue_head(tls->unwrap_queue);
struct protolayer_payload *pld = &ctx->payload;
static ssize_t kres_gnutls_pull(gnutls_transport_ptr_t h, void *buf, size_t len)
{
struct tls_ctx_t *t = (struct tls_ctx_t *)h;
assert(t != NULL);
bool fully_consumed = false;
if (pld->type == PROTOLAYER_PAYLOAD_BUFFER) {
size_t to_copy = MIN(len, pld->buffer.len);
ssize_t avail = t->nread - t->consumed;
DEBUG_MSG("[tls] pull wanted: %zu available: %zu\n", len, avail);
if (t->nread <= t->consumed) {
errno = EAGAIN;
return -1;
}
memcpy(dest, pld->buffer.buf, to_copy);
dest += to_copy;
len -= to_copy;
pld->buffer.buf = (char *)pld->buffer.buf + to_copy;
pld->buffer.len -= to_copy;
transfer += to_copy;
ssize_t transfer = MIN(avail, len);
memcpy(buf, t->buf + t->consumed, transfer);
t->consumed += transfer;
return transfer;
}
if (pld->buffer.len == 0)
fully_consumed = true;
} else if (pld->type == PROTOLAYER_PAYLOAD_IOVEC) {
while (pld->iovec.cnt && len > 0) {
struct iovec *iov = pld->iovec.iov;
size_t to_copy = MIN(len, iov->iov_len);
struct tls_ctx_t *tls_new(struct worker_ctx *worker)
{
assert(worker != NULL);
assert(worker->engine != NULL);
memcpy(dest, iov->iov_base, to_copy);
dest += to_copy;
len -= to_copy;
iov->iov_base = ((char *)iov->iov_base) + to_copy;
iov->iov_len -= to_copy;
transfer += to_copy;
struct network *net = &worker->engine->net;
if (!net->tls_credentials) {
kr_log_error("[tls] x509 credentials are missing; no TLS\n");
return NULL;
}
if (iov->iov_len == 0) {
pld->iovec.iov++;
pld->iovec.cnt--;
}
}
struct tls_ctx_t *tls = calloc(1, sizeof(struct tls_ctx_t));
if (tls == NULL) {
kr_log_error("[tls] failed to allocate TLS context\n");
return NULL;
}
if (pld->iovec.cnt == 0)
fully_consumed = true;
} else if (pld->type == PROTOLAYER_PAYLOAD_WIRE_BUF) {
size_t wbl = wire_buf_data_length(pld->wire_buf);
size_t to_copy = MIN(len, wbl);
memcpy(dest, wire_buf_data(pld->wire_buf), to_copy);
dest += to_copy;
len -= to_copy;
transfer += to_copy;
int err = gnutls_init(&tls->session, GNUTLS_SERVER | GNUTLS_NONBLOCK);
if (err < 0) {
kr_log_error("[tls] gnutls_init(): %s (%d)\n", gnutls_strerror_name(err), err);
tls_free(tls);
return NULL;
}
tls->credentials = tls_credentials_reserve(net->tls_credentials);
err = gnutls_credentials_set(tls->session, GNUTLS_CRD_CERTIFICATE, tls->credentials->credentials);
if (err < 0) {
kr_log_error("[tls] gnutls_credentials_set(): %s (%d)\n", gnutls_strerror_name(err), err);
tls_free(tls);
return NULL;
}
const char *errpos = NULL;
err = gnutls_priority_set_direct(tls->session, priorities, &errpos);
if (err < 0) {
kr_log_error("[tls] setting priority '%s' failed at character %zd (...'%s') with %s (%d)\n",
priorities, errpos - priorities, errpos, gnutls_strerror_name(err), err);
tls_free(tls);
return NULL;
wire_buf_trim(pld->wire_buf, to_copy);
if (wire_buf_data_length(pld->wire_buf) == 0) {
wire_buf_reset(pld->wire_buf);
fully_consumed = true;
}
} else if (!pld->type) {
fully_consumed = true;
} else {
kr_assert(false && "Unsupported payload type");
errno = EFAULT;
return -1;
}
if (!fully_consumed) /* `len` was smaller than the sum of payloads */
break;
if (queue_len(tls->unwrap_queue) > 1) {
/* Finalize queued contexts, except for the last one. */
protolayer_break(ctx, kr_ok());
queue_pop(tls->unwrap_queue);
} else {
/* The last queued context will `continue` on the next
* `gnutls_record_recv`. */
ctx->payload.type = PROTOLAYER_PAYLOAD_NULL;
break;
}
}
gnutls_transport_set_pull_function(tls->session, kres_gnutls_pull);
gnutls_transport_set_push_function(tls->session, kres_gnutls_push);
gnutls_transport_set_ptr(tls->session, tls);
return tls;
VERBOSE_MSG(tls->client_side, "pull transfer: %zu\n", transfer);
return transfer;
}
void tls_free(struct tls_ctx_t *tls)
{
if (!tls) {
return;
}
struct kres_gnutls_push_ctx {
struct pl_tls_sess_data *sess_data;
struct iovec iov[];
};
if (tls->session) {
/* Don't terminate TLS connection, just tear it down */
gnutls_deinit(tls->session);
tls->session = NULL;
static void kres_gnutls_push_finished(int status, struct session2 *session,
const struct comm_info *comm, void *baton)
{
struct kres_gnutls_push_ctx *push_ctx = baton;
struct pl_tls_sess_data *tls = push_ctx->sess_data;
while (queue_len(tls->wrap_queue)) {
struct protolayer_iter_ctx *ctx = queue_head(tls->wrap_queue);
protolayer_break(ctx, kr_ok());
queue_pop(tls->wrap_queue);
}
tls_credentials_release(tls->credentials);
free(tls);
free(push_ctx);
}
int tls_push(struct qr_task *task, uv_handle_t* handle, knot_pkt_t * pkt)
static ssize_t kres_gnutls_vec_push(gnutls_transport_ptr_t h, const giovec_t * iov, int iovcnt)
{
if (!pkt || !handle || !handle->data) {
return kr_error(EINVAL);
struct pl_tls_sess_data *tls = h;
if (kr_fails_assert(tls)) {
errno = EFAULT;
return -1;
}
struct session *session = handle->data;
const uint16_t pkt_size = htons(pkt->size);
struct tls_ctx_t *tls_p = session->tls_ctx;
if (!tls_p) {
kr_log_error("[tls] no tls context on push\n");
return kr_error(ENOENT);
if (iovcnt == 0) {
return 0;
}
gnutls_record_cork(tls_p->session);
ssize_t count = 0;
if ((count = gnutls_record_send(tls_p->session, &pkt_size, sizeof(pkt_size)) < 0) ||
(count = gnutls_record_send(tls_p->session, pkt->wire, pkt->size) < 0)) {
kr_log_error("[tls] gnutls_record_send failed: %s (%zd)\n", gnutls_strerror_name(count), count);
return kr_error(EIO);
if (kr_fails_assert(iovcnt > 0)) {
errno = EINVAL;
return -1;
}
ssize_t submitted = 0;
do {
count = gnutls_record_uncork(tls_p->session, 0);
if (count < 0) {
if (gnutls_error_is_fatal(count)) {
kr_log_error("[tls] gnutls_record_uncork failed: %s (%zd)\n",
gnutls_strerror_name(count), count);
return kr_error(EIO);
}
size_t total_len = 0;
for (int i = 0; i < iovcnt; i++)
total_len += iov[i].iov_len;
struct kres_gnutls_push_ctx *push_ctx =
malloc(sizeof(*push_ctx) + sizeof(struct iovec[iovcnt]));
kr_require(push_ctx);
push_ctx->sess_data = tls;
memcpy(push_ctx->iov, iov, sizeof(struct iovec[iovcnt]));
session2_wrap_after(tls->h.session, PROTOLAYER_TYPE_TLS,
protolayer_payload_iovec(push_ctx->iov, iovcnt, true),
NULL, kres_gnutls_push_finished, push_ctx);
return total_len;
}
static void tls_handshake_success(struct pl_tls_sess_data *tls,
struct session2 *session)
{
if (tls->client_side) {
tls_client_param_t *tls_params = tls->client_params;
gnutls_session_t tls_session = tls->tls_session;
if (gnutls_session_is_resumed(tls_session) != 0) {
kr_log_debug(TLSCLIENT, "TLS session has resumed\n");
} else {
submitted += count;
if (count == 0 && submitted != sizeof(pkt_size) + pkt->size) {
kr_log_error("[tls] gnutls_record_uncork didn't send all data: %s (%zd)\n",
gnutls_strerror_name(count), count);
return kr_error(EIO);
kr_log_debug(TLSCLIENT, "TLS session has not resumed\n");
/* session wasn't resumed, delete old session data ... */
if (tls_params->session_data.data != NULL) {
gnutls_free(tls_params->session_data.data);
tls_params->session_data.data = NULL;
tls_params->session_data.size = 0;
}
/* ... and get the new session data */
gnutls_datum_t tls_session_data = { NULL, 0 };
int ret = gnutls_session_get_data2(tls_session, &tls_session_data);
if (ret == 0) {
tls_params->session_data = tls_session_data;
}
}
} while (submitted != sizeof(pkt_size) + pkt->size);
return kr_ok();
}
if (!tls->first_handshake_done) {
session2_event_after(session, PROTOLAYER_TYPE_TLS,
PROTOLAYER_EVENT_CONNECT, NULL);
tls->first_handshake_done = true;
}
}
int tls_process(struct worker_ctx *worker, uv_stream_t *handle, const uint8_t *buf, ssize_t nread)
/** Perform TLS handshake and handle error codes according to the documentation.
* See See https://gnutls.org/manual/html_node/TLS-handshake.html#TLS-handshake
* The function returns kr_ok() or success or non fatal error, kr_error(EAGAIN) on blocking, or kr_error(EIO) on fatal error.
*/
static int tls_handshake(struct pl_tls_sess_data *tls, struct session2 *session)
{
struct session *session = handle->data;
struct tls_ctx_t *tls_p = session->tls_ctx;
if (!tls_p) {
return kr_error(ENOSYS);
int err = gnutls_handshake(tls->tls_session);
if (err == GNUTLS_E_SUCCESS) {
/* Handshake finished, return success */
tls->handshake_state = TLS_HS_DONE;
struct sockaddr *peer = session2_get_peer(session);
VERBOSE_MSG(tls->client_side, "TLS handshake with %s has completed\n",
kr_straddr(peer));
tls_handshake_success(tls, session);
} else if (err == GNUTLS_E_AGAIN) {
return kr_error(EAGAIN);
} else if (gnutls_error_is_fatal(err)) {
/* Fatal errors, return error as it's not recoverable */
VERBOSE_MSG(tls->client_side, "gnutls_handshake failed: %s (%d)\n",
gnutls_strerror_name(err), err);
/* Notify the peer about handshake failure via an alert. */
gnutls_alert_send_appropriate(tls->tls_session, err);
enum protolayer_event_type etype = (tls->first_handshake_done)
? PROTOLAYER_EVENT_DISCONNECT
: PROTOLAYER_EVENT_CONNECT_FAIL;
session2_event(session, etype,
(void *)KR_SELECTION_TLS_HANDSHAKE_FAILED);
return kr_error(EIO);
} else if (err == GNUTLS_E_WARNING_ALERT_RECEIVED) {
/* Handle warning when in verbose mode */
const char *alert_name = gnutls_alert_get_name(gnutls_alert_get(tls->tls_session));
if (alert_name != NULL) {
struct sockaddr *peer = session2_get_peer(session);
VERBOSE_MSG(tls->client_side, "TLS alert from %s received: %s\n",
kr_straddr(peer), alert_name);
}
}
return kr_ok();
}
tls_p->buf = buf;
tls_p->nread = nread;
tls_p->handle = handle;
tls_p->consumed = 0; /* TODO: doesn't handle split TLS records */
/* Ensure TLS handshake is performed before receiving data. */
while (!tls_p->handshake_done) {
int err = gnutls_handshake(tls_p->session);
if (err == GNUTLS_E_SUCCESS) {
tls_p->handshake_done = true;
} else if (err == GNUTLS_E_AGAIN) {
return 0; /* No data, bail out */
} else if (err < 0 && gnutls_error_is_fatal(err)) {
return kr_error(err);
}
/*! Close a TLS context (call gnutls_bye()) */
static void tls_close(struct pl_tls_sess_data *tls, struct session2 *session, bool allow_bye)
{
if (tls == NULL || tls->tls_session == NULL || kr_fails_assert(session))
return;
/* Store the current session data for potential resumption of this session */
if (session->outgoing && tls->client_params) {
gnutls_free(tls->client_params->session_data.data);
tls->client_params->session_data.data = NULL;
tls->client_params->session_data.size = 0;
gnutls_session_get_data2(
tls->tls_session,
&tls->client_params->session_data);
}
int submitted = 0;
while (true) {
ssize_t count = gnutls_record_recv(tls_p->session, tls_p->recv_buf, sizeof(tls_p->recv_buf));
if (count == GNUTLS_E_AGAIN) {
break; /* No data available */
} else if (count == GNUTLS_E_INTERRUPTED) {
continue; /* Try reading again */
} else if (count < 0) {
kr_log_error("[tls] gnutls_record_recv failed: %s (%zd)\n",
gnutls_strerror_name(count), count);
return kr_error(EIO);
}
DEBUG_MSG("[tls] submitting %zd data to worker\n", count);
int ret = worker_process_tcp(worker, handle, tls_p->recv_buf, count);
if (ret < 0) {
return ret;
}
submitted += ret;
const struct sockaddr *peer = session2_get_peer(session);
if (allow_bye && tls->handshake_state == TLS_HS_DONE) {
VERBOSE_MSG(tls->client_side, "closing tls connection to `%s`\n",
kr_straddr(peer));
tls->handshake_state = TLS_HS_CLOSING;
gnutls_bye(tls->tls_session, GNUTLS_SHUT_RDWR);
} else {
VERBOSE_MSG(tls->client_side, "closing tls connection to `%s` (without bye)\n",
kr_straddr(peer));
}
return submitted;
}
#if GNUTLS_VERSION_NUMBER >= 0x030400
/*
DNS-over-TLS Out of band key-pinned authentication profile uses the
same form of pins as HPKP:
e.g. pin-sha256="FHkyLhvI0n70E47cJlRTamTrnYVcsYdjUGbr79CfAVI="
DNS-over-TLS OOB key-pins: https://tools.ietf.org/html/rfc7858#appendix-A
HPKP pin reference: https://tools.ietf.org/html/rfc7469#appendix-A
*/
#define PINLEN (((32) * 8 + 4)/6) + 3 + 1
#define PINLEN ((((32) * 8 + 4)/6) + 3 + 1)
/* out must be at least PINLEN octets long */
static int get_oob_key_pin(gnutls_x509_crt_t crt, char *outchar, ssize_t outchar_len)
/* Compute pin_sha256 for the certificate.
* It may be in raw format - just TLS_SHA256_RAW_LEN bytes without termination,
* or it may be a base64 0-terminated string requiring up to
* TLS_SHA256_BASE64_BUFLEN bytes.
* \return error code */
static int get_oob_key_pin(gnutls_x509_crt_t crt, char *outchar, ssize_t outchar_len, bool raw)
{
int err;
/* TODO: simplify this function by using gnutls_x509_crt_get_key_id() */
if (kr_fails_assert(!raw || outchar_len >= TLS_SHA256_RAW_LEN)) {
return kr_error(ENOSPC);
/* With !raw we have check inside kr_base64_encode. */
}
gnutls_pubkey_t key;
gnutls_datum_t datum = { .size = 0 };
int err = gnutls_pubkey_init(&key);
if (err != GNUTLS_E_SUCCESS) return err;
if ((err = gnutls_pubkey_init(&key)) < 0) {
return err;
}
gnutls_datum_t datum = { .data = NULL, .size = 0 };
err = gnutls_pubkey_import_x509(key, crt, 0);
if (err != GNUTLS_E_SUCCESS) goto leave;
err = gnutls_pubkey_export2(key, GNUTLS_X509_FMT_DER, &datum);
if (err != GNUTLS_E_SUCCESS) goto leave;
if ((err = gnutls_pubkey_import_x509(key, crt, 0)) != GNUTLS_E_SUCCESS) {
char raw_pin[TLS_SHA256_RAW_LEN]; /* TMP buffer if raw == false */
err = gnutls_hash_fast(GNUTLS_DIG_SHA256, datum.data, datum.size,
(raw ? outchar : raw_pin));
if (err != GNUTLS_E_SUCCESS || raw/*success*/)
goto leave;
} else {
if ((err = gnutls_pubkey_export2(key, GNUTLS_X509_FMT_DER, &datum)) != GNUTLS_E_SUCCESS) {
goto leave;
} else {
uint8_t raw_pin[32];
if ((err = gnutls_hash_fast(GNUTLS_DIG_SHA256, datum.data, datum.size, raw_pin)) != GNUTLS_E_SUCCESS) {
goto leave;
} else {
base64_encode(raw_pin, sizeof(raw_pin), (uint8_t *)outchar, outchar_len);
}
}
/* Convert to non-raw. */
err = kr_base64_encode((uint8_t *)raw_pin, sizeof(raw_pin),
(uint8_t *)outchar, outchar_len);
if (err >= 0 && err < outchar_len) {
outchar[err] = '\0'; /* kr_base64_encode() doesn't do it */
err = GNUTLS_E_SUCCESS;
} else if (kr_fails_assert(err < 0)) {
outchar[outchar_len - 1] = '\0';
err = kr_error(ENOSPC); /* base64 fits but '\0' doesn't */
}
leave:
gnutls_free(datum.data);
......@@ -309,38 +395,38 @@ static int get_oob_key_pin(gnutls_x509_crt_t crt, char *outchar, ssize_t outchar
return err;
}
/*! Log DNS-over-TLS OOB key-pin form of current credentials:
* https://tools.ietf.org/html/rfc7858#appendix-A */
void tls_credentials_log_pins(struct tls_credentials *tls_credentials)
{
for (int index = 0;; index++) {
int err;
gnutls_x509_crt_t *certs = NULL;
unsigned int cert_count = 0;
if ((err = gnutls_certificate_get_x509_crt(tls_credentials->credentials, index, &certs, &cert_count)) != GNUTLS_E_SUCCESS) {
int err = gnutls_certificate_get_x509_crt(tls_credentials->credentials,
index, &certs, &cert_count);
if (err != GNUTLS_E_SUCCESS) {
if (err != GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE) {
kr_log_error("[tls] could not get x509 certificates (%d) %s\n", err, gnutls_strerror_name(err));
kr_log_error(TLS, "could not get X.509 certificates (%d) %s\n",
err, gnutls_strerror_name(err));
}
return;
}
for (int i = 0; i < cert_count; i++) {
char pin[PINLEN] = { 0 };
if ((err = get_oob_key_pin(certs[i], pin, sizeof(pin))) != GNUTLS_E_SUCCESS) {
kr_log_error("[tls] could not calculate RFC 7858 OOB key-pin from cert %d (%d) %s\n", i, err, gnutls_strerror_name(err));
char pin[TLS_SHA256_BASE64_BUFLEN] = { 0 };
err = get_oob_key_pin(certs[i], pin, sizeof(pin), false);
if (err != GNUTLS_E_SUCCESS) {
kr_log_error(TLS, "could not calculate RFC 7858 OOB key-pin from cert %d (%d) %s\n",
i, err, gnutls_strerror_name(err));
} else {
kr_log_info("[tls] RFC 7858 OOB key-pin (%d): pin-sha256=\"%s\"\n", i, pin);
kr_log_info(TLS, "RFC 7858 OOB key-pin (%d): pin-sha256=\"%s\"\n",
i, pin);
}
gnutls_x509_crt_deinit(certs[i]);
}
gnutls_free(certs);
}
}
#else
void tls_credentials_log_pins(struct tls_credentials *tls_credentials)
{
kr_log_error("[tls] could not calculate RFC 7858 OOB key-pin; GnuTLS 3.4.0+ required\n");
}
#endif
static int str_replace(char **where_ptr, const char *with)
{
......@@ -354,9 +440,40 @@ static int str_replace(char **where_ptr, const char *with)
return kr_ok();
}
int tls_certificate_set(struct network *net, const char *tls_cert, const char *tls_key)
static time_t get_end_entity_expiration(gnutls_certificate_credentials_t creds)
{
gnutls_datum_t data;
gnutls_x509_crt_t cert = NULL;
int err;
time_t ret = GNUTLS_X509_NO_WELL_DEFINED_EXPIRATION;
if ((err = gnutls_certificate_get_crt_raw(creds, 0, 0, &data)) != GNUTLS_E_SUCCESS) {
kr_log_error(TLS, "failed to get cert to check expiration: (%d) %s\n",
err, gnutls_strerror_name(err));
goto done;
}
if ((err = gnutls_x509_crt_init(&cert)) != GNUTLS_E_SUCCESS) {
kr_log_error(TLS, "failed to initialize cert: (%d) %s\n",
err, gnutls_strerror_name(err));
goto done;
}
if ((err = gnutls_x509_crt_import(cert, &data, GNUTLS_X509_FMT_DER)) != GNUTLS_E_SUCCESS) {
kr_log_error(TLS, "failed to construct cert while checking expiration: (%d) %s\n",
err, gnutls_strerror_name(err));
goto done;
}
ret = gnutls_x509_crt_get_expiration_time (cert);
done:
/* do not free data; g_c_get_crt_raw() says to treat it as
* constant. */
gnutls_x509_crt_deinit(cert);
return ret;
}
int tls_certificate_set(const char *tls_cert, const char *tls_key)
{
if (!net) {
if (kr_fails_assert(the_network)) {
return kr_error(EINVAL);
}
......@@ -366,15 +483,15 @@ int tls_certificate_set(struct network *net, const char *tls_cert, const char *t
}
int err = 0;
if ((err = gnutls_certificate_allocate_credentials(&tls_credentials->credentials)) < 0) {
kr_log_error("[tls] gnutls_certificate_allocate_credentials() failed: (%d) %s\n",
if ((err = gnutls_certificate_allocate_credentials(&tls_credentials->credentials)) != GNUTLS_E_SUCCESS) {
kr_log_error(TLS, "gnutls_certificate_allocate_credentials() failed: (%d) %s\n",
err, gnutls_strerror_name(err));
tls_credentials_free(tls_credentials);
return kr_error(ENOMEM);
}
if ((err = gnutls_certificate_set_x509_system_trust(tls_credentials->credentials)) < 0) {
if (err != GNUTLS_E_UNIMPLEMENTED_FEATURE) {
kr_log_error("[tls] warning: gnutls_certificate_set_x509_system_trust() failed: (%d) %s\n",
kr_log_warning(TLS, "warning: gnutls_certificate_set_x509_system_trust() failed: (%d) %s\n",
err, gnutls_strerror_name(err));
tls_credentials_free(tls_credentials);
return err;
......@@ -386,20 +503,23 @@ int tls_certificate_set(struct network *net, const char *tls_cert, const char *t
tls_credentials_free(tls_credentials);
return kr_error(ENOMEM);
}
if ((err = gnutls_certificate_set_x509_key_file(tls_credentials->credentials,
tls_cert, tls_key, GNUTLS_X509_FMT_PEM)) < 0) {
tls_cert, tls_key, GNUTLS_X509_FMT_PEM)) != GNUTLS_E_SUCCESS) {
tls_credentials_free(tls_credentials);
kr_log_error("[tls] gnutls_certificate_set_x509_key_file(%s,%s) failed: %d (%s)\n",
kr_log_error(TLS, "gnutls_certificate_set_x509_key_file(%s,%s) failed: %d (%s)\n",
tls_cert, tls_key, err, gnutls_strerror_name(err));
return kr_error(EINVAL);
}
// Exchange the x509 credentials
struct tls_credentials *old_credentials = net->tls_credentials;
// Start using the new x509_credentials
net->tls_credentials = tls_credentials;
tls_credentials_log_pins(net->tls_credentials);
/* record the expiration date: */
tls_credentials->valid_until = get_end_entity_expiration(tls_credentials->credentials);
/* Exchange the x509 credentials */
struct tls_credentials *old_credentials = the_network->tls_credentials;
/* Start using the new x509_credentials */
the_network->tls_credentials = tls_credentials;
tls_credentials_log_pins(the_network->tls_credentials);
if (old_credentials) {
err = tls_credentials_release(old_credentials);
......@@ -411,7 +531,9 @@ int tls_certificate_set(struct network *net, const char *tls_cert, const char *t
return kr_ok();
}
struct tls_credentials *tls_credentials_reserve(struct tls_credentials *tls_credentials) {
/*! Borrow TLS credentials for context. */
static struct tls_credentials *tls_credentials_reserve(struct tls_credentials *tls_credentials)
{
if (!tls_credentials) {
return NULL;
}
......@@ -419,7 +541,9 @@ struct tls_credentials *tls_credentials_reserve(struct tls_credentials *tls_cred
return tls_credentials;
}
int tls_credentials_release(struct tls_credentials *tls_credentials) {
/*! Release TLS credentials for context (decrements refcount or frees). */
int tls_credentials_release(struct tls_credentials *tls_credentials)
{
if (!tls_credentials) {
return kr_error(EINVAL);
}
......@@ -431,7 +555,9 @@ int tls_credentials_release(struct tls_credentials *tls_credentials) {
return kr_ok();
}
void tls_credentials_free(struct tls_credentials *tls_credentials) {
/*! Free TLS credentials, must not be called if it holds positive refcount. */
void tls_credentials_free(struct tls_credentials *tls_credentials)
{
if (!tls_credentials) {
return;
}
......@@ -445,7 +571,816 @@ void tls_credentials_free(struct tls_credentials *tls_credentials) {
if (tls_credentials->tls_key) {
free(tls_credentials->tls_key);
}
if (tls_credentials->ephemeral_servicename) {
free(tls_credentials->ephemeral_servicename);
}
free(tls_credentials);
}
#undef DEBUG_MSG
void tls_client_param_unref(tls_client_param_t *entry)
{
if (!entry || kr_fails_assert(entry->refs)) return;
--(entry->refs);
if (entry->refs) return;
VERBOSE_MSG(true, "freeing TLS parameters %p\n", (void *)entry);
for (int i = 0; i < entry->ca_files.len; ++i) {
free_const(entry->ca_files.at[i]);
}
array_clear(entry->ca_files);
free_const(entry->hostname);
for (int i = 0; i < entry->pins.len; ++i) {
free_const(entry->pins.at[i]);
}
array_clear(entry->pins);
if (entry->credentials) {
gnutls_certificate_free_credentials(entry->credentials);
}
if (entry->session_data.data) {
gnutls_free(entry->session_data.data);
}
free(entry);
}
static int param_free(void **param, void *null)
{
if (kr_fails_assert(param && *param))
return -1;
tls_client_param_unref(*param);
return 0;
}
void tls_client_params_free(tls_client_params_t *params)
{
if (!params) return;
trie_apply(params, param_free, NULL);
trie_free(params);
}
tls_client_param_t * tls_client_param_new(void)
{
tls_client_param_t *e = calloc(1, sizeof(*e));
if (kr_fails_assert(e))
return NULL;
/* Note: those array_t don't need further initialization. */
e->refs = 1;
int ret = gnutls_certificate_allocate_credentials(&e->credentials);
if (ret != GNUTLS_E_SUCCESS) {
kr_log_error(TLSCLIENT, "error: gnutls_certificate_allocate_credentials() fails (%s)\n",
gnutls_strerror_name(ret));
free(e);
return NULL;
}
gnutls_certificate_set_verify_function(e->credentials, client_verify_certificate);
return e;
}
/**
* Convert an IP address and port number to binary key.
*
* \precond buffer \param key must have sufficient size
* \param addr[in]
* \param len[out] output length
* \param key[out] output buffer
*/
static bool construct_key(const union kr_sockaddr *addr, uint32_t *len, char *key)
{
switch (addr->ip.sa_family) {
case AF_INET:
memcpy(key, &addr->ip4.sin_port, sizeof(addr->ip4.sin_port));
memcpy(key + sizeof(addr->ip4.sin_port), &addr->ip4.sin_addr,
sizeof(addr->ip4.sin_addr));
*len = sizeof(addr->ip4.sin_port) + sizeof(addr->ip4.sin_addr);
return true;
case AF_INET6:
memcpy(key, &addr->ip6.sin6_port, sizeof(addr->ip6.sin6_port));
memcpy(key + sizeof(addr->ip6.sin6_port), &addr->ip6.sin6_addr,
sizeof(addr->ip6.sin6_addr));
*len = sizeof(addr->ip6.sin6_port) + sizeof(addr->ip6.sin6_addr);
return true;
default:
kr_assert(!EINVAL);
return false;
}
}
tls_client_param_t **tls_client_param_getptr(tls_client_params_t **params,
const struct sockaddr *addr, bool do_insert)
{
if (kr_fails_assert(params && addr))
return NULL;
/* We accept NULL for empty map; ensure the map exists if needed. */
if (!*params) {
if (!do_insert) return NULL;
*params = trie_create(NULL);
if (kr_fails_assert(*params))
return NULL;
}
/* Construct the key. */
const union kr_sockaddr *ia = (const union kr_sockaddr *)addr;
char key[sizeof(ia->ip6.sin6_port) + sizeof(ia->ip6.sin6_addr)];
uint32_t len;
if (!construct_key(ia, &len, key))
return NULL;
/* Get the entry. */
return (tls_client_param_t **)
(do_insert ? trie_get_ins : trie_get_try)(*params, key, len);
}
int tls_client_param_remove(tls_client_params_t *params, const struct sockaddr *addr)
{
const union kr_sockaddr *ia = (const union kr_sockaddr *)addr;
char key[sizeof(ia->ip6.sin6_port) + sizeof(ia->ip6.sin6_addr)];
uint32_t len;
if (!construct_key(ia, &len, key))
return kr_error(EINVAL);
trie_val_t param_ptr;
int ret = trie_del(params, key, len, &param_ptr);
if (ret != KNOT_EOK)
return kr_error(ret);
tls_client_param_unref(param_ptr);
return kr_ok();
}
static void log_all_pins(tls_client_param_t *params)
{
uint8_t buffer[TLS_SHA256_BASE64_BUFLEN + 1];
for (int i = 0; i < params->pins.len; i++) {
int len = kr_base64_encode(params->pins.at[i], TLS_SHA256_RAW_LEN,
buffer, TLS_SHA256_BASE64_BUFLEN);
if (!kr_fails_assert(len > 0)) {
buffer[len] = '\0';
kr_log_error(TLSCLIENT, "pin no. %d: %s\n", i, buffer);
}
}
}
static void log_all_certificates(const unsigned int cert_list_size,
const gnutls_datum_t *cert_list)
{
for (int i = 0; i < cert_list_size; i++) {
gnutls_x509_crt_t cert;
if (gnutls_x509_crt_init(&cert) != GNUTLS_E_SUCCESS) {
return;
}
if (gnutls_x509_crt_import(cert, &cert_list[i], GNUTLS_X509_FMT_DER) != GNUTLS_E_SUCCESS) {
gnutls_x509_crt_deinit(cert);
return;
}
char cert_pin[TLS_SHA256_BASE64_BUFLEN];
if (get_oob_key_pin(cert, cert_pin, sizeof(cert_pin), false) != GNUTLS_E_SUCCESS) {
gnutls_x509_crt_deinit(cert);
return;
}
kr_log_error(TLSCLIENT, "Certificate: %s\n", cert_pin);
gnutls_x509_crt_deinit(cert);
}
}
/**
* Verify that at least one certificate in the certificate chain matches
* at least one certificate pin in the non-empty params->pins array.
* \returns GNUTLS_E_SUCCESS if pin matches, any other value is an error
*/
static int client_verify_pin(const unsigned int cert_list_size,
const gnutls_datum_t *cert_list,
tls_client_param_t *params)
{
if (kr_fails_assert(params->pins.len > 0))
return GNUTLS_E_CERTIFICATE_ERROR;
for (int i = 0; i < cert_list_size; i++) {
gnutls_x509_crt_t cert;
int ret = gnutls_x509_crt_init(&cert);
if (ret != GNUTLS_E_SUCCESS) {
return ret;
}
ret = gnutls_x509_crt_import(cert, &cert_list[i], GNUTLS_X509_FMT_DER);
if (ret != GNUTLS_E_SUCCESS) {
gnutls_x509_crt_deinit(cert);
return ret;
}
char cert_pin[TLS_SHA256_RAW_LEN];
/* Get raw pin and compare. */
ret = get_oob_key_pin(cert, cert_pin, sizeof(cert_pin), true);
gnutls_x509_crt_deinit(cert);
if (ret != GNUTLS_E_SUCCESS) {
return ret;
}
for (size_t j = 0; j < params->pins.len; ++j) {
const uint8_t *pin = params->pins.at[j];
if (memcmp(cert_pin, pin, TLS_SHA256_RAW_LEN) != 0)
continue; /* mismatch */
VERBOSE_MSG(true, "matched a configured pin no. %zd\n", j);
return GNUTLS_E_SUCCESS;
}
VERBOSE_MSG(true, "none of %zd configured pin(s) matched\n",
params->pins.len);
}
kr_log_error(TLSCLIENT, "no pin matched: %zu pins * %d certificates\n",
params->pins.len, cert_list_size);
log_all_pins(params);
log_all_certificates(cert_list_size, cert_list);
return GNUTLS_E_CERTIFICATE_ERROR;
}
/**
* Verify that \param tls_session contains a valid X.509 certificate chain
* with given hostname.
*
* \returns GNUTLS_E_SUCCESS if certificate chain is valid, any other value is an error
*/
static int client_verify_certchain(struct pl_tls_sess_data *tls, const char *hostname)
{
if (kr_fails_assert(hostname)) {
kr_log_error(TLSCLIENT, "internal config inconsistency: no hostname set\n");
return GNUTLS_E_CERTIFICATE_ERROR;
}
unsigned int status;
int ret = gnutls_certificate_verify_peers3(tls->tls_session, hostname, &status);
if ((ret == GNUTLS_E_SUCCESS) && (status == 0)) {
return GNUTLS_E_SUCCESS;
}
const char *addr_str = kr_straddr(session2_get_peer(tls->h.session));
if (ret == GNUTLS_E_SUCCESS) {
gnutls_datum_t msg;
ret = gnutls_certificate_verification_status_print(
status, gnutls_certificate_type_get(tls->tls_session), &msg, 0);
if (ret == GNUTLS_E_SUCCESS) {
kr_log_error(TLSCLIENT, "failed to verify peer certificate of %s: "
"%s\n", addr_str, msg.data);
gnutls_free(msg.data);
} else {
kr_log_error(TLSCLIENT, "failed to verify peer certificate of %s: "
"unable to print reason: %s (%s)\n",
addr_str,
gnutls_strerror(ret), gnutls_strerror_name(ret));
} /* gnutls_certificate_verification_status_print end */
} else {
kr_log_error(TLSCLIENT, "failed to verify peer certificate of %s: "
"gnutls_certificate_verify_peers3 error: %s (%s)\n",
addr_str,
gnutls_strerror(ret), gnutls_strerror_name(ret));
} /* gnutls_certificate_verify_peers3 end */
return GNUTLS_E_CERTIFICATE_ERROR;
}
/**
* Verify that actual TLS security parameters of \param tls_session
* match requirements provided by user in tls_session->params.
* \returns GNUTLS_E_SUCCESS if requirements were met, any other value is an error
*/
static int client_verify_certificate(gnutls_session_t tls_session)
{
struct pl_tls_sess_data *tls = gnutls_session_get_ptr(tls_session);
if (kr_fails_assert(tls->client_params))
return GNUTLS_E_CERTIFICATE_ERROR;
if (tls->client_params->insecure) {
return GNUTLS_E_SUCCESS;
}
gnutls_certificate_type_t cert_type = gnutls_certificate_type_get(tls_session);
if (cert_type != GNUTLS_CRT_X509) {
kr_log_error(TLSCLIENT, "invalid certificate type %i has been received\n",
cert_type);
return GNUTLS_E_CERTIFICATE_ERROR;
}
unsigned int cert_list_size = 0;
const gnutls_datum_t *cert_list =
gnutls_certificate_get_peers(tls_session, &cert_list_size);
if (cert_list == NULL || cert_list_size == 0) {
kr_log_error(TLSCLIENT, "empty certificate list\n");
return GNUTLS_E_CERTIFICATE_ERROR;
}
if (tls->client_params->pins.len > 0)
/* check hash of the certificate but ignore everything else */
return client_verify_pin(cert_list_size, cert_list, tls->client_params);
else
return client_verify_certchain(tls, tls->client_params->hostname);
}
static int tls_pull_timeout_func(gnutls_transport_ptr_t h, unsigned int ms)
{
struct pl_tls_sess_data *tls = h;
if (kr_fails_assert(tls)) {
errno = EFAULT;
return -1;
}
size_t avail = protolayer_queue_count_payload(&tls->unwrap_queue);
VERBOSE_MSG(tls->client_side, "timeout check: available: %zu\n", avail);
if (!avail) {
errno = EAGAIN;
return -1;
}
return avail;
}
static int pl_tls_sess_data_deinit(struct pl_tls_sess_data *tls)
{
if (tls->tls_session) {
/* Don't terminate TLS connection, just tear it down */
gnutls_deinit(tls->tls_session);
tls->tls_session = NULL;
}
if (tls->client_side) {
tls_client_param_unref(tls->client_params);
} else {
tls_credentials_release(tls->server_credentials);
}
wire_buf_deinit(&tls->unwrap_buf);
while (queue_len(tls->unwrap_queue) > 0) {
struct protolayer_iter_ctx *ctx = queue_head(tls->unwrap_queue);
protolayer_break(ctx, kr_error(EIO));
queue_pop(tls->unwrap_queue);
}
queue_deinit(tls->unwrap_queue);
while (queue_len(tls->wrap_queue)) {
struct protolayer_iter_ctx *ctx = queue_head(tls->wrap_queue);
protolayer_break(ctx, kr_error(EIO));
queue_pop(tls->wrap_queue);
}
queue_deinit(tls->wrap_queue);
return kr_ok();
}
static int pl_tls_sess_server_init(struct session2 *session,
struct pl_tls_sess_data *tls)
{
if (kr_fails_assert(the_worker && the_engine))
return kr_error(EINVAL);
if (!the_network->tls_credentials) {
the_network->tls_credentials = tls_get_ephemeral_credentials();
if (!the_network->tls_credentials) {
kr_log_error(TLS, "X.509 credentials are missing, and ephemeral credentials failed; no TLS\n");
return kr_error(EINVAL);
}
kr_log_info(TLS, "Using ephemeral TLS credentials\n");
tls_credentials_log_pins(the_network->tls_credentials);
}
time_t now = time(NULL);
if (the_network->tls_credentials->valid_until != GNUTLS_X509_NO_WELL_DEFINED_EXPIRATION) {
if (the_network->tls_credentials->ephemeral_servicename) {
/* ephemeral cert: refresh if due to expire within a week */
if (now >= the_network->tls_credentials->valid_until - EPHEMERAL_CERT_EXPIRATION_SECONDS_RENEW_BEFORE) {
struct tls_credentials *newcreds = tls_get_ephemeral_credentials();
if (newcreds) {
tls_credentials_release(the_network->tls_credentials);
the_network->tls_credentials = newcreds;
kr_log_info(TLS, "Renewed expiring ephemeral X.509 cert\n");
} else {
kr_log_error(TLS, "Failed to renew expiring ephemeral X.509 cert, using existing one\n");
}
}
} else {
/* non-ephemeral cert: warn once when certificate expires */
if (now >= the_network->tls_credentials->valid_until) {
kr_log_error(TLS, "X.509 certificate has expired!\n");
the_network->tls_credentials->valid_until = GNUTLS_X509_NO_WELL_DEFINED_EXPIRATION;
}
}
}
int flags = GNUTLS_SERVER | GNUTLS_NONBLOCK;
#if GNUTLS_VERSION_NUMBER >= 0x030705
if (gnutls_check_version("3.7.5"))
flags |= GNUTLS_NO_TICKETS_TLS12;
#endif
int ret = gnutls_init(&tls->tls_session, flags);
if (ret != GNUTLS_E_SUCCESS) {
kr_log_error(TLS, "gnutls_init(): %s (%d)\n", gnutls_strerror_name(ret), ret);
pl_tls_sess_data_deinit(tls);
return ret;
}
tls->server_credentials = tls_credentials_reserve(the_network->tls_credentials);
ret = gnutls_credentials_set(tls->tls_session, GNUTLS_CRD_CERTIFICATE,
tls->server_credentials->credentials);
if (ret != GNUTLS_E_SUCCESS) {
kr_log_error(TLS, "gnutls_credentials_set(): %s (%d)\n", gnutls_strerror_name(ret), ret);
pl_tls_sess_data_deinit(tls);
return ret;
}
ret = kres_gnutls_set_priority(tls->tls_session);
if (ret != GNUTLS_E_SUCCESS) {
pl_tls_sess_data_deinit(tls);
return ret;
}
tls->client_side = false;
wire_buf_init(&tls->unwrap_buf, UNWRAP_BUF_SIZE);
gnutls_transport_set_pull_function(tls->tls_session, kres_gnutls_pull);
gnutls_transport_set_vec_push_function(tls->tls_session, kres_gnutls_vec_push);
gnutls_transport_set_ptr(tls->tls_session, tls);
if (the_network->tls_session_ticket_ctx) {
tls_session_ticket_enable(the_network->tls_session_ticket_ctx,
tls->tls_session);
}
const gnutls_datum_t *alpn = &tls_grp_alpn[session->proto];
if (alpn->size) { /* ALPN is a non-empty string */
flags = 0;
#if GNUTLS_VERSION_NUMBER >= 0x030500
/* Mandatory ALPN means the protocol must match if and
* only if ALPN extension is used by the client. */
flags |= GNUTLS_ALPN_MANDATORY;
#endif
ret = gnutls_alpn_set_protocols(tls->tls_session, alpn, 1, flags);
if (ret != GNUTLS_E_SUCCESS) {
kr_log_error(TLS, "gnutls_alpn_set_protocols(): %s (%d)\n", gnutls_strerror_name(ret), ret);
pl_tls_sess_data_deinit(tls);
return ret;
}
}
return kr_ok();
}
static int pl_tls_sess_client_init(struct session2 *session,
struct pl_tls_sess_data *tls,
tls_client_param_t *param)
{
unsigned int flags = GNUTLS_CLIENT | GNUTLS_NONBLOCK
#ifdef GNUTLS_ENABLE_FALSE_START
| GNUTLS_ENABLE_FALSE_START
#endif
;
#if GNUTLS_VERSION_NUMBER >= 0x030705
if (gnutls_check_version("3.7.5"))
flags |= GNUTLS_NO_TICKETS_TLS12;
#endif
int ret = gnutls_init(&tls->tls_session, flags);
if (ret != GNUTLS_E_SUCCESS) {
pl_tls_sess_data_deinit(tls);
return ret;
}
ret = kres_gnutls_set_priority(tls->tls_session);
if (ret != GNUTLS_E_SUCCESS) {
pl_tls_sess_data_deinit(tls);
return ret;
}
/* Must take a reference on parameters as the credentials are owned by it
* and must not be freed while the session is active. */
++(param->refs);
tls->client_params = param;
ret = gnutls_credentials_set(tls->tls_session, GNUTLS_CRD_CERTIFICATE,
param->credentials);
if (ret == GNUTLS_E_SUCCESS && param->hostname) {
ret = gnutls_server_name_set(tls->tls_session, GNUTLS_NAME_DNS,
param->hostname, strlen(param->hostname));
kr_log_debug(TLSCLIENT, "set hostname, ret = %d\n", ret);
} else if (!param->hostname) {
kr_log_debug(TLSCLIENT, "no hostname\n");
}
if (ret != GNUTLS_E_SUCCESS) {
pl_tls_sess_data_deinit(tls);
return ret;
}
tls->client_side = true;
wire_buf_init(&tls->unwrap_buf, UNWRAP_BUF_SIZE);
gnutls_transport_set_pull_function(tls->tls_session, kres_gnutls_pull);
gnutls_transport_set_vec_push_function(tls->tls_session, kres_gnutls_vec_push);
gnutls_transport_set_ptr(tls->tls_session, tls);
return kr_ok();
}
static int pl_tls_sess_init(struct session2 *session,
void *sess_data,
void *param)
{
struct pl_tls_sess_data *tls = sess_data;
session->secure = true;
queue_init(tls->unwrap_queue);
queue_init(tls->wrap_queue);
if (session->outgoing)
return pl_tls_sess_client_init(session, tls, param);
else
return pl_tls_sess_server_init(session, tls);
}
static int pl_tls_sess_deinit(struct session2 *session,
void *sess_data)
{
return pl_tls_sess_data_deinit(sess_data);
}
static enum protolayer_iter_cb_result pl_tls_unwrap(void *sess_data, void *iter_data,
struct protolayer_iter_ctx *ctx)
{
int brstatus = kr_ok();
struct pl_tls_sess_data *tls = sess_data;
struct session2 *s = ctx->session;
queue_push(tls->unwrap_queue, ctx);
/* Ensure TLS handshake is performed before receiving data.
* See https://www.gnutls.org/manual/html_node/TLS-handshake.html */
while (tls->handshake_state <= TLS_HS_IN_PROGRESS) {
int err = tls_handshake(tls, s);
if (err == kr_error(EAGAIN)) {
return protolayer_async(); /* Wait for more data */
} else if (err != kr_ok()) {
brstatus = err;
goto exit_break;
}
}
/* See https://gnutls.org/manual/html_node/Data-transfer-and-termination.html#Data-transfer-and-termination */
while (true) {
ssize_t count = gnutls_record_recv(tls->tls_session,
wire_buf_free_space(&tls->unwrap_buf),
wire_buf_free_space_length(&tls->unwrap_buf));
if (count == GNUTLS_E_AGAIN) {
if (!protolayer_queue_has_payload(&tls->unwrap_queue)) {
/* See https://www.gnutls.org/manual/html_node/Asynchronous-operation.html */
break;
}
continue;
} else if (count == GNUTLS_E_INTERRUPTED) {
continue;
} else if (count == GNUTLS_E_REHANDSHAKE) {
/* See https://www.gnutls.org/manual/html_node/Re_002dauthentication.html */
struct sockaddr *peer = session2_get_peer(s);
VERBOSE_MSG(tls->client_side, "TLS rehandshake with %s has started\n",
kr_straddr(peer));
tls->handshake_state = TLS_HS_IN_PROGRESS;
int err = kr_ok();
while (tls->handshake_state <= TLS_HS_IN_PROGRESS) {
err = tls_handshake(tls, s);
if (err == kr_error(EAGAIN)) {
break;
} else if (err != kr_ok()) {
brstatus = err;
goto exit_break;
}
}
if (err == kr_error(EAGAIN)) {
/* pull function is out of data */
break;
}
/* There are can be data available, check it. */
continue;
} else if (count < 0) {
VERBOSE_MSG(tls->client_side, "gnutls_record_recv failed: %s (%zd)\n",
gnutls_strerror_name(count), count);
brstatus = kr_error(EIO);
goto exit_break;
} else if (count == 0) {
break;
}
VERBOSE_MSG(tls->client_side, "received %zd data\n", count);
wire_buf_consume(&tls->unwrap_buf, count);
if (wire_buf_free_space_length(&tls->unwrap_buf) == 0 && protolayer_queue_has_payload(&tls->unwrap_queue) > 0) {
/* wire buffer is full but not all data was consumed */
brstatus = kr_error(ENOSPC);
goto exit_break;
}
if (kr_fails_assert(queue_len(tls->unwrap_queue) == 1)) {
brstatus = kr_error(EINVAL);
goto exit_break;
}
struct protolayer_iter_ctx *ctx_head = queue_head(tls->unwrap_queue);
if (kr_fails_assert(ctx == ctx_head)) {
protolayer_break(ctx, kr_error(EINVAL));
ctx = ctx_head;
}
}
/* Here all data must be consumed. */
if (protolayer_queue_has_payload(&tls->unwrap_queue)) {
/* Something went wrong, better return error.
* This is most probably due to gnutls_record_recv() did not
* consume all available network data by calling kres_gnutls_pull().
* TODO assess the need for buffering of data amount.
*/
brstatus = kr_error(ENOSPC);
goto exit_break;
}
struct protolayer_iter_ctx *ctx_head = queue_head(tls->unwrap_queue);
if (!kr_fails_assert(ctx == ctx_head))
queue_pop(tls->unwrap_queue);
ctx->payload = protolayer_payload_wire_buf(&tls->unwrap_buf, false);
return protolayer_continue(ctx);
exit_break:
ctx_head = queue_head(tls->unwrap_queue);
if (!kr_fails_assert(ctx == ctx_head))
queue_pop(tls->unwrap_queue);
return protolayer_break(ctx, brstatus);
}
static ssize_t pl_tls_submit(gnutls_session_t tls_session,
struct protolayer_payload payload)
{
if (payload.type == PROTOLAYER_PAYLOAD_WIRE_BUF)
payload = protolayer_payload_as_buffer(&payload);
// TODO: the handling of positive gnutls_record_send() is weird/confusing,
// but it seems caught later when checking gnutls_record_uncork()
if (payload.type == PROTOLAYER_PAYLOAD_BUFFER) {
ssize_t count = gnutls_record_send(tls_session,
payload.buffer.buf, payload.buffer.len);
if (count < 0)
return count;
return payload.buffer.len;
} else if (payload.type == PROTOLAYER_PAYLOAD_IOVEC) {
ssize_t total_submitted = 0;
for (int i = 0; i < payload.iovec.cnt; i++) {
struct iovec iov = payload.iovec.iov[i];
ssize_t count = gnutls_record_send(tls_session,
iov.iov_base, iov.iov_len);
if (count < 0)
return count;
total_submitted += iov.iov_len;
}
return total_submitted;
}
kr_assert(false && "Invalid payload");
return kr_error(EINVAL);
}
static enum protolayer_iter_cb_result pl_tls_wrap(
void *sess_data, void *iter_data,
struct protolayer_iter_ctx *ctx)
{
struct pl_tls_sess_data *tls = sess_data;
gnutls_session_t tls_session = tls->tls_session;
gnutls_record_cork(tls_session);
ssize_t submitted = pl_tls_submit(tls_session, ctx->payload);
if (submitted < 0) {
VERBOSE_MSG(tls->client_side, "pl_tls_submit failed: %s (%zd)\n",
gnutls_strerror_name(submitted), submitted);
return protolayer_break(ctx, submitted);
}
queue_push(tls->wrap_queue, ctx);
int ret = gnutls_record_uncork(tls_session, GNUTLS_RECORD_WAIT);
if (ret < 0) {
if (!gnutls_error_is_fatal(ret)) {
queue_pop(tls->wrap_queue);
return protolayer_break(ctx, kr_error(EAGAIN));
} else {
queue_pop(tls->wrap_queue);
VERBOSE_MSG(tls->client_side, "gnutls_record_uncork failed: %s (%d)\n",
gnutls_strerror_name(ret), ret);
return protolayer_break(ctx, kr_error(EIO));
}
}
if (ret != submitted) {
kr_log_error(TLS, "gnutls_record_uncork didn't send all data (%d of %zd)\n", ret, submitted);
return protolayer_break(ctx, kr_error(EIO));
}
return protolayer_async();
}
static enum protolayer_event_cb_result pl_tls_client_connect_start(
struct pl_tls_sess_data *tls, struct session2 *session)
{
if (tls->handshake_state != TLS_HS_NOT_STARTED)
return PROTOLAYER_EVENT_CONSUME;
if (kr_fails_assert(session->outgoing))
return PROTOLAYER_EVENT_CONSUME;
gnutls_session_set_ptr(tls->tls_session, tls);
gnutls_handshake_set_timeout(tls->tls_session, the_network->tcp.tls_handshake_timeout);
gnutls_transport_set_pull_timeout_function(tls->tls_session, tls_pull_timeout_func);
tls->handshake_state = TLS_HS_IN_PROGRESS;
tls_client_param_t *tls_params = tls->client_params;
if (tls_params->session_data.data != NULL) {
gnutls_session_set_data(tls->tls_session, tls_params->session_data.data,
tls_params->session_data.size);
}
/* See https://www.gnutls.org/manual/html_node/Asynchronous-operation.html */
while (tls->handshake_state <= TLS_HS_IN_PROGRESS) {
int ret = tls_handshake(tls, session);
if (ret != kr_ok()) {
if (ret == kr_error(EAGAIN)) {
session2_timer_stop(session);
session2_timer_start(session,
PROTOLAYER_EVENT_GENERAL_TIMEOUT,
MAX_TCP_INACTIVITY, MAX_TCP_INACTIVITY);
}
return PROTOLAYER_EVENT_CONSUME;
}
}
return PROTOLAYER_EVENT_CONSUME;
}
static enum protolayer_event_cb_result pl_tls_event_unwrap(
enum protolayer_event_type event, void **baton,
struct session2 *s, void *sess_data)
{
struct pl_tls_sess_data *tls = sess_data;
if (event == PROTOLAYER_EVENT_CLOSE) {
tls_close(tls, s, true); /* WITH gnutls_bye */
return PROTOLAYER_EVENT_PROPAGATE;
}
if (event == PROTOLAYER_EVENT_FORCE_CLOSE) {
tls_close(tls, s, false); /* WITHOUT gnutls_bye */
return PROTOLAYER_EVENT_PROPAGATE;
}
if (event == PROTOLAYER_EVENT_EOF) {
// TCP half-closed state not allowed
session2_force_close(s);
return PROTOLAYER_EVENT_CONSUME;
}
if (tls->client_side) {
if (event == PROTOLAYER_EVENT_CONNECT)
return pl_tls_client_connect_start(tls, s);
} else {
if (event == PROTOLAYER_EVENT_CONNECT) {
/* TLS sends its own _CONNECT event when the handshake
* is finished. */
return PROTOLAYER_EVENT_CONSUME;
}
}
return PROTOLAYER_EVENT_PROPAGATE;
}
static enum protolayer_event_cb_result pl_tls_event_wrap(
enum protolayer_event_type event, void **baton,
struct session2 *session, void *sess_data)
{
if (event == PROTOLAYER_EVENT_STATS_SEND_ERR) {
the_worker->stats.err_tls += 1;
return PROTOLAYER_EVENT_CONSUME;
} else if (event == PROTOLAYER_EVENT_STATS_QRY_OUT) {
the_worker->stats.tls += 1;
return PROTOLAYER_EVENT_CONSUME;
}
return PROTOLAYER_EVENT_PROPAGATE;
}
static void pl_tls_request_init(struct session2 *session,
struct kr_request *req,
void *sess_data)
{
req->qsource.comm_flags.tls = true;
}
__attribute__((constructor))
static void tls_protolayers_init(void)
{
protolayer_globals[PROTOLAYER_TYPE_TLS] = (struct protolayer_globals){
.sess_size = sizeof(struct pl_tls_sess_data),
.sess_deinit = pl_tls_sess_deinit,
.wire_buf_overhead = TLS_CHUNK_SIZE,
.sess_init = pl_tls_sess_init,
.unwrap = pl_tls_unwrap,
.wrap = pl_tls_wrap,
.event_unwrap = pl_tls_event_unwrap,
.event_wrap = pl_tls_event_wrap,
.request_init = pl_tls_request_init
};
}
#undef VERBOSE_MSG
/* Copyright (C) 2016 American Civil Liberties Union (ACLU)
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
/* Copyright (C) 2016 American Civil Liberties Union (ACLU)
* Copyright (C) CZ.NIC, z.s.p.o
* SPDX-License-Identifier: GPL-3.0-or-later
*/
#pragma once
#include <uv.h>
#include <gnutls/gnutls.h>
#include <libknot/packet/pkt.h>
struct tls_ctx_t;
struct tls_credentials;
#include "lib/defines.h"
#include "lib/generic/array.h"
#include "lib/generic/trie.h"
#include "lib/utils.h"
#define MAX_TLS_PADDING KR_EDNS_PAYLOAD
#define TLS_MAX_UNCORK_RETRIES 100
/* rfc 5476, 7.3 - handshake Protocol overview
* https://tools.ietf.org/html/rfc5246#page-33
* Message flow for a full handshake (only mandatory messages)
* ClientHello -------->
ServerHello
<-------- ServerHelloDone
ClientKeyExchange
Finished -------->
<-------- Finished
*
* See also https://blog.cloudflare.com/keyless-ssl-the-nitty-gritty-technical-details/
* So it takes 2 RTT.
* As we use session tickets, there are additional messages, add one RTT mode.
*/
#define TLS_MAX_HANDSHAKE_TIME (KR_CONN_RTT_MAX * (uint64_t)3)
/** Transport session (opaque). */
struct session2;
struct tls_ctx;
struct tls_client_ctx;
struct tls_credentials {
int count;
char *tls_cert;
char *tls_key;
gnutls_certificate_credentials_t credentials;
time_t valid_until;
char *ephemeral_servicename;
};
/*! Toggle verbose logging from TLS context. */
void tls_setup_logging(bool verbose);
/*! Create an empty TLS context in query context */
struct tls_ctx_t* tls_new(struct worker_ctx *worker);
#define TLS_SHA256_RAW_LEN 32 /* gnutls_hash_get_len(GNUTLS_DIG_SHA256) */
/** Required buffer length for pin_sha256, including the zero terminator. */
#define TLS_SHA256_BASE64_BUFLEN (((TLS_SHA256_RAW_LEN * 8 + 4) / 6) + 3 + 1)
/** TLS authentication parameters for a single address-port pair. */
typedef struct {
uint32_t refs; /**< Reference count; consider TLS sessions in progress. */
bool insecure; /**< Use no authentication. */
const char *hostname; /**< Server name for SNI and certificate check, lowercased. */
array_t(const char *) ca_files; /**< Paths to certificate files; not really used. */
array_t(const uint8_t *) pins; /**< Certificate pins as raw unterminated strings.*/
gnutls_certificate_credentials_t credentials; /**< CA creds. in gnutls format. */
gnutls_datum_t session_data; /**< Session-resumption data gets stored here. */
} tls_client_param_t;
/** Holds configuration for TLS authentication for all potential servers.
* Special case: NULL pointer also means empty. */
typedef trie_t tls_client_params_t;
/** Get a pointer-to-pointer to TLS auth params.
* If it didn't exist, it returns NULL (if !do_insert) or pointer to NULL. */
tls_client_param_t **tls_client_param_getptr(tls_client_params_t **params,
const struct sockaddr *addr, bool do_insert);
/** Get a pointer to TLS auth params or NULL. */
static inline tls_client_param_t *
tls_client_param_get(tls_client_params_t *params, const struct sockaddr *addr)
{
tls_client_param_t **pe = tls_client_param_getptr(&params, addr, false);
return pe ? *pe : NULL;
}
/** Allocate and initialize the structure (with ->ref = 1). */
tls_client_param_t * tls_client_param_new(void);
/** Reference-counted free(); any inside data is freed alongside. */
void tls_client_param_unref(tls_client_param_t *entry);
int tls_client_param_remove(tls_client_params_t *params, const struct sockaddr *addr);
/** Free TLS authentication parameters. */
void tls_client_params_free(tls_client_params_t *params);
/*! Close a TLS context */
void tls_free(struct tls_ctx_t* tls);
/*! Set TLS certificate and key from files. */
int tls_certificate_set(const char *tls_cert, const char *tls_key);
/*! Push new data to TLS context for sending */
int tls_push(struct qr_task *task, uv_handle_t* handle, knot_pkt_t * pkt);
/*! Release TLS credentials for context (decrements refcount or frees). */
int tls_credentials_release(struct tls_credentials *tls_credentials);
/*! Unwrap incoming data from a TLS stream and pass them to TCP session. */
int tls_process(struct worker_ctx *worker, uv_stream_t *handle, const uint8_t *buf, ssize_t nread);
/*! Generate new ephemeral TLS credentials. */
struct tls_credentials * tls_get_ephemeral_credentials(void);
/*! Set TLS certificate and key from files. */
int tls_certificate_set(struct network *net, const char *tls_cert, const char *tls_key);
/* Session tickets, server side. Implementation in ./tls_session_ticket-srv.c */
/*! Borrow TLS credentials for context. */
struct tls_credentials *tls_credentials_reserve(struct tls_credentials *worker);
/*! Opaque struct used by tls_session_ticket_* functions. */
struct tls_session_ticket_ctx;
/*! Release TLS credentials for context (decrements refcount or frees). */
int tls_credentials_release(struct tls_credentials *tls_credentials);
/*! Suggested maximum reasonable secret length. */
#define TLS_SESSION_TICKET_SECRET_MAX_LEN 1024
/*! Create a session ticket context and initialize it (secret gets copied inside).
*
* Passing zero-length secret implies using a random key, i.e. not synchronized
* between multiple instances.
*
* Beware that knowledge of the secret (if nonempty) breaks forward secrecy,
* so you should rotate the secret regularly and securely erase all past secrets.
* With TLS < 1.3 it's probably too risky to set nonempty secret.
*/
struct tls_session_ticket_ctx * tls_session_ticket_ctx_create(
uv_loop_t *loop, const char *secret, size_t secret_len);
/*! Try to enable session tickets for a server session. */
void tls_session_ticket_enable(struct tls_session_ticket_ctx *ctx, gnutls_session_t session);
/*! Free all resources of the session ticket context. NULL is accepted as well. */
void tls_session_ticket_ctx_destroy(struct tls_session_ticket_ctx *ctx);
/*! Free TLS credentials, must not be called if it holds positive refcount. */
/*! Free TLS credentials. */
void tls_credentials_free(struct tls_credentials *tls_credentials);
/*! Log DNS-over-TLS OOB key-pin form of current credentials:
* https://tools.ietf.org/html/rfc7858#appendix-A */
void tls_credentials_log_pins(struct tls_credentials *tls_credentials);
/*
* Copyright (C) 2016 American Civil Liberties Union (ACLU)
* Copyright (C) CZ.NIC, z.s.p.o.
*
* Initial Author: Daniel Kahn Gillmor <dkg@fifthhorseman.net>
* SPDX-License-Identifier: GPL-3.0-or-later
*/
#include <sys/file.h>
#include <unistd.h>
#include <gnutls/gnutls.h>
#include <gnutls/x509.h>
#include <gnutls/crypto.h>
#include "daemon/engine.h"
#include "daemon/tls.h"
#define EPHEMERAL_PRIVKEY_FILENAME "ephemeral_key.pem"
#define INVALID_HOSTNAME "dns-over-tls.invalid"
#define EPHEMERAL_CERT_EXPIRATION_SECONDS ((time_t)60*60*24*90)
/* This is an attempt to grab an exclusive, advisory, non-blocking
* lock based on a filename. At the moment it's POSIX-only, but it
* should be abstract enough of an interface to make an implementation
* for non-posix systems if anyone cares. */
typedef int lock_t;
static bool lock_is_invalid(lock_t lock)
{
return lock == -1;
}
/* a blocking lock on a given filename */
static lock_t lock_filename(const char *fname)
{
lock_t lockfd = open(fname, O_RDONLY|O_CREAT, 0400);
if (lockfd == -1)
return lockfd;
/* this should be a non-blocking lock */
if (flock(lockfd, LOCK_EX | LOCK_NB) != 0) {
close(lockfd);
return -1;
}
return lockfd; /* for cleanup later */
}
static void lock_unlock(lock_t *lock, const char *fname)
{
if (lock && !lock_is_invalid(*lock)) {
flock(*lock, LOCK_UN);
close(*lock);
*lock = -1;
unlink(fname); /* ignore errors */
}
}
static gnutls_x509_privkey_t get_ephemeral_privkey (void)
{
gnutls_x509_privkey_t privkey = NULL;
int err;
gnutls_datum_t data = { .data = NULL, .size = 0 };
lock_t lock;
int datafd = -1;
/* Take a lock to ensure that two daemons started concurrently
* with a shared cache don't both create the same privkey: */
lock = lock_filename(EPHEMERAL_PRIVKEY_FILENAME ".lock");
if (lock_is_invalid(lock)) {
kr_log_error(TLS, "unable to lock lockfile " EPHEMERAL_PRIVKEY_FILENAME ".lock\n");
goto done;
}
if ((err = gnutls_x509_privkey_init (&privkey)) < 0) {
kr_log_error(TLS, "gnutls_x509_privkey_init() failed: %d (%s)\n",
err, gnutls_strerror_name(err));
goto done;
}
/* read from cache file (we assume that we've chdir'ed
* already, so we're just looking for the file in the
* cachedir. */
datafd = open(EPHEMERAL_PRIVKEY_FILENAME, O_RDONLY);
if (datafd != -1) {
struct stat stat;
ssize_t bytes_read;
if (fstat(datafd, &stat)) {
kr_log_error(TLS, "unable to stat ephemeral private key " EPHEMERAL_PRIVKEY_FILENAME "\n");
goto bad_data;
}
data.data = gnutls_malloc(stat.st_size);
if (data.data == NULL) {
kr_log_error(TLS, "unable to allocate memory for reading ephemeral private key\n");
goto bad_data;
}
data.size = stat.st_size;
bytes_read = read(datafd, data.data, stat.st_size);
if (bytes_read < 0 || bytes_read != stat.st_size) {
kr_log_error(TLS, "unable to read ephemeral private key\n");
goto bad_data;
}
if ((err = gnutls_x509_privkey_import (privkey, &data, GNUTLS_X509_FMT_PEM)) < 0) {
kr_log_error(TLS, "gnutls_x509_privkey_import() failed: %d (%s)\n",
err, gnutls_strerror_name(err));
/* goto bad_data; */
bad_data:
close(datafd);
datafd = -1;
}
if (data.data != NULL) {
gnutls_free(data.data);
data.data = NULL;
}
}
if (datafd == -1) {
/* if loading failed, then generate ... */
#if GNUTLS_VERSION_NUMBER >= 0x030500
if ((err = gnutls_x509_privkey_generate(privkey, GNUTLS_PK_ECDSA, GNUTLS_CURVE_TO_BITS(GNUTLS_ECC_CURVE_SECP256R1), 0)) < 0) {
#else
if ((err = gnutls_x509_privkey_generate(privkey, GNUTLS_PK_RSA, gnutls_sec_param_to_pk_bits(GNUTLS_PK_RSA, GNUTLS_SEC_PARAM_MEDIUM), 0)) < 0) {
#endif
kr_log_error(TLS, "gnutls_x509_privkey_init() failed: %d (%s)\n",
err, gnutls_strerror_name(err));
gnutls_x509_privkey_deinit(privkey);
goto done;
}
/* ... and save */
kr_log_info(TLS, "Stashing ephemeral private key in " EPHEMERAL_PRIVKEY_FILENAME "\n");
if ((err = gnutls_x509_privkey_export2(privkey, GNUTLS_X509_FMT_PEM, &data)) < 0) {
kr_log_error(TLS, "gnutls_x509_privkey_export2() failed: %d (%s), not storing\n",
err, gnutls_strerror_name(err));
} else {
datafd = open(EPHEMERAL_PRIVKEY_FILENAME, O_WRONLY|O_CREAT, 0600);
if (datafd == -1) {
kr_log_error(TLS, "failed to open " EPHEMERAL_PRIVKEY_FILENAME " to store the ephemeral key\n");
} else {
ssize_t bytes_written;
bytes_written = write(datafd, data.data, data.size);
if (bytes_written != data.size)
kr_log_error(TLS, "failed to write %d octets to "
EPHEMERAL_PRIVKEY_FILENAME
" (%zd written)\n",
data.size, bytes_written);
}
}
}
done:
lock_unlock(&lock, EPHEMERAL_PRIVKEY_FILENAME ".lock");
if (datafd != -1) {
close(datafd);
}
if (data.data != NULL) {
gnutls_free(data.data);
}
return privkey;
}
static gnutls_x509_crt_t get_ephemeral_cert(gnutls_x509_privkey_t privkey, const char *servicename, time_t invalid_before, time_t valid_until)
{
gnutls_x509_crt_t cert = NULL;
int err;
/* need a random buffer of bytes */
uint8_t serial[16];
gnutls_rnd(GNUTLS_RND_NONCE, serial, sizeof(serial));
/* clear the left-most bit to avoid signedness confusion: */
serial[0] &= 0x7f;
size_t namelen = strlen(servicename);
#define gtx(fn, ...) \
if ((err = fn ( __VA_ARGS__ )) != GNUTLS_E_SUCCESS) { \
kr_log_error(TLS, #fn "() failed: %d (%s)\n", \
err, gnutls_strerror_name(err)); \
goto bad; }
gtx(gnutls_x509_crt_init, &cert);
gtx(gnutls_x509_crt_set_activation_time, cert, invalid_before);
gtx(gnutls_x509_crt_set_ca_status, cert, 0);
gtx(gnutls_x509_crt_set_expiration_time, cert, valid_until);
gtx(gnutls_x509_crt_set_key, cert, privkey);
gtx(gnutls_x509_crt_set_key_purpose_oid, cert, GNUTLS_KP_TLS_WWW_CLIENT, 0);
gtx(gnutls_x509_crt_set_key_purpose_oid, cert, GNUTLS_KP_TLS_WWW_SERVER, 0);
gtx(gnutls_x509_crt_set_key_usage, cert, GNUTLS_KEY_DIGITAL_SIGNATURE);
gtx(gnutls_x509_crt_set_serial, cert, serial, sizeof(serial));
gtx(gnutls_x509_crt_set_subject_alt_name, cert, GNUTLS_SAN_DNSNAME, servicename, namelen, GNUTLS_FSAN_SET);
gtx(gnutls_x509_crt_set_dn_by_oid,cert, GNUTLS_OID_X520_COMMON_NAME, 0, servicename, namelen);
gtx(gnutls_x509_crt_set_version, cert, 3);
gtx(gnutls_x509_crt_sign2,cert, cert, privkey, GNUTLS_DIG_SHA256, 0); /* self-sign, since it doesn't look like we can just stub-sign */
#undef gtx
return cert;
bad:
gnutls_x509_crt_deinit(cert);
return NULL;
}
/*! Generate new ephemeral TLS credentials. */
struct tls_credentials * tls_get_ephemeral_credentials(void)
{
struct tls_credentials *creds = NULL;
gnutls_x509_privkey_t privkey = NULL;
gnutls_x509_crt_t cert = NULL;
int err;
time_t now = time(NULL);
creds = calloc(1, sizeof(*creds));
if (!creds) {
kr_log_error(TLS, "failed to allocate memory for ephemeral credentials\n");
return NULL;
}
if ((err = gnutls_certificate_allocate_credentials(&(creds->credentials))) < 0) {
kr_log_error(TLS, "failed to allocate memory for ephemeral credentials\n");
goto failure;
}
creds->valid_until = now + EPHEMERAL_CERT_EXPIRATION_SECONDS;
creds->ephemeral_servicename = strdup(engine_get_hostname());
if (creds->ephemeral_servicename == NULL) {
kr_log_error(TLS, "could not get server's hostname, using '" INVALID_HOSTNAME "' instead\n");
if ((creds->ephemeral_servicename = strdup(INVALID_HOSTNAME)) == NULL) {
kr_log_error(TLS, "failed to allocate memory for ephemeral credentials\n");
goto failure;
}
}
if ((privkey = get_ephemeral_privkey()) == NULL) {
goto failure;
}
if ((cert = get_ephemeral_cert(privkey, creds->ephemeral_servicename, now - ((time_t)60 * 15), creds->valid_until)) == NULL) {
goto failure;
}
if ((err = gnutls_certificate_set_x509_key(creds->credentials, &cert, 1, privkey)) < 0) {
kr_log_error(TLS, "failed to set up ephemeral credentials\n");
goto failure;
}
gnutls_x509_privkey_deinit(privkey);
gnutls_x509_crt_deinit(cert);
return creds;
failure:
gnutls_x509_privkey_deinit(privkey);
gnutls_x509_crt_deinit(cert);
tls_credentials_free(creds);
return NULL;
}
/* Copyright (C) CZ.NIC, z.s.p.o. <knot-resolver@labs.nic.cz>
* SPDX-License-Identifier: GPL-3.0-or-later
*/
#include <inttypes.h>
#include <stdbool.h>
#include <stdlib.h>
#include <string.h>
#include <sys/time.h>
#include <gnutls/gnutls.h>
#include <gnutls/crypto.h>
#include <uv.h>
#include "lib/utils.h"
/* Style: "local/static" identifiers are usually named tst_* */
/** The number of seconds between synchronized rotation of TLS session ticket key. */
#define TST_KEY_LIFETIME 4096
/** Value from gnutls:lib/ext/session_ticket.c
* Beware: changing this needs to change the hashing implementation. */
#define SESSION_KEY_SIZE 64
/** Compile-time support for setting the secret. */
/* This is not secure with TLS <= 1.2 but TLS 1.3 and secure configuration
* is not available in GnuTLS yet. See https://gitlab.com/gnutls/gnutls/issues/477 */
#define TLS_SESSION_RESUMPTION_SYNC (GNUTLS_VERSION_NUMBER >= 0x030603)
#if TLS_SESSION_RESUMPTION_SYNC
#define TST_HASH GNUTLS_DIG_SHA3_512
#else
#define TST_HASH abort()
#endif
/** Fields are internal to tst_key_* functions. */
typedef struct tls_session_ticket_ctx {
uv_timer_t timer; /**< timer for rotation of the key */
unsigned char key[SESSION_KEY_SIZE]; /**< the key itself */
bool has_secret; /**< false -> key is random for each epoch */
uint16_t hash_len; /**< length of `hash_data` */
char hash_data[]; /**< data to hash to obtain `key`;
* it's `time_t epoch` and then the secret string */
} tst_ctx_t;
/** Check invariants, based on gnutls version. */
static bool tst_key_invariants(void)
{
static int result = 0; /*< cache for multiple invocations */
if (result) return result > 0;
bool ok = true;
#if TLS_SESSION_RESUMPTION_SYNC
/* SHA3-512 output size may never change, but let's check it anyway :-) */
ok = ok && gnutls_hash_get_len(TST_HASH) == SESSION_KEY_SIZE;
#endif
/* The ticket key size might change in a different gnutls version. */
gnutls_datum_t key = { 0, 0 };
ok = ok && gnutls_session_ticket_key_generate(&key) == 0
&& key.size == SESSION_KEY_SIZE;
free(key.data);
result = ok ? 1 : -1;
return ok;
}
/** Create the internal structures and copy the secret. Beware: secret must be kept secure. */
static tst_ctx_t * tst_key_create(const char *secret, size_t secret_len, uv_loop_t *loop)
{
const size_t hash_len = sizeof(time_t) + secret_len;
if (kr_fails_assert(!secret_len || (secret && hash_len >= secret_len && hash_len <= UINT16_MAX))) {
return NULL;
/* reasonable secret_len is best enforced in config API */
}
if (kr_fails_assert(tst_key_invariants()))
return NULL;
#if !TLS_SESSION_RESUMPTION_SYNC
if (secret_len) {
kr_log_error(TLS, "session ticket: secrets were not enabled at compile-time (your GnuTLS version is not supported)\n");
return NULL; /* ENOTSUP */
}
#endif
tst_ctx_t *ctx = malloc(sizeof(*ctx) + hash_len); /* can be slightly longer */
if (!ctx) return NULL;
ctx->has_secret = secret_len > 0;
ctx->hash_len = hash_len;
if (secret_len) {
memcpy(ctx->hash_data + sizeof(time_t), secret, secret_len);
}
if (uv_timer_init(loop, &ctx->timer) != 0) {
free(ctx);
return NULL;
}
ctx->timer.data = ctx;
return ctx;
}
/** Random variant of secret rotation: generate into key_tmp and copy. */
static int tst_key_get_random(tst_ctx_t *ctx)
{
gnutls_datum_t key_tmp = { NULL, 0 };
int err = gnutls_session_ticket_key_generate(&key_tmp);
if (err) return kr_error(err);
if (kr_fails_assert(key_tmp.size == SESSION_KEY_SIZE))
return kr_error(EFAULT);
memcpy(ctx->key, key_tmp.data, SESSION_KEY_SIZE);
gnutls_memset(key_tmp.data, 0, SESSION_KEY_SIZE);
free(key_tmp.data);
return kr_ok();
}
/** Recompute the session ticket key, if epoch has changed or forced. */
static int tst_key_update(tst_ctx_t *ctx, time_t epoch, bool force_update)
{
if (kr_fails_assert(ctx && ctx->hash_len >= sizeof(epoch)))
return kr_error(EINVAL);
/* documented limitation: time_t and endianness must match
* on instances sharing a secret */
if (!force_update && memcmp(ctx->hash_data, &epoch, sizeof(epoch)) == 0) {
return kr_ok(); /* we are up to date */
}
memcpy(ctx->hash_data, &epoch, sizeof(epoch));
if (!ctx->has_secret) {
return tst_key_get_random(ctx);
}
/* Otherwise, deterministic variant of secret rotation, if supported. */
#if !TLS_SESSION_RESUMPTION_SYNC
kr_assert(!ENOTSUP);
return kr_error(ENOTSUP);
#else
int err = gnutls_hash_fast(TST_HASH, ctx->hash_data,
ctx->hash_len, ctx->key);
return err == 0 ? kr_ok() : kr_error(err);
#endif
}
/** Free all resources of the key (securely). */
static void tst_key_destroy(uv_handle_t *timer)
{
if (kr_fails_assert(timer))
return;
tst_ctx_t *ctx = timer->data;
if (kr_fails_assert(ctx))
return;
gnutls_memset(ctx, 0, offsetof(tst_ctx_t, hash_data) + ctx->hash_len);
free(ctx);
}
static void tst_key_check(uv_timer_t *timer, bool force_update);
static void tst_timer_callback(uv_timer_t *timer)
{
tst_key_check(timer, false);
}
/** Update the ST key if needed and reschedule itself via the timer. */
static void tst_key_check(uv_timer_t *timer, bool force_update)
{
tst_ctx_t *stst = (tst_ctx_t *)timer->data;
/* Compute the current epoch. */
struct timeval now;
if (gettimeofday(&now, NULL)) {
kr_log_error(TLS, "session ticket: gettimeofday failed, %s\n",
strerror(errno));
return;
}
uv_update_time(timer->loop); /* to have sync. between real and mono time */
const time_t epoch = now.tv_sec / TST_KEY_LIFETIME;
/* Update the key; new sessions will fetch it from the location.
* Old ones hopefully can't get broken by that; documentation
* for gnutls_session_ticket_enable_server() doesn't say. */
int err = tst_key_update(stst, epoch, force_update);
if (err) {
kr_log_error(TLS, "session ticket: failed rotation, %s\n",
kr_strerror(err));
if (kr_fails_assert(err != kr_error(EINVAL)))
return;
}
/* Reschedule. */
const time_t tv_sec_next = (epoch + 1) * TST_KEY_LIFETIME;
const uint64_t ms_until_second = 1000 - (now.tv_usec + 501) / 1000;
const uint64_t remain_ms = (tv_sec_next - now.tv_sec - 1) * (uint64_t)1000
+ ms_until_second + 1;
/* ^ +1 because we don't want to wake up half a millisecond before the epoch! */
if (kr_fails_assert(remain_ms < ((uint64_t)TST_KEY_LIFETIME + 1 /*rounding tolerance*/) * 1000))
return;
kr_log_debug(TLS, "session ticket: epoch %"PRIu64
", scheduling rotation check in %"PRIu64" ms\n",
(uint64_t)epoch, remain_ms);
err = uv_timer_start(timer, &tst_timer_callback, remain_ms, 0);
if (kr_fails_assert(err == 0)) {
kr_log_error(TLS, "session ticket: failed to schedule, %s\n",
uv_strerror(err));
return;
}
}
/* Implementation for prototypes from ./tls.h */
void tls_session_ticket_enable(struct tls_session_ticket_ctx *ctx, gnutls_session_t session)
{
if (kr_fails_assert(ctx && session))
return;
const gnutls_datum_t gd = {
.size = SESSION_KEY_SIZE,
.data = ctx->key,
};
int err = gnutls_session_ticket_enable_server(session, &gd);
if (err) {
kr_log_error(TLS, "failed to enable session tickets: %s (%d)\n",
gnutls_strerror_name(err), err);
/* but continue without tickets */
}
}
tst_ctx_t * tls_session_ticket_ctx_create(uv_loop_t *loop, const char *secret,
size_t secret_len)
{
if (kr_fails_assert(loop && (!secret_len || secret)))
return NULL;
#if GNUTLS_VERSION_NUMBER < 0x030500
/* We would need different SESSION_KEY_SIZE; avoid an error. */
return NULL;
#endif
tst_ctx_t *ctx = tst_key_create(secret, secret_len, loop);
if (ctx) {
tst_key_check(&ctx->timer, true);
}
return ctx;
}
void tls_session_ticket_ctx_destroy(tst_ctx_t *ctx)
{
if (ctx == NULL) {
return;
}
uv_close((uv_handle_t *)&ctx->timer, &tst_key_destroy);
}