Untitled

 avatar
unknown
python
a year ago
660 B
4
Indexable
model.load_state_dict(torch.load("mnist_cnn.pt"))
weight1 = model.conv1.weight.int().tolist()
weight2 = model.conv2.weight.int().tolist()

@fhe.compiler({"x": "encrypted"})
def f(x):
    x = fhe.conv(x, weight1, kernel_shape=(3, 3), strides=(1, 1))
    x = fhe.relu(x)
    x = fhe.conv(x, weight2, kernel_shape=(3, 3), strides=(1, 1))
    x = fhe.relu(x)
    return x

inputset = [np.random.randint(0, 255, size=(1, 1, 28, 28)) for _ in range(10)]
circuit = f.compile(inputset, relu_on_bits_threshold=9, use_gpu=True)

sample = np.random.randint(0, 255, size=(1, 1, 28, 28))
assert np.array_equal(circuit.encrypt_run_decrypt(sample), f(sample))
Editor is loading...
Leave a Comment