Untitled

mail@pastecode.io avatar
unknown
plain_text
a month ago
9.5 kB
3
Indexable
Never
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <errno.h>
#include <netdb.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <infiniband/verbs.h>

#define MAX_POLL_CQ_TIMEOUT 2000
#define MSG_SIZE 64
#define TCP_PORT 18515

struct qp_info {
    uint32_t qp_num;
    uint16_t lid;
    uint8_t gid[16];
};

struct connection {
    struct ibv_context *context;
    struct ibv_pd *pd;
    struct ibv_cq *cq;
    struct ibv_qp *qp;
    struct ibv_mr *mr;
    char *buf;
    int size;
    int sock;
    struct qp_info local_qp_info;
    struct qp_info remote_qp_info;
    int ib_port;
};

static int modify_qp_to_init(struct ibv_qp *qp, int ib_port)
{
    struct ibv_qp_attr attr = {
        .qp_state        = IBV_QPS_INIT,
        .pkey_index      = 0,
        .port_num        = ib_port,
        .qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_WRITE
    };

    return ibv_modify_qp(qp, &attr,
                         IBV_QP_STATE      |
                         IBV_QP_PKEY_INDEX |
                         IBV_QP_PORT       |
                         IBV_QP_ACCESS_FLAGS);
}

static int modify_qp_to_rtr(struct ibv_qp *qp, int ib_port, uint32_t remote_qpn, uint16_t dlid, uint8_t *dgid)
{
    struct ibv_qp_attr attr = {
        .qp_state           = IBV_QPS_RTR,
        .path_mtu           = IBV_MTU_1024,
        .dest_qp_num        = remote_qpn,
        .rq_psn             = 0,
        .max_dest_rd_atomic = 1,
        .min_rnr_timer      = 0x12,
        .ah_attr            = {
            .is_global      = 0,
            .dlid           = dlid,
            .sl             = 0,
            .src_path_bits  = 0,
            .port_num       = ib_port
        }
    };

    if (dgid) {
        attr.ah_attr.is_global = 1;
        attr.ah_attr.grh.hop_limit = 1;
        memcpy(&attr.ah_attr.grh.dgid, dgid, 16);
        attr.ah_attr.grh.sgid_index = 0;
    }

    return ibv_modify_qp(qp, &attr,
                         IBV_QP_STATE              |
                         IBV_QP_AV                 |
                         IBV_QP_PATH_MTU           |
                         IBV_QP_DEST_QPN           |
                         IBV_QP_RQ_PSN             |
                         IBV_QP_MAX_DEST_RD_ATOMIC |
                         IBV_QP_MIN_RNR_TIMER);
}

static int modify_qp_to_rts(struct ibv_qp *qp)
{
    struct ibv_qp_attr attr = {
        .qp_state      = IBV_QPS_RTS,
        .timeout       = 0x12,
        .retry_cnt     = 7,
        .rnr_retry     = 7,
        .sq_psn        = 0,
        .max_rd_atomic = 1
    };

    return ibv_modify_qp(qp, &attr,
                         IBV_QP_STATE              |
                         IBV_QP_TIMEOUT            |
                         IBV_QP_RETRY_CNT          |
                         IBV_QP_RNR_RETRY          |
                         IBV_QP_SQ_PSN             |
                         IBV_QP_MAX_QP_RD_ATOMIC);
}

static int post_send(struct connection *conn)
{
    struct ibv_sge list = {
        .addr   = (uintptr_t)conn->buf,
        .length = conn->size,
        .lkey   = conn->mr->lkey
    };
    struct ibv_send_wr wr = {
        .wr_id      = (uintptr_t)conn,
        .sg_list    = &list,
        .num_sge    = 1,
        .opcode     = IBV_WR_SEND,
        .send_flags = IBV_SEND_SIGNALED,
    };
    struct ibv_send_wr *bad_wr;
    return ibv_post_send(conn->qp, &wr, &bad_wr);
}

static int setup_connection(struct connection *conn, const char *dev_name, int ib_port)
{
    struct ibv_device **dev_list;
    struct ibv_device *ib_dev;
    int num_devices;

    dev_list = ibv_get_device_list(&num_devices);
    if (!dev_list) {
        fprintf(stderr, "Failed to get IB devices list: %s\n", strerror(errno));
        return 1;
    }

    ib_dev = *dev_list;
    if (!ib_dev) {
        fprintf(stderr, "No IB devices found\n");
        return 1;
    }

    conn->context = ibv_open_device(ib_dev);
    if (!conn->context) {
        fprintf(stderr, "Failed to open device: %s\n", strerror(errno));
        return 1;
    }

    conn->pd = ibv_alloc_pd(conn->context);
    if (!conn->pd) {
        fprintf(stderr, "Failed to allocate PD: %s\n", strerror(errno));
        return 1;
    }

    conn->cq = ibv_create_cq(conn->context, 10, NULL, NULL, 0);
    if (!conn->cq) {
        fprintf(stderr, "Failed to create CQ: %s\n", strerror(errno));
        return 1;
    }

    struct ibv_qp_init_attr qp_init_attr = {
        .send_cq = conn->cq,
        .recv_cq = conn->cq,
        .qp_type = IBV_QPT_RC,
        .cap     = {
            .max_send_wr  = 10,
            .max_recv_wr  = 10,
            .max_send_sge = 1,
            .max_recv_sge = 1
        }
    };

    conn->qp = ibv_create_qp(conn->pd, &qp_init_attr);
    if (!conn->qp) {
        fprintf(stderr, "Failed to create QP: %s\n", strerror(errno));
        return 1;
    }

    conn->size = MSG_SIZE;
    conn->buf = malloc(conn->size);
    if (!conn->buf) {
        fprintf(stderr, "Failed to allocate memory: %s\n", strerror(errno));
        return 1;
    }

    conn->mr = ibv_reg_mr(conn->pd, conn->buf, conn->size, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
    if (!conn->mr) {
        fprintf(stderr, "Failed to register MR: %s\n", strerror(errno));
        return 1;
    }

    if (modify_qp_to_init(conn->qp, ib_port)) {
        fprintf(stderr, "Failed to modify QP to INIT\n");
        return 1;
    }

    struct ibv_port_attr port_attr;
    if (ibv_query_port(conn->context, ib_port, &port_attr)) {
        fprintf(stderr, "Failed to query port attributes\n");
        return 1;
    }
    conn->local_qp_info.lid = port_attr.lid;
    conn->local_qp_info.qp_num = conn->qp->qp_num;

    printf("Local QP number: %u\n", conn->local_qp_info.qp_num);
    printf("Local LID: %u\n", conn->local_qp_info.lid);

    return 0;
}

static int setup_socket(struct connection *conn, const char *server_name)
{
    struct sockaddr_in addr;
    struct addrinfo *res;
    int ret;

    conn->sock = socket(AF_INET, SOCK_STREAM, 0);
    if (conn->sock < 0) {
        fprintf(stderr, "Failed to create socket: %s\n", strerror(errno));
        return -1;
    }

    ret = getaddrinfo(server_name, NULL, NULL, &res);
    if (ret) {
        fprintf(stderr, "getaddrinfo failed: %s\n", gai_strerror(ret));
        return -1;
    }

    memset(&addr, 0, sizeof(addr));
    addr.sin_family = AF_INET;
    addr.sin_port = htons(TCP_PORT);
    memcpy(&addr.sin_addr, &((struct sockaddr_in *)res->ai_addr)->sin_addr, sizeof(addr.sin_addr));

    freeaddrinfo(res);

    if (connect(conn->sock, (struct sockaddr *)&addr, sizeof(addr))) {
        fprintf(stderr, "Failed to connect to server: %s\n", strerror(errno));
        return -1;
    }

    printf("Connected to server\n");
    return 0;
}

static int exchange_qp_info(struct connection *conn)
{
    if (write(conn->sock, &conn->local_qp_info, sizeof(conn->local_qp_info)) != sizeof(conn->local_qp_info)) {
        fprintf(stderr, "Failed to send local QP info\n");
        return -1;
    }

    if (read(conn->sock, &conn->remote_qp_info, sizeof(conn->remote_qp_info)) != sizeof(conn->remote_qp_info)) {
        fprintf(stderr, "Failed to receive remote QP info\n");
        return -1;
    }

    printf("QP info exchanged\n");
    return 0;
}

int main(int argc, char *argv[])
{
    struct connection conn;
    struct ibv_wc wc;
    int ne;

    if (argc != 2) {
        fprintf(stderr, "Usage: %s <server-address>\n", argv[0]);
        return 1;
    }

    memset(&conn, 0, sizeof(conn));
    conn.ib_port = 1; // Use port 1 by default

    if (setup_socket(&conn, argv[1])) {
        fprintf(stderr, "Failed to setup socket\n");
        return 1;
    }

    if (setup_connection(&conn, NULL, conn.ib_port)) {
        fprintf(stderr, "Failed to setup RDMA connection\n");
        return 1;
    }

    if (exchange_qp_info(&conn)) {
        fprintf(stderr, "Failed to exchange QP info\n");
        return 1;
    }

    if (modify_qp_to_rtr(conn.qp, conn.ib_port, conn.remote_qp_info.qp_num, conn.remote_qp_info.lid, NULL)) {
        fprintf(stderr, "Failed to modify QP to RTR\n");
        return 1;
    }

    if (modify_qp_to_rts(conn.qp)) {
        fprintf(stderr, "Failed to modify QP to RTS\n");
        return 1;
    }

    strcpy(conn.buf, "Hello from client!");
    printf("Sending message: %s\n", conn.buf);

    if (post_send(&conn)) {
        fprintf(stderr, "Failed to post send\n");
        return 1;
    }

    do {
        ne = ibv_poll_cq(conn.cq, 1, &wc);
    } while (ne == 0);

    if (ne < 0) {
        fprintf(stderr, "Failed to poll CQ: %s\n", strerror(errno));
        return 1;
    }

    if (wc.status != IBV_WC_SUCCESS) {
        fprintf(stderr, "Work completion failed with status %s (%d)\n", 
                ibv_wc_status_str(wc.status), wc.status);
        return 1;
    }

    printf("Message sent successfully\n");

    // Clean up
    ibv_destroy_qp(conn.qp);
    ibv_destroy_cq(conn.cq);
    ibv_dereg_mr(conn.mr);
    ibv_dealloc_pd(conn.pd);
    ibv_close_device(conn.context);
    free(conn.buf);
    close(conn.sock);

    return 0;
}
Leave a Comment