optimal_batch inference

 avatar
unknown
c_cpp
5 months ago
4.3 kB
9
Indexable
#include <torch/script.h> // One-stop header.
#include <iostream>
#include <chrono>
#include <vector>
#include <ATen/Parallel.h> // Required for get_num_threads and set_num_threads
#include <random> // Required for random number generation
// Function to perform batch inference
std::vector<float> batch_infer_distance(torch::jit::Module& model, const std::vector<std::vector<float>>& point_pairs) {
    int batch_size = point_pairs.size();
    torch::Tensor input_tensor = torch::empty({batch_size, 4});
    for (int i = 0; i < batch_size; ++i) {
        input_tensor[i] = torch::tensor(point_pairs[i]);
    }
    
    // Perform inference
    torch::Tensor output = model.forward({input_tensor}).toTensor();
    
    // Extract the results
    std::vector<float> distances(output.data_ptr<float>(), output.data_ptr<float>() + output.numel());
    return distances;
}


// Function to generate random point pairs
std::vector<std::vector<float>> generate_random_points(int batch_size) {
    std::vector<std::vector<float>> point_pairs;
    point_pairs.reserve(batch_size);

    // Set up random number generation
    std::random_device rd;  // Seed for the random number engine
    std::mt19937 gen(rd()); // Mersenne Twister engine
    std::uniform_real_distribution<> dis(0.0, 100.0); // Uniform distribution between 0 and 100

    // Generate random point pairs
    for (int i = 0; i < batch_size; ++i) {
        float x1 = dis(gen);
        float y1 = dis(gen);
        float x2 = dis(gen);
        float y2 = dis(gen);
        point_pairs.push_back({x1, y1, x2, y2});
    }

    return point_pairs;
}

int main(int argc, const char* argv[]) {

    int num_threads = at::get_num_threads();
    
    // Print the number of threads
    std::cout << "Current number of threads: " << num_threads << std::endl;



    torch::jit::Module model;
    try {
        model = torch::jit::load("distance_model_traced.pt");
    } catch (const c10::Error& e) {
        std::cerr << "Error loading the model\n";
        return -1;
    }
    
    std::cout << "Model loaded successfully.\n";
    
    // Varying batch sizes
    std::vector<int> batch_sizes = {1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024}; // Example batch sizes to test
    
    // working
    // for (int batch_size : batch_sizes) {
    //     // Generate example input for this batch size
    //     //std::vector<std::vector<float>> point_pairs(batch_size, {10.0, 20.0, 30.0, 40.0});
    //     std::vector<std::vector<float>> point_pairs = generate_random_points(batch_size);
    //     // Start timing
    //     auto start = std::chrono::high_resolution_clock::now();
        
    //     // Perform inference
    //     batch_infer_distance(model, point_pairs);
        
    //     // End timing
    //     auto end = std::chrono::high_resolution_clock::now();
    //     std::chrono::duration<double> duration = end - start;
        
    //     std::cout << "Batch Size: " << batch_size << ", Time: " << duration.count() << " seconds" << std::endl;
    // }
   
    for (int batch_size : batch_sizes) {
        double total_time = 0.0;
        int trial_num = 10000;

        // Run 4 trials for each batch size
        for (int trial = 0; trial < trial_num; ++trial) {
            // Generate random input points for this batch size
            std::vector<std::vector<float>> point_pairs = generate_random_points(batch_size);
            
            // Start timing
            auto start = std::chrono::high_resolution_clock::now();
            
            // Perform inference
            batch_infer_distance(model, point_pairs);
            
            // End timing
            auto end = std::chrono::high_resolution_clock::now();
            std::chrono::duration<double> duration = end - start;
            
            // Accumulate the time for this trial
            total_time += duration.count();
        }
        
        // Calculate the average time over 4 trials
        double average_time = total_time / trial_num;
        double throughput = batch_size / average_time;
        // Print the average time for the batch size
        std::cout << "Batch Size: " << batch_size
                  << ", Average Time: " << average_time << " seconds"
                  << ", Throughput: " << throughput << " samples/second" << std::endl;
    }


    return 0;
}







Editor is loading...
Leave a Comment