static __always_inline __be64 get_pseudo_hdr_32sum(struct ipv6hdr *ipv6h, __be32 payload_length) {
__u64 csum_buffer = 0;
// Compute pseudo-header checksum
__u32 *saddr = (__u32 *) &ipv6h->saddr;
__u32 *daddr = (__u32 *) &ipv6h->daddr;
csum_buffer += saddr[0];
csum_buffer += saddr[1];
csum_buffer += saddr[2];
csum_buffer += saddr[3];
csum_buffer += daddr[0];
csum_buffer += daddr[1];
csum_buffer += daddr[2];
csum_buffer += daddr[3];
csum_buffer += (__u64) ipv6h->nexthdr << 24;
csum_buffer += payload_length;
return csum_buffer;
}
#define MAX_MSG_LENGTH 1480
static __always_inline __u16 csum_fold64(__u64 csum) {
int i;
#pragma unroll
for (i = 0; i < 4; i++) {
if (csum >> 16) {
csum = (csum & 0xffff) + (csum >> 16);
}
}
return ~csum;
}
static __always_inline __u64 get_data_32sum(__u64 csum_buffer, void *data, void *data_end) {
int i;
__u32 *buff = (__u32 *) data;
for (i = 0; i < MAX_MSG_LENGTH; i += 4) {
if ((void *)(buff + 1) > data_end) {
break;
}
csum_buffer += *buff;
buff++;
}
// In case payload size is not multiple of 4 bytes
__u8 *byte_buff = (__u8 *) buff;
if (byte_buff < data_end) {
const int remaining = data_end - byte_buff;
__u32 to_add = 0;
for (i = 0; i < remaining; i++) {
to_add |= byte_buff << 8*(remaining - i - 1);
}
csum_buffer += to_add;
}
return csum_buffer;
}
static __always_inline void set_udp6_csum(struct udphdr* udph, void* data_end, struct ipv6hdr* ipv6h) {
// Compute sum on pseudo header
__u64 csum_buffer = get_pseudo_hdr_32sum(ipv6h, (__be32) udph->len);
// Compute sum on data
csum_buffer = get_data_32sum(csum_buffer, (void *) udph, data_end);
// Fold
csum_buffer = csum_fold64(csum_buffer);
if (csum_buffer == 0xffff) {
udph->check = (__u16) csum_buffer;
return;
}
udph->check = ~((__u16) csum_buffer);
}
static __always_inline void set_icmp6_csum(struct icmp6hdr* icmp6h, void* data_end, struct ipv6hdr* ipv6h) {
// Compute sum on pseudo header
__u64 csum_buffer = get_pseudo_hdr_32sum(ipv6h, (__be32) ipv6h->payload_len);
// Compute sum on data
csum_buffer = get_data_32sum(csum_buffer, (void *) icmp6h, data_end);
// Fold
csum_buffer = csum_fold64(csum_buffer);
if (csum_buffer == 0xffff) {
icmp6h->icmp6_cksum = (__u16) csum_buffer;
return;
}
icmp6h->icmp6_cksum = ~((__u16) csum_buffer);
}