Untitled

mail@pastecode.io avatar
unknown
python
3 years ago
1.8 kB
2
Indexable
Never
def histnd(points: Tensor, size: Tuple[int], weights: Tensor = None, align_corners=True):
    dev = points.device

    # Filter points
    min_c, max_c = points.min(1)[0], points.max(1)[0]
    points = points[(min_c >= 0.0) * (max_c <= 1.0), :]
    npoints, ndims = points.shape
    if npoints == 0:
        return torch.zeros(size, device=dev, requires_grad=True)
    points = torch.clamp(points, 0., 1. - 1e-5)
    
    # Align points with grid
    grid_size = torch.tensor(size, device=dev)
    padded_size = grid_size + 2
    grid = torch.zeros(padded_size.tolist(), device=dev)
    if align_corners:
        points = points - 0.5 / grid_size
    points = points * grid_size
    
    # Init weights
    if weights is None:
        weights = torch.ones(npoints, device=dev)

    # Duplicate points along each dimension for N-Linear interpolation
    I = torch.floor(points).long()
    for dim in range(ndims):
        points = torch.cat([points] * 2)
        weights = torch.cat([weights] * 2)
        deltas = torch.zeros(ndims, device=dev, dtype=torch.long)
        deltas[dim] = 1
        I = torch.cat([I, I + deltas])

    # Compute contributes
    D = weights * torch.prod(1 - torch.abs(points - I), dim=1)
    
    # Scatter sum all contributes into grid
    # U, inv = torch.unique(I, dim=0, return_inverse=True)
    # U += 1
    U = torch.stack(torch.meshgrid(*[torch.arange(x) for x in grid.shape]), -1).reshape(-1, ndims)
    inv = sum((I[:, i] + 1) * torch.tensor(padded_size[i+1:]).prod() for i in range(ndims)).long()
    E = torch.zeros(len(U), device=dev)
    E.index_add_(0, inv, D)

    grid[tuple(U[:, i] for i in range(ndims))] += E

    # Un-pad grid
    for i in range(ndims):
        grid = torch.narrow(grid, i, 1, grid.shape[i] - 2)

    return grid