Untitled

 avatar
unknown
python
a year ago
741 B
6
Indexable
model.load_state_dict(torch.load("mnist_cnn.pt"))

weight1 = model.conv1.weight.tolist()
weight2 = model.conv2.weight.tolist()
weight3 = model.fc1.weight.T.tolist()
bias3 = model.fc1.bias.tolist()
weight4 = model.fc2.weight.T.tolist()
bias4 = model.fc2.bias.tolist()

def f(x):
    x = fhe.conv(x, weight1, kernel_shape=(3, 3), strides=(1, 1))
    x = fhe.relu(x)
    x = fhe.conv(x, weight2, kernel_shape=(3, 3), strides=(1, 1))
    x = fhe.relu(x)
    x = fhe.maxpool(x, kernel_shape=(2, 2), strides=(2, 2))
    x = x.reshape(x.shape[0], -1)
    x = np.matmul(x, weight3) + bias3
    x = fhe.relu(x)
    x = np.matmul(x, weight4) + bias4
    return x

sample = np.array([[dataset2.data[0].tolist()]])
print(f(sample))
Editor is loading...
Leave a Comment