Untitled

mail@pastecode.io avatar
unknown
plain_text
a month ago
1.7 kB
8
Indexable
Never
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use rayon::prelude::*;
use std::time::Instant;
use tfhe::prelude::*;
use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS;
use tfhe::{set_server_key, ClientKey, CompressedServerKey, ConfigBuilder, FheInt8, FheUint8};

fn similarity(a: &Vec<FheUint8>, clear_b: &Vec<u8>) -> FheInt8 {
    let c: Vec<FheUint8> = a
        .par_iter()
        .zip(clear_b.par_iter())
        .map(|(x, y)| FheUint8::cast_from(x.eq(y.clone())))
        .collect();
    FheInt8::cast_from(c.iter().sum::<FheUint8>() << 1u8) - (c.len() as i8)
}

fn main() {
    let config =
        ConfigBuilder::with_custom_parameters(PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS)
            .build();
    let client_key = ClientKey::generate(config);
    let compressed_server_key = CompressedServerKey::new(&client_key);
    let gpu_key = || compressed_server_key.decompress_to_gpu();

    let mut rng = StdRng::seed_from_u64(0);
    let clear_a: Vec<u8> = (0..200).map(|_| rng.gen_range(0..2)).collect();
    let clear_b: Vec<u8> = (0..200).map(|_| rng.gen_range(0..2)).collect();

    let a = clear_a
        .iter()
        .map(|x| FheUint8::encrypt(x.clone(), &client_key))
        .collect();

    // Server-side
    let time = Instant::now();
    rayon::broadcast(|_| set_server_key(gpu_key()));
    set_server_key(gpu_key());
    let result = similarity(&a, &clear_b);
    let elapsed = time.elapsed();

    // Client-side
    let decrypted_result: i8 = FheInt8::decrypt(&result, &client_key);
    println!("Similarity: {:}", decrypted_result);
    println!("Elapsed: {:?}", elapsed);
}
Leave a Comment