Untitled

mail@pastecode.io avatarunknown
c_cpp
a month ago
2.8 kB
1
Indexable
Never
#include <xgboost/c_api.h>
#include <zmq.h>
#include <stdio.h>
#include <stdlib.h>

// Struct to hold batch of extracted features
typedef struct {
  int partner_id;
  float features[MAX_FEATURES]; 
  int num_samples;
} FeatureBatch;

// Function to extract features from ad request
void extract_features(AdRequest* request, float* features) {
  // Extract features like location, device, etc
  features[0] = request->country; 
  features[1] = request->device_type;
  // etc
}

// Partition and batch the feature extraction
void* ingest_thread(void* args) {

  void *context = zmq_ctx_new();
  void *subscriber = zmq_socket(context, ZMQ_SUB);
  
  zmq_connect(subscriber, "tcp://localhost:5557");
  zmq_setsockopt(subscriber, ZMQ_SUBSCRIBE, "", 0);

  while (1) {
    AdRequest request;
    zmq_recv(subscriber, &request, sizeof(request));
    
    FeatureBatch* batch = batches[request.partner_id]; 
    
    extract_features(&request, batch->features[batch->num_samples]);
    batch->num_samples++;

    if (batch->num_samples >= BATCH_SIZE) {
      // Send batch to trainer
      zmq_send(trainer_socket, batch, sizeof(FeatureBatch), 0); 
      batch->num_samples = 0; 
    }
  }
}

// Trainer thread
void* trainer_thread(void* args) {

  void *context = zmq_ctx_new();
  void *subscriber = zmq_socket(context, ZMQ_SUB);
  
  zmq_bind(subscriber, "tcp://*:5558");

  while (1) {
    FeatureBatch batch;
    zmq_recv(subscriber, &batch, sizeof(FeatureBatch), 0);

    // Accumulate batch into training data
    AccumulateTrainingBatch(batch); 

    // Periodically train model
    if (ShouldTrainModel()) {
      xgboost::DMatrix dmat;
      CreateDMatrix(&dmat);
      
      xgboost::Booster booster;
      booster.Configure();

      booster.Update(dmat);
      
      xgboost::Model model = booster.GetModel(); 
      SaveModel(model);
    }

  }

}

// Prediction thread
void* predict_thread(void* args) {
  
  // Load model  
  xgboost::Model model = LoadModel();

  while (1) {
    AdRequest request;
    zmq_recv(predictor_socket, &request, sizeof(request));
  
    FeatureBatch batch;
    extract_features(&request, batch.features);
    
    // Get prediction
    float pred = model.Predict(batch);
    
    // Return prediction  
    zmq_send(results_socket, &pred, sizeof(float), 0);
  }

}

int main() {

  // Start partition threads
  for (int i=0; i<NUM_PARTNERS; i++) {
    pthread_create(&thread, NULL, ingest_thread, NULL);
  }

  // Start trainer thread
  pthread_create(&thread, NULL, trainer_thread, NULL);

  // Start predictor threads
  for (int i=0; i<NUM_PREDICTORS; i++) {
    pthread_create(&thread, NULL, predict_thread, NULL);
  }

  // Join threads
  for (int i=0; i<NUM_THREADS; i++) {
    pthread_join(threads[i], NULL); 
  }

}