Untitled
unknown
plain_text
2 years ago
632 B
9
Indexable
def rbf(x_1, x_2, sigma=1.):
'''Computes rbf kernel for batches of objects
Args:
x_1: torch.tensor shaped `(#samples_1, #features)` of type torch.float32
x_2: torch.tensor shaped `(#samples_1, #features)` of type torch.float32
Returns:
kernel function values for all pairs of samples from x_1 and x_2
torch.tensor of type torch.float32 shaped `(#samples_1, #samples_2)`
'''
distances = torch.exp(- ((torch.sum(x_1 ** 2, dim=1) - 2 * x_2 @ x_1.T).T + torch.sum(x_2 ** 2, dim=1)) / 2 / sigma ** 2)### YOUR CODE HERE
return torch.Tensor(distances).type(torch.float32)Editor is loading...
Leave a Comment