diff --git a/daemon/tls.c b/daemon/tls.c index 5424cfc59f6419b4f114b3f222c69978a97bd3b7..e034b0594e23d6a807bb268f1a1c664f9269b75e 100644 --- a/daemon/tls.c +++ b/daemon/tls.c @@ -208,6 +208,7 @@ int tls_push(struct qr_task *task, uv_handle_t *handle, knot_pkt_t *pkt) return kr_error(EINVAL); } + int ret = kr_ok(); struct session *session = handle->data; struct tls_common_ctx *tls_ctx = session->outgoing ? &session->tls_client_ctx->c : &session->tls_ctx->c; @@ -215,55 +216,28 @@ int tls_push(struct qr_task *task, uv_handle_t *handle, knot_pkt_t *pkt) assert (tls_ctx); assert (session->outgoing == tls_ctx->client_side); - const uint16_t pkt_size = htons(pkt->size); + uint8_t *tcp_len_start = (uint8_t *)pkt->wire - 2; const char *logstring = tls_ctx->client_side ? client_logstring : server_logstring; gnutls_session_t tls_session = tls_ctx->tls_session; + const size_t size_to_send = pkt->size + 2; tls_ctx->task = task; + knot_wire_write_u16(tcp_len_start, pkt->size); - assert(gnutls_record_check_corked(tls_session) == 0); - - gnutls_record_cork(tls_session); - ssize_t count = 0; - if ((count = gnutls_record_send(tls_session, &pkt_size, sizeof(pkt_size)) < 0) || - (count = gnutls_record_send(tls_session, pkt->wire, pkt->size) < 0)) { + /* gnutls_record_send() calls worker_gnutls_push() which actually sends data. + * It either sends all the data either sends nothing and returns error. */ + ssize_t count = gnutls_record_send(tls_session, tcp_len_start, size_to_send); + if (count < 0) { kr_log_error("[%s] gnutls_record_send failed: %s (%zd)\n", - logstring, gnutls_strerror_name(count), count); - return kr_error(EIO); - } - - ssize_t submitted = 0; - ssize_t retries = 0; - do { - count = gnutls_record_uncork(tls_session, 0); - if (count < 0) { - if (gnutls_error_is_fatal(count)) { - kr_log_error("[%s] gnutls_record_uncork failed: %s (%zd)\n", - logstring, gnutls_strerror_name(count), count); - return kr_error(EIO); - } - if (++retries > TLS_MAX_UNCORK_RETRIES) { - kr_log_error("[%s] gnutls_record_uncork: too many sequential non-fatal errors (%zd), last error is: %s (%zd)\n", - logstring, retries, gnutls_strerror_name(count), count); - return kr_error(EIO); - } - } else if (count != 0) { - submitted += count; - retries = 0; - } else if (gnutls_record_check_corked(tls_session) != 0) { - if (++retries > TLS_MAX_UNCORK_RETRIES) { - kr_log_error("[%s] gnutls_record_uncork: too many retries (%zd)\n", - logstring, retries); - return kr_error(EIO); - } - } else if (submitted != sizeof(pkt_size) + pkt->size) { - kr_log_error("[%s] gnutls_record_uncork didn't send all data(%zd of %zd)\n", - logstring, submitted, sizeof(pkt_size) + pkt->size); - return kr_error(EIO); - } - } while (submitted != sizeof(pkt_size) + pkt->size); + logstring, gnutls_strerror_name(count), count); + ret = kr_error(EIO); + } else if (count != size_to_send) { + kr_log_error("[%s] gnutls_record_send didn't send all data (%zd of %zd)\n", + logstring, count, size_to_send); + ret = kr_error(EIO); + } - return kr_ok(); + return ret; } int tls_process(struct worker_ctx *worker, uv_stream_t *handle, const uint8_t *buf, ssize_t nread) diff --git a/daemon/worker.c b/daemon/worker.c index 850d02f560b7239e1c8a20090d092e7e46aa070d..219e19acc109351cda0a9f3b3aca90e8bcdd4eb2 100644 --- a/daemon/worker.c +++ b/daemon/worker.c @@ -62,6 +62,7 @@ struct qr_task knot_pkt_t *pktbuf; qr_tasklist_t waiting; uv_handle_t *pending[MAX_PENDING]; + void *write_req_buf; uint16_t pending_count; uint16_t addrlist_count; uint16_t addrlist_turn; @@ -586,7 +587,11 @@ static int request_start(struct request_ctx *ctx, knot_pkt_t *query) } req->qsource.size = query->size; - req->answer = knot_pkt_new(NULL, answer_max, &req->pool); + /* Wire buffer + placeholder for tcp message length field. + * Placeholder allows to avoid usage of gnutls_record_cork() \ gnutls_record_cork() + * when TLS is used. */ + char *wire = mm_alloc(&req->pool, answer_max + 2); + req->answer = knot_pkt_new(wire + 2, answer_max, &req->pool); if (!req->answer) { return kr_error(ENOMEM); } @@ -690,8 +695,12 @@ static struct qr_task *qr_task_create(struct request_ctx *ctx) } memset(task, 0, sizeof(*task)); /* avoid accidentally unitialized fields */ - /* Create packet buffers for answer and subrequests */ - knot_pkt_t *pktbuf = knot_pkt_new(NULL, pktbuf_max, &ctx->req.pool); + /* Create packet buffers for answer and subrequests. + * Don't forget about placeholder for tcp message length field. + * This allows to avoid usage of gnutls_record_cork() \ gnutls_record_cork() + * when TLS is used. */ + char *wire = mm_alloc(&ctx->req.pool, pktbuf_max + 2); + knot_pkt_t *pktbuf = knot_pkt_new(wire + 2, pktbuf_max, &ctx->req.pool); if (!pktbuf) { mm_free(&ctx->req.pool, task); return NULL; @@ -719,6 +728,10 @@ static void qr_task_free(struct qr_task *task) assert(ctx); + if (task->write_req_buf != NULL) { + free(task->write_req_buf); + } + /* Process outbound session. */ struct session *source_session = ctx->source.session; struct worker_ctx *worker = ctx->worker; @@ -844,25 +857,26 @@ static int qr_task_on_send(struct qr_task *task, uv_handle_t *handle, int status if (ret != kr_ok()) { while (session->waiting.len > 0) { struct qr_task *t = session->waiting.at[0]; + array_del(session->waiting, 0); + session_del_tasks(session, t); if (session->outgoing) { qr_task_finalize(t, KR_STATE_FAIL); } else { assert(t->ctx->source.session == session); t->ctx->source.session = NULL; } - array_del(session->waiting, 0); - session_del_tasks(session, t); qr_task_unref(t); } while (session->tasks.len > 0) { struct qr_task *t = session->tasks.at[0]; + array_del(session->tasks, 0); if (session->outgoing) { qr_task_finalize(t, KR_STATE_FAIL); } else { assert(t->ctx->source.session == session); t->ctx->source.session = NULL; } - session_del_tasks(session, t); + qr_task_unref(t); } session_close(session); return status; @@ -895,6 +909,10 @@ static void on_task_write(uv_write_t *req, int status) struct worker_ctx *worker = loop->data; assert(worker == get_worker()); struct qr_task *task = req->data; + if (task->write_req_buf) { + free(task->write_req_buf); + task->write_req_buf = NULL; + } qr_task_on_send(task, handle, status); qr_task_unref(task); iorequest_release(worker, req); @@ -904,6 +922,10 @@ static void on_nontask_write(uv_write_t *req, int status) { uv_handle_t *handle = (uv_handle_t *)(req->handle); uv_loop_t *loop = handle->loop; + if (req->data) { + free(req->data); + req->data = NULL; + } struct worker_ctx *worker = loop->data; assert(worker == get_worker()); iorequest_release(worker, req); @@ -912,15 +934,18 @@ static void on_nontask_write(uv_write_t *req, int status) ssize_t worker_gnutls_push(gnutls_transport_ptr_t h, const void *buf, size_t len) { struct tls_common_ctx *t = (struct tls_common_ctx *)h; - const uv_buf_t uv_buf[1] = { - { (char *)buf, len } - }; - if (t == NULL) { errno = EFAULT; return -1; } + void *buf_local = malloc(len); + memcpy(buf_local, buf, len); + + const uv_buf_t uv_buf[1] = { + { (char *)buf_local, len } + }; + assert(t->session && t->session->handle && t->session->handle->type == UV_TCP); @@ -933,6 +958,7 @@ ssize_t worker_gnutls_push(gnutls_transport_ptr_t h, const void *buf, size_t len void *ioreq = worker_iohandle_borrow(worker); if (!ioreq) { errno = EFAULT; + free(buf_local); return -1; } @@ -942,13 +968,17 @@ ssize_t worker_gnutls_push(gnutls_transport_ptr_t h, const void *buf, size_t len uv_write_cb write_cb = on_task_write; if (t->handshake_state == TLS_HS_DONE) { assert(task); + if (task->write_req_buf != NULL) { + free(task->write_req_buf); + } + write_req->data = task; + task->write_req_buf = buf_local; } else { task = NULL; write_cb = on_nontask_write; + write_req->data = buf_local; } - write_req->data = task; - ssize_t ret = -1; int res = uv_write(write_req, (uv_stream_t *)t->session->handle, uv_buf, 1, write_cb); if (res == 0) { @@ -975,6 +1005,7 @@ ssize_t worker_gnutls_push(gnutls_transport_ptr_t h, const void *buf, size_t len VERBOSE_MSG(NULL,"[%s] uv_write: %s\n", t->client_side ? "tls_client" : "tls", uv_strerror(res)); iorequest_release(worker, ioreq); + free(buf_local); errno = EIO; } return ret; @@ -1745,8 +1776,9 @@ static int qr_task_step(struct qr_task *task, session_del_tasks(session, task); while (session->tasks.len != 0) { struct qr_task *t = session->tasks.at[0]; + array_del(session->tasks, 0); qr_task_finalize(t, KR_STATE_FAIL); - session_del_tasks(session, t); + qr_task_unref(t); } subreq_finalize(task, packet_source, packet); session_close(session); @@ -1762,8 +1794,9 @@ static int qr_task_step(struct qr_task *task, session_del_tasks(session, task); while (session->tasks.len != 0) { struct qr_task *t = session->tasks.at[0]; + array_del(session->tasks, 0); qr_task_finalize(t, KR_STATE_FAIL); - session_del_tasks(session, t); + qr_task_unref(t); } subreq_finalize(task, packet_source, packet); session_close(session); @@ -2266,7 +2299,12 @@ int worker_process_tcp(struct worker_ctx *worker, uv_stream_t *handle, * Previous packet is allocated with mempool, so there's no need to free it manually. */ if (task->pktbuf->max_size < KNOT_WIRE_MAX_PKTSIZE) { knot_mm_t *pool = &task->pktbuf->mm; - pkt_buf = knot_pkt_new(NULL, KNOT_WIRE_MAX_PKTSIZE, pool); + /* Allocate wire buffer + placeholder for tcp message length field. + * Placeholder allows to avoid usage of + * gnutls_record_cork() \ gnutls_record_cork() + * when TLS is used. */ + char *wire = mm_alloc(pool, KNOT_WIRE_MAX_PKTSIZE + 2); + pkt_buf = knot_pkt_new(wire + 2, KNOT_WIRE_MAX_PKTSIZE, pool); if (!pkt_buf) { return kr_error(ENOMEM); }