Untitled

mail@pastecode.io avatar
unknown
plain_text
7 months ago
632 B
2
Indexable
Never
def rbf(x_1, x_2, sigma=1.):
    '''Computes rbf kernel for batches of objects

    Args:
        x_1: torch.tensor shaped `(#samples_1, #features)` of type torch.float32
        x_2: torch.tensor shaped `(#samples_1, #features)` of type torch.float32
    Returns:
        kernel function values for all pairs of samples from x_1 and x_2
        torch.tensor of type torch.float32 shaped `(#samples_1, #samples_2)`
    '''
    distances = torch.exp(- ((torch.sum(x_1 ** 2, dim=1) - 2 * x_2 @ x_1.T).T + torch.sum(x_2 ** 2, dim=1)) / 2 / sigma ** 2)### YOUR CODE HERE
    return torch.Tensor(distances).type(torch.float32)
Leave a Comment