Untitled
unknown
python
4 years ago
1.8 kB
8
Indexable
def histnd(points: Tensor, size: Tuple[int], weights: Tensor = None, align_corners=True): dev = points.device # Filter points min_c, max_c = points.min(1)[0], points.max(1)[0] points = points[(min_c >= 0.0) * (max_c <= 1.0), :] npoints, ndims = points.shape if npoints == 0: return torch.zeros(size, device=dev, requires_grad=True) points = torch.clamp(points, 0., 1. - 1e-5) # Align points with grid grid_size = torch.tensor(size, device=dev) padded_size = grid_size + 2 grid = torch.zeros(padded_size.tolist(), device=dev) if align_corners: points = points - 0.5 / grid_size points = points * grid_size # Init weights if weights is None: weights = torch.ones(npoints, device=dev) # Duplicate points along each dimension for N-Linear interpolation I = torch.floor(points).long() for dim in range(ndims): points = torch.cat([points] * 2) weights = torch.cat([weights] * 2) deltas = torch.zeros(ndims, device=dev, dtype=torch.long) deltas[dim] = 1 I = torch.cat([I, I + deltas]) # Compute contributes D = weights * torch.prod(1 - torch.abs(points - I), dim=1) # Scatter sum all contributes into grid # U, inv = torch.unique(I, dim=0, return_inverse=True) # U += 1 U = torch.stack(torch.meshgrid(*[torch.arange(x) for x in grid.shape]), -1).reshape(-1, ndims) inv = sum((I[:, i] + 1) * torch.tensor(padded_size[i+1:]).prod() for i in range(ndims)).long() E = torch.zeros(len(U), device=dev) E.index_add_(0, inv, D) grid[tuple(U[:, i] for i in range(ndims))] += E # Un-pad grid for i in range(ndims): grid = torch.narrow(grid, i, 1, grid.shape[i] - 2) return grid
Editor is loading...