Untitled
unknown
python
a year ago
988 B
16
Indexable
model.load_state_dict(torch.load("mnist_cnn.pt"))
model = torch.quantization.quantize_dynamic(model, dtype=torch.qint8)
weight1 = model.conv1.weight.tolist()
weight2 = model.conv2.weight.tolist()
weight3 = model.fc1.weight.T.tolist()
weight4 = model.fc2.weight.T.tolist()
@fhe.compiler({"x": "encrypted"})
def f(x):
x = fhe.conv(x, weight1, kernel_shape=(3, 3), strides=(1, 1))
x = fhe.relu(x)
x = fhe.conv(x, weight2, kernel_shape=(3, 3), strides=(1, 1))
x = fhe.relu(x)
x = fhe.maxpool(x, kernel_shape=(2, 2), strides=(2, 2))
x = x.reshape(x.shape[0], -1)
x = np.matmul(x, weight3) # pyright: ignore
x = fhe.relu(x)
x = np.matmul(x, weight4) # pyright: ignore
return x
inputset = [[[x.tolist()]] for x in dataset1.data[:1000]]
circuit = f.compile(inputset, relu_on_bits_threshold=9, use_gpu=True)
sample = [[np.array(dataset2.data[0])]]
print(np.linalg.norm(np.array(circuit.encrypt_run_decrypt(sample)).flatten()))Editor is loading...
Leave a Comment