main.rs

# Cargo.toml [package] name = "levenshtein" [dependencies] tfhe = { path = "../tfhe-rs/tfhe", features = ["boolean", "shortint", "integer", "x86_64-unix", "gpu"] }
mail@pastecode.io avatar
unknown
rust
a month ago
2.6 kB
2
Indexable
Never
extern crate tfhe;

use std::collections::HashMap;
use std::time::Instant;
use tfhe::prelude::*;
use tfhe::{set_server_key, ClientKey, CompressedServerKey, ConfigBuilder, FheUint8, FheUint16};
use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS;

struct Levenshtein {
    x: Vec<FheUint8>,
    y: Vec<FheUint8>,
    zero: FheUint16,
}

fn encode(c: u8) -> u8 {
    return c - 41u8;
}

impl Levenshtein {
    fn encrypt(x: &str, y: &str, client_key: &ClientKey) -> Self {
        let fhe_x = x
            .bytes()
            .map(|b| FheUint8::encrypt(encode(b), client_key))
            .collect();

        let fhe_y = y
            .bytes()
            .map(|b| FheUint8::encrypt(encode(b), client_key))
            .collect();

        let fhe_zero = FheUint16::encrypt(0u16, client_key);

        Self { x: fhe_x, y: fhe_y, zero: fhe_zero }
    }

    fn levenshtein(&self) -> FheUint16 {
        let mut memo = HashMap::new();
        self._levenshtein(&self.x, &self.y, &mut memo)
    }

    fn _levenshtein(
        &self,
        x: &[FheUint8],
        y: &[FheUint8],
        memo: &mut HashMap<(u16, u16), FheUint16>
    ) -> FheUint16 {
        let key = (x.len() as u16, y.len() as u16);
        if let Some(result) = memo.get(&key) {
            return result.clone();
        }

        let result = if x.is_empty() {
            &self.zero + y.len() as u16
        } else if y.is_empty() {
            &self.zero + x.len() as u16
        } else {
            let case_1 = self._levenshtein(&x[1..], y, memo);
            let case_2 = self._levenshtein(x, &y[1..], memo);
            let case_3 = self._levenshtein(&x[1..], &y[1..], memo);

            x[0].eq(&y[0]).select(&case_3, &(case_1.min(&case_2).min(&case_3) + 1))
        };

        memo.insert(key, result.clone());
        result
    }
}

fn main() {
    let config = ConfigBuilder::with_custom_parameters(PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, None).build();

    let client_key = ClientKey::generate(config);
    let compressed_server_key = CompressedServerKey::new(&client_key);

    let gpu_key = compressed_server_key.decompress_to_gpu();
    set_server_key(gpu_key);
 
    let now = Instant::now();
    let x = "ABCDEFGHIKLMNOPQRSTVXYZ";
    let y = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    let fhe_input = Levenshtein::encrypt(x, y, &client_key);
    let fhe_result = fhe_input.levenshtein();
    let result: u16 = fhe_result.decrypt(&client_key);
    println!("{}", result);
    println!("Execution: {}s", now.elapsed().as_secs())
}
Leave a Comment