On Thu, Feb 29, 2024 at 05:59:53PM +0100, Laurent Vivier wrote: > The TCP and UDP checksums are computed using the data in the TCP/UDP > payload but also some informations in the IP header (protocol, > length, source and destination addresses). > > We add two functions, proto_ipv4_header_psum() and > proto_ipv6_header_psum(), to compute the checksum of the IP > header part. > > Signed-off-by: Laurent Vivier > --- > > Notes: > v4: > - fix payload length endianness > > v3: > - function parameters provide tot_len, saddr, daddr and protocol > rather than an iphdr/ipv6hdr > > v2: > - move new function to checksum.c > - use _psum rather than _checksum in the name > - replace csum_udp4() and csum_udp6() by the new function > > checksum.c | 69 ++++++++++++++++++++++++++++++++++++++++++------------ > checksum.h | 4 ++++ > tcp.c | 45 ++++++++++++++++------------------- > udp.c | 13 ++++++---- > 4 files changed, 86 insertions(+), 45 deletions(-) > > diff --git a/checksum.c b/checksum.c > index 511b296a9a80..93c8d5205c2b 100644 > --- a/checksum.c > +++ b/checksum.c > @@ -134,6 +134,30 @@ uint16_t csum_ip4_header(uint16_t tot_len, uint8_t protocol, > return ~csum_fold(sum); > } > > +/** > + * proto_ipv4_header_psum() - Calculates the partial checksum of an > + * IPv4 header for UDP or TCP > + * @tot_len: IPv4 Payload length > + * @proto: Protocol number > + * @saddr: Source address > + * @daddr: Destination address > + * @proto: proto Protocol number Needs to note that tot_len is in host order, but saddr and daddr are in network order. Usually, I'd take host order as assumed for a plain integer type, but since it's mixed here, we should annotate them all. Alternatively, we could pass saddr and daddr as struct in_addr. In general I've tried to pass IPv4 addresses with that type, rather than in_addr_t or uint32_t. Looking at the callers, it seems like it's a mixed bag whether that's messier or cleaner in this case. > + * Returns: Partial checksum of the IPv4 header > + */ > +uint32_t proto_ipv4_header_psum(uint16_t tot_len, uint8_t protocol, > + uint32_t saddr, uint32_t daddr) > +{ > + uint32_t psum = htons(protocol); > + > + psum += (saddr >> 16) & 0xffff; > + psum += saddr & 0xffff; > + psum += (daddr >> 16) & 0xffff; > + psum += daddr & 0xffff; > + psum += htons(tot_len); > + > + return psum; > +} > + > /** > * csum_udp4() - Calculate and set checksum for a UDP over IPv4 packet > * @udp4hr: UDP header, initialised apart from checksum > @@ -150,14 +174,12 @@ void csum_udp4(struct udphdr *udp4hr, > udp4hr->check = 0; > > if (UDP4_REAL_CHECKSUMS) { > - /* UNTESTED: if we did want real UDPv4 checksums, this > - * is roughly what we'd need */ > - uint32_t psum = csum_fold(saddr.s_addr) > - + csum_fold(daddr.s_addr) > - + htons(len + sizeof(*udp4hr)) > - + htons(IPPROTO_UDP); > - /* Add in partial checksum for the UDP header alone */ > - psum += sum_16b(udp4hr, sizeof(*udp4hr)); > + uint16_t tot_len = len + sizeof(struct udphdr); > + uint32_t psum = proto_ipv4_header_psum(tot_len, > + IPPROTO_UDP, > + saddr.s_addr, > + daddr.s_addr); > + psum = csum_unfolded(udp4hr, sizeof(struct udphdr), psum); > udp4hr->check = csum(payload, len, psum); > } > } > @@ -180,6 +202,26 @@ void csum_icmp4(struct icmphdr *icmp4hr, const void *payload, size_t len) > icmp4hr->checksum = csum(payload, len, psum); > } > > +/** > + * proto_ipv6_header_psum() - Calculates the partial checksum of an > + * IPv6 header for UDP or TCP > + * @payload_len: IPv6 payload length > + * @proto: Protocol number > + * @saddr: Source address > + * @daddr: Destination address > + * Returns: Partial checksum of the IPv6 header > + */ > +uint32_t proto_ipv6_header_psum(uint16_t payload_len, uint8_t protocol, > + struct in6_addr saddr, struct in6_addr daddr) I don't see any point to passing the addresses by value here. You take their address, so they must be written back to memory if passed in registers. At the call sites, you still have the dereference so it doesn't help with alignment. > +{ > + uint32_t sum = htons(protocol) + htons(payload_len); > + > + sum += sum_16b(&saddr, sizeof(saddr)); > + sum += sum_16b(&daddr, sizeof(daddr)); > + > + return sum; > +} > + > /** > * csum_udp6() - Calculate and set checksum for a UDP over IPv6 packet > * @udp6hr: UDP header, initialised apart from checksum > @@ -190,14 +232,11 @@ void csum_udp6(struct udphdr *udp6hr, > const struct in6_addr *saddr, const struct in6_addr *daddr, > const void *payload, size_t len) > { > - /* Partial checksum for the pseudo-IPv6 header */ > - uint32_t psum = sum_16b(saddr, sizeof(*saddr)) + > - sum_16b(daddr, sizeof(*daddr)) + > - htons(len + sizeof(*udp6hr)) + htons(IPPROTO_UDP); > - > + uint32_t psum = proto_ipv6_header_psum(len + sizeof(struct udphdr), > + IPPROTO_UDP, *saddr, *daddr); > udp6hr->check = 0; > - /* Add in partial checksum for the UDP header alone */ > - psum += sum_16b(udp6hr, sizeof(*udp6hr)); > + > + psum = csum_unfolded(udp6hr, sizeof(struct udphdr), psum); > udp6hr->check = csum(payload, len, psum); > } > > diff --git a/checksum.h b/checksum.h > index 92db73612b6e..b2b5b8e8b77e 100644 > --- a/checksum.h > +++ b/checksum.h > @@ -15,10 +15,14 @@ uint16_t csum_fold(uint32_t sum); > uint16_t csum_unaligned(const void *buf, size_t len, uint32_t init); > uint16_t csum_ip4_header(uint16_t tot_len, uint8_t protocol, > uint32_t saddr, uint32_t daddr); > +uint32_t proto_ipv4_header_psum(uint16_t tot_len, uint8_t protocol, > + uint32_t saddr, uint32_t daddr); > void csum_udp4(struct udphdr *udp4hr, > struct in_addr saddr, struct in_addr daddr, > const void *payload, size_t len); > void csum_icmp4(struct icmphdr *ih, const void *payload, size_t len); > +uint32_t proto_ipv6_header_psum(uint16_t payload_len, uint8_t protocol, > + struct in6_addr saddr, struct in6_addr daddr); > void csum_udp6(struct udphdr *udp6hr, > const struct in6_addr *saddr, const struct in6_addr *daddr, > const void *payload, size_t len); > diff --git a/tcp.c b/tcp.c > index ea0802c6b102..d78efa5401bb 100644 > --- a/tcp.c > +++ b/tcp.c > @@ -939,39 +939,30 @@ static void tcp_sock_set_bufsize(const struct ctx *c, int s) > * tcp_update_check_tcp4() - Update TCP checksum from stored one > * @buf: L2 packet buffer with final IPv4 header Function comment no longer matches the parameters. > */ > -static void tcp_update_check_tcp4(struct tcp4_l2_buf_t *buf) > +static void tcp_update_check_tcp4(struct iphdr *iph) Hmm... so this takes only a pointer to iph, but writes to the TCP header it assumes is beyond that, and reads from the payload it assumes is beyond that. That seems like a dangerous interface to me (not to mention that I fear it could trigger TBAA traps). > { > - uint16_t tlen = ntohs(buf->iph.tot_len) - 20; > - uint32_t sum = htons(IPPROTO_TCP); > + uint16_t tlen = ntohs(iph->tot_len) - sizeof(struct iphdr); > + uint32_t sum = proto_ipv4_header_psum(tlen, IPPROTO_TCP, > + iph->saddr, iph->daddr); > + struct tcphdr *th = (struct tcphdr *)(iph + 1); > > - sum += (buf->iph.saddr >> 16) & 0xffff; > - sum += buf->iph.saddr & 0xffff; > - sum += (buf->iph.daddr >> 16) & 0xffff; > - sum += buf->iph.daddr & 0xffff; > - sum += htons(ntohs(buf->iph.tot_len) - 20); > - > - buf->th.check = 0; > - buf->th.check = csum(&buf->th, tlen, sum); > + th->check = 0; > + th->check = csum(th, tlen, sum); > } > > /** > * tcp_update_check_tcp6() - Calculate TCP checksum for IPv6 > * @buf: L2 packet buffer with final IPv6 header > */ > -static void tcp_update_check_tcp6(struct tcp6_l2_buf_t *buf) > +static void tcp_update_check_tcp6(struct ipv6hdr *ip6h) Same comments as for the IPv4 version. > { > - int len = ntohs(buf->ip6h.payload_len) + sizeof(struct ipv6hdr); > - > - buf->ip6h.hop_limit = IPPROTO_TCP; > - buf->ip6h.version = 0; > - buf->ip6h.nexthdr = 0; > + struct tcphdr *th = (struct tcphdr *)(ip6h + 1); > + uint16_t payload_len = ntohs(ip6h->payload_len); > + uint32_t sum = proto_ipv6_header_psum(payload_len, IPPROTO_TCP, > + ip6h->saddr, ip6h->daddr); > > - buf->th.check = 0; > - buf->th.check = csum(&buf->ip6h, len, 0); > - > - buf->ip6h.hop_limit = 255; > - buf->ip6h.version = 6; > - buf->ip6h.nexthdr = IPPROTO_TCP; > + th->check = 0; > + th->check = csum(th, payload_len, sum); > } > > /** > @@ -1383,7 +1374,7 @@ do { \ > > SET_TCP_HEADER_COMMON_V4_V6(b, conn, seq); > > - tcp_update_check_tcp4(b); > + tcp_update_check_tcp4(&b->iph); > > tlen = tap_iov_len(c, &b->taph, ip_len); > } else { > @@ -1402,7 +1393,11 @@ do { \ > > SET_TCP_HEADER_COMMON_V4_V6(b, conn, seq); > > - tcp_update_check_tcp6(b); > + tcp_update_check_tcp6(&b->ip6h); > + > + b->ip6h.hop_limit = 255; > + b->ip6h.version = 6; > + b->ip6h.nexthdr = IPPROTO_TCP; > > b->ip6h.flow_lbl[0] = (conn->sock >> 16) & 0xf; > b->ip6h.flow_lbl[1] = (conn->sock >> 8) & 0xff; > diff --git a/udp.c b/udp.c > index d517c99dcc69..410ace16a6a2 100644 > --- a/udp.c > +++ b/udp.c > @@ -625,6 +625,7 @@ static size_t udp_update_hdr6(const struct ctx *c, int n, in_port_t dstport, > { > struct udp6_l2_buf_t *b = &udp6_l2_buf[n]; > struct in6_addr *src; > + uint16_t payload_len; > in_port_t src_port; > size_t ip_len; > > @@ -633,7 +634,8 @@ static size_t udp_update_hdr6(const struct ctx *c, int n, in_port_t dstport, > > ip_len = udp6_l2_mh_sock[n].msg_len + sizeof(b->ip6h) + sizeof(b->uh); > > - b->ip6h.payload_len = htons(udp6_l2_mh_sock[n].msg_len + sizeof(b->uh)); > + payload_len = udp6_l2_mh_sock[n].msg_len + sizeof(b->uh); > + b->ip6h.payload_len = htons(payload_len); > > if (IN6_IS_ADDR_LINKLOCAL(src)) { > b->ip6h.daddr = c->ip6.addr_ll_seen; > @@ -675,10 +677,11 @@ static size_t udp_update_hdr6(const struct ctx *c, int n, in_port_t dstport, > b->uh.source = b->s_in6.sin6_port; > b->uh.dest = htons(dstport); > b->uh.len = b->ip6h.payload_len; > - > - b->ip6h.hop_limit = IPPROTO_UDP; > - b->ip6h.version = b->ip6h.nexthdr = b->uh.check = 0; > - b->uh.check = csum(&b->ip6h, ip_len, 0); > + b->uh.check = 0; > + b->uh.check = csum(&b->uh, payload_len, > + proto_ipv6_header_psum(payload_len, IPPROTO_UDP, > + b->ip6h.saddr, > + b->ip6h.daddr)); > b->ip6h.version = 6; > b->ip6h.nexthdr = IPPROTO_UDP; > b->ip6h.hop_limit = 255; -- David Gibson | I'll have my music baroque, and my code david AT gibson.dropbear.id.au | minimalist, thank you. NOT _the_ _other_ | _way_ _around_! http://www.ozlabs.org/~dgibson