diff --git a/src/libknot/packet/pkt.c b/src/libknot/packet/pkt.c index c25757a4ebfa0dad43f38266b46267b9c019719d..6a922f26dfb6aa062b8ead1227ae26fb350a6ca5 100644 --- a/src/libknot/packet/pkt.c +++ b/src/libknot/packet/pkt.c @@ -202,6 +202,37 @@ knot_pkt_t *knot_pkt_new(void *wire, uint16_t len, mm_ctx_t *mm) return pkt_new_mm(wire, len, mm); } +knot_pkt_t *knot_pkt_copy(const knot_pkt_t *pkt, mm_ctx_t *mm) +{ + dbg_packet("%s(%p, %p)\n", __func__, pkt, mm); + if (pkt == NULL) { + return NULL; + } + + knot_pkt_t *copy = knot_pkt_new(NULL, pkt->max_size, mm); + if (copy == NULL) { + return NULL; + } + + copy->size = pkt->size; + memcpy(copy->wire, pkt->wire, copy->size); + + /* Copy TSIG RR back to wire. */ + if (pkt->tsig_rr) { + int ret = knot_tsig_append(copy->wire, ©->size, copy->max_size, + pkt->tsig_rr); + if (ret != KNOT_EOK) { + knot_pkt_free(©); + return NULL; + } + } + + /* @note This could be done more effectively if needed. */ + knot_pkt_parse(copy, 0); + + return copy; +} + int knot_pkt_init_response(knot_pkt_t *pkt, const knot_pkt_t *query) { dbg_packet("%s(%p, %p)\n", __func__, pkt, query); diff --git a/src/libknot/packet/pkt.h b/src/libknot/packet/pkt.h index 0adb73b0d038a9e57ad9a837ef392076e22c7182..fa0188c40ffbfb10bb60167dc407929de7f0b608 100644 --- a/src/libknot/packet/pkt.h +++ b/src/libknot/packet/pkt.h @@ -130,6 +130,18 @@ typedef struct knot_pkt { */ knot_pkt_t *knot_pkt_new(void *wire, uint16_t len, mm_ctx_t *mm); +/*! + * \brief Copy packet. + * + * \note Current implementation is not very efficient, as it re-parses the wire. + * + * \param pkt Source packet. + * \param mm Memory context. + * + * \return new packet or NULL + */ +knot_pkt_t *knot_pkt_copy(const knot_pkt_t *pkt, mm_ctx_t *mm); + /*! * \brief Initialized response from query packet. * diff --git a/tests/pkt.c b/tests/pkt.c index ab9787cc4fa7bef089d96ce64b1b0a3d3edf465a..d0d5d58f3ba4d20504f89ee9743ebd4aef70abc3 100644 --- a/tests/pkt.c +++ b/tests/pkt.c @@ -41,9 +41,32 @@ const char *g_rdata[DATACOUNT] = { #define RDVAL(i) ((const uint8_t*)(g_rdata[(i)] + 1)) #define RDLEN(i) ((uint16_t)(g_rdata[(i)][0])) +/* @note Packet equivalence test, 5 checks. */ +static void packet_match(knot_pkt_t *in, knot_pkt_t *out) +{ + /* Check counts */ + is_int(knot_wire_get_qdcount(out->wire), + knot_wire_get_qdcount(in->wire), "pkt: QD match"); + is_int(knot_wire_get_ancount(out->wire), + knot_wire_get_ancount(in->wire), "pkt: AN match"); + is_int(knot_wire_get_nscount(out->wire), + knot_wire_get_nscount(in->wire), "pkt: NS match"); + is_int(knot_wire_get_arcount(out->wire), + knot_wire_get_arcount(in->wire), "pkt: AR match"); + + /* Check RRs */ + int rr_matched = 0; + for (unsigned i = 0; i < NAMECOUNT; ++i) { + if (knot_rrset_equal(&out->rr[i], &in->rr[i], KNOT_RRSET_COMPARE_WHOLE) > 0) { + ++rr_matched; + } + } + is_int(NAMECOUNT, rr_matched, "pkt: RR content match"); +} + int main(int argc, char *argv[]) { - plan(25); + plan(30); /* Create memory pool context. */ int ret = 0; @@ -146,34 +169,23 @@ int main(int argc, char *argv[]) ret = knot_pkt_parse_payload(in, 0); ok(ret == KNOT_EOK, "pkt: read payload"); - /* Check qname. */ - ok(knot_dname_is_equal(knot_pkt_qname(out), - knot_pkt_qname(in)), "pkt: equal qname"); + /* Compare parsed packet to written packet. */ + packet_match(in, out); - /* Check counts */ - is_int(knot_wire_get_qdcount(out->wire), - knot_wire_get_qdcount(in->wire), "pkt: QD match"); - is_int(knot_wire_get_ancount(out->wire), - knot_wire_get_ancount(in->wire), "pkt: AN match"); - is_int(knot_wire_get_nscount(out->wire), - knot_wire_get_nscount(in->wire), "pkt: NS match"); - is_int(knot_wire_get_arcount(out->wire), - knot_wire_get_arcount(in->wire), "pkt: AR match"); + /* + * Copied packet tests. + */ + knot_pkt_t *copy = knot_pkt_copy(in, &in->mm); + ok(copy != NULL, "pkt: create packet copy"); - /* Check RRs */ - int rr_matched = 0; - for (unsigned i = 0; i < NAMECOUNT; ++i) { - if (knot_rrset_equal(&out->rr[i], &in->rr[i], KNOT_RRSET_COMPARE_WHOLE) > 0) { - ++rr_matched; - } - } - is_int(NAMECOUNT, rr_matched, "pkt: RR content match"); + /* Compare copied packet to original. */ + packet_match(in, copy); /* Free packets. */ + knot_pkt_free(©); knot_pkt_free(&out); knot_pkt_free(&in); - ok(in == NULL, "pkt: free"); - ok(out == NULL, "pkt: free"); + ok(in == NULL && out == NULL && copy == NULL, "pkt: free"); /* Free extra data. */ for (unsigned i = 0; i < NAMECOUNT; ++i) {