# Untitled

unknown
python
3 years ago
1.8 kB
2
Indexable
Never
```def histnd(points: Tensor, size: Tuple[int], weights: Tensor = None, align_corners=True):
dev = points.device

# Filter points
min_c, max_c = points.min(1)[0], points.max(1)[0]
points = points[(min_c >= 0.0) * (max_c <= 1.0), :]
npoints, ndims = points.shape
if npoints == 0:
points = torch.clamp(points, 0., 1. - 1e-5)

# Align points with grid
grid_size = torch.tensor(size, device=dev)
padded_size = grid_size + 2
grid = torch.zeros(padded_size.tolist(), device=dev)
if align_corners:
points = points - 0.5 / grid_size
points = points * grid_size

# Init weights
if weights is None:
weights = torch.ones(npoints, device=dev)

# Duplicate points along each dimension for N-Linear interpolation
I = torch.floor(points).long()
for dim in range(ndims):
points = torch.cat([points] * 2)
weights = torch.cat([weights] * 2)
deltas = torch.zeros(ndims, device=dev, dtype=torch.long)
deltas[dim] = 1
I = torch.cat([I, I + deltas])

# Compute contributes
D = weights * torch.prod(1 - torch.abs(points - I), dim=1)

# Scatter sum all contributes into grid
# U, inv = torch.unique(I, dim=0, return_inverse=True)
# U += 1
U = torch.stack(torch.meshgrid(*[torch.arange(x) for x in grid.shape]), -1).reshape(-1, ndims)
inv = sum((I[:, i] + 1) * torch.tensor(padded_size[i+1:]).prod() for i in range(ndims)).long()
E = torch.zeros(len(U), device=dev)