Untitled
unknown
c_cpp
2 years ago
2.8 kB
10
Indexable
#include <xgboost/c_api.h>
#include <zmq.h>
#include <stdio.h>
#include <stdlib.h>
// Struct to hold batch of extracted features
typedef struct {
int partner_id;
float features[MAX_FEATURES];
int num_samples;
} FeatureBatch;
// Function to extract features from ad request
void extract_features(AdRequest* request, float* features) {
// Extract features like location, device, etc
features[0] = request->country;
features[1] = request->device_type;
// etc
}
// Partition and batch the feature extraction
void* ingest_thread(void* args) {
void *context = zmq_ctx_new();
void *subscriber = zmq_socket(context, ZMQ_SUB);
zmq_connect(subscriber, "tcp://localhost:5557");
zmq_setsockopt(subscriber, ZMQ_SUBSCRIBE, "", 0);
while (1) {
AdRequest request;
zmq_recv(subscriber, &request, sizeof(request));
FeatureBatch* batch = batches[request.partner_id];
extract_features(&request, batch->features[batch->num_samples]);
batch->num_samples++;
if (batch->num_samples >= BATCH_SIZE) {
// Send batch to trainer
zmq_send(trainer_socket, batch, sizeof(FeatureBatch), 0);
batch->num_samples = 0;
}
}
}
// Trainer thread
void* trainer_thread(void* args) {
void *context = zmq_ctx_new();
void *subscriber = zmq_socket(context, ZMQ_SUB);
zmq_bind(subscriber, "tcp://*:5558");
while (1) {
FeatureBatch batch;
zmq_recv(subscriber, &batch, sizeof(FeatureBatch), 0);
// Accumulate batch into training data
AccumulateTrainingBatch(batch);
// Periodically train model
if (ShouldTrainModel()) {
xgboost::DMatrix dmat;
CreateDMatrix(&dmat);
xgboost::Booster booster;
booster.Configure();
booster.Update(dmat);
xgboost::Model model = booster.GetModel();
SaveModel(model);
}
}
}
// Prediction thread
void* predict_thread(void* args) {
// Load model
xgboost::Model model = LoadModel();
while (1) {
AdRequest request;
zmq_recv(predictor_socket, &request, sizeof(request));
FeatureBatch batch;
extract_features(&request, batch.features);
// Get prediction
float pred = model.Predict(batch);
// Return prediction
zmq_send(results_socket, &pred, sizeof(float), 0);
}
}
int main() {
// Start partition threads
for (int i=0; i<NUM_PARTNERS; i++) {
pthread_create(&thread, NULL, ingest_thread, NULL);
}
// Start trainer thread
pthread_create(&thread, NULL, trainer_thread, NULL);
// Start predictor threads
for (int i=0; i<NUM_PREDICTORS; i++) {
pthread_create(&thread, NULL, predict_thread, NULL);
}
// Join threads
for (int i=0; i<NUM_THREADS; i++) {
pthread_join(threads[i], NULL);
}
}Editor is loading...