Untitled
c_cpp
a month ago
2.8 kB
1
Indexable
Never
#include <xgboost/c_api.h> #include <zmq.h> #include <stdio.h> #include <stdlib.h> // Struct to hold batch of extracted features typedef struct { int partner_id; float features[MAX_FEATURES]; int num_samples; } FeatureBatch; // Function to extract features from ad request void extract_features(AdRequest* request, float* features) { // Extract features like location, device, etc features[0] = request->country; features[1] = request->device_type; // etc } // Partition and batch the feature extraction void* ingest_thread(void* args) { void *context = zmq_ctx_new(); void *subscriber = zmq_socket(context, ZMQ_SUB); zmq_connect(subscriber, "tcp://localhost:5557"); zmq_setsockopt(subscriber, ZMQ_SUBSCRIBE, "", 0); while (1) { AdRequest request; zmq_recv(subscriber, &request, sizeof(request)); FeatureBatch* batch = batches[request.partner_id]; extract_features(&request, batch->features[batch->num_samples]); batch->num_samples++; if (batch->num_samples >= BATCH_SIZE) { // Send batch to trainer zmq_send(trainer_socket, batch, sizeof(FeatureBatch), 0); batch->num_samples = 0; } } } // Trainer thread void* trainer_thread(void* args) { void *context = zmq_ctx_new(); void *subscriber = zmq_socket(context, ZMQ_SUB); zmq_bind(subscriber, "tcp://*:5558"); while (1) { FeatureBatch batch; zmq_recv(subscriber, &batch, sizeof(FeatureBatch), 0); // Accumulate batch into training data AccumulateTrainingBatch(batch); // Periodically train model if (ShouldTrainModel()) { xgboost::DMatrix dmat; CreateDMatrix(&dmat); xgboost::Booster booster; booster.Configure(); booster.Update(dmat); xgboost::Model model = booster.GetModel(); SaveModel(model); } } } // Prediction thread void* predict_thread(void* args) { // Load model xgboost::Model model = LoadModel(); while (1) { AdRequest request; zmq_recv(predictor_socket, &request, sizeof(request)); FeatureBatch batch; extract_features(&request, batch.features); // Get prediction float pred = model.Predict(batch); // Return prediction zmq_send(results_socket, &pred, sizeof(float), 0); } } int main() { // Start partition threads for (int i=0; i<NUM_PARTNERS; i++) { pthread_create(&thread, NULL, ingest_thread, NULL); } // Start trainer thread pthread_create(&thread, NULL, trainer_thread, NULL); // Start predictor threads for (int i=0; i<NUM_PREDICTORS; i++) { pthread_create(&thread, NULL, predict_thread, NULL); } // Join threads for (int i=0; i<NUM_THREADS; i++) { pthread_join(threads[i], NULL); } }