Untitled

mail@pastecode.io avatar
unknown
plain_text
23 days ago
4.7 kB
5
Indexable
Never
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <infiniband/verbs.h>

#define MAX_POLL_CQ_TIMEOUT 2000
#define MSG_SIZE 64

struct config_t {
    const char *dev_name;
    const char *server_name;
    u_int32_t tcp_port;
    int ib_port;
    int gid_idx;
};

struct connection {
    struct ibv_context *context;
    struct ibv_pd *pd;
    struct ibv_cq *cq;
    struct ibv_qp *qp;
    struct ibv_mr *mr;
    char *buf;
    int size;
    int sock;
};

static int post_send(struct connection *conn) {
    struct ibv_send_wr wr, *bad_wr = NULL;
    struct ibv_sge sge;

    memset(&wr, 0, sizeof(wr));

    wr.wr_id = (uintptr_t)conn;
    wr.opcode = IBV_WR_SEND;
    wr.sg_list = &sge;
    wr.num_sge = 1;
    wr.send_flags = IBV_SEND_SIGNALED;

    sge.addr = (uintptr_t)conn->buf;
    sge.length = MSG_SIZE;
    sge.lkey = conn->mr->lkey;

    return ibv_post_send(conn->qp, &wr, &bad_wr);
}

static int post_recv(struct connection *conn) {
    struct ibv_recv_wr wr, *bad_wr = NULL;
    struct ibv_sge sge;

    memset(&wr, 0, sizeof(wr));

    wr.wr_id = (uintptr_t)conn;
    wr.sg_list = &sge;
    wr.num_sge = 1;

    sge.addr = (uintptr_t)conn->buf;
    sge.length = MSG_SIZE;
    sge.lkey = conn->mr->lkey;

    return ibv_post_recv(conn->qp, &wr, &bad_wr);
}

static int setup_connection(struct connection *conn) {
    struct ibv_device **dev_list;
    struct ibv_device *ib_dev;
    int num_devices;

    dev_list = ibv_get_device_list(&num_devices);
    if (!dev_list) {
        fprintf(stderr, "Failed to get IB devices list\n");
        return 1;
    }

    ib_dev = *dev_list;
    if (!ib_dev) {
        fprintf(stderr, "No IB devices found\n");
        return 1;
    }

    conn->context = ibv_open_device(ib_dev);
    if (!conn->context) {
        fprintf(stderr, "Failed to open device\n");
        return 1;
    }

    conn->pd = ibv_alloc_pd(conn->context);
    if (!conn->pd) {
        fprintf(stderr, "Failed to allocate PD\n");
        return 1;
    }

    conn->cq = ibv_create_cq(conn->context, 10, NULL, NULL, 0);
    if (!conn->cq) {
        fprintf(stderr, "Failed to create CQ\n");
        return 1;
    }

    struct ibv_qp_init_attr qp_init_attr = {
        .send_cq = conn->cq,
        .recv_cq = conn->cq,
        .qp_type = IBV_QPT_RC,
        .cap = {
            .max_send_wr = 10,
            .max_recv_wr = 10,
            .max_send_sge = 1,
            .max_recv_sge = 1
        }
    };

    conn->qp = ibv_create_qp(conn->pd, &qp_init_attr);
    if (!conn->qp) {
        fprintf(stderr, "Failed to create QP\n");
        return 1;
    }

    conn->size = MSG_SIZE;
    conn->buf = malloc(conn->size);
    if (!conn->buf) {
        fprintf(stderr, "Failed to allocate memory\n");
        return 1;
    }

    conn->mr = ibv_reg_mr(conn->pd, conn->buf, conn->size, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
    if (!conn->mr) {
        fprintf(stderr, "Failed to register MR\n");
        return 1;
    }

    return 0;
}

int main(int argc, char *argv[]) {
    struct connection conn;
    struct config_t config = {
        .dev_name = NULL,
        .server_name = NULL,
        .tcp_port = 18515,
        .ib_port = 1,
        .gid_idx = -1
    };

    if (argc > 1) config.server_name = argv[1];

    if (setup_connection(&conn)) {
        fprintf(stderr, "Failed to setup connection\n");
        return 1;
    }

    if (config.server_name) {
        // Client mode
        strcpy(conn.buf, "Hello from client!");
        if (post_send(&conn)) {
            fprintf(stderr, "Failed to post send\n");
            return 1;
        }
        printf("Message sent: %s\n", conn.buf);
    } else {
        // Server mode
        if (post_recv(&conn)) {
            fprintf(stderr, "Failed to post receive\n");
            return 1;
        }
        printf("Waiting for message...\n");
    }

    struct ibv_wc wc;
    int ne;

    do {
        ne = ibv_poll_cq(conn.cq, 1, &wc);
    } while (ne == 0);

    if (ne < 0) {
        fprintf(stderr, "Failed to poll CQ\n");
        return 1;
    }

    if (wc.status != IBV_WC_SUCCESS) {
        fprintf(stderr, "Work completion status is not IBV_WC_SUCCESS\n");
        return 1;
    }

    if (!config.server_name) {
        printf("Message received: %s\n", conn.buf);
    }

    ibv_destroy_qp(conn.qp);
    ibv_destroy_cq(conn.cq);
    ibv_dereg_mr(conn.mr);
    ibv_dealloc_pd(conn.pd);
    ibv_close_device(conn.context);
    free(conn.buf);

    return 0;
}
Leave a Comment