Untitled
unknown
plain_text
2 years ago
2.4 kB
6
Indexable
import spconv
import torch
import torch.nn as nn
import timeit
import open3d as o3d
import numpy as np
import matplotlib.pyplot as plt
def generate_unique_coordinates(voxel_number, spatial_shape):
max_coordinate = max(spatial_shape)
unique_coordinates = set()
while len(unique_coordinates) < voxel_number:
coord = tuple(torch.randint(low=0, high=max_coordinate, size=(3,)).tolist())
unique_coordinates.add(coord)
return torch.tensor(list(unique_coordinates), dtype=torch.int32)
def benchmark(voxel_number, kernel_size, spatial_shape, num_runs):
# Create input sparse tensor
features = torch.ones((voxel_number, 1), dtype=torch.float32)
coordinates = torch.cat((torch.zeros((voxel_number, 1), dtype=torch.int32), generate_unique_coordinates(voxel_number, spatial_shape)), dim=1)
input_tensor = spconv.SparseConvTensor(features, coordinates, torch.Size(spatial_shape), batch_size=1)
# Create convolution layer with spconv
conv_layer_spconv = spconv.SparseConv3d(1, 16, kernel_size, bias=False)
conv_layer_spconv.weight.data.fill_(1)
# Create convolution layer with pytorch
conv_layer_pytorch = nn.Conv3d(1, 16, kernel_size, bias=False)
conv_layer_pytorch.weight.data.fill_(1)
# Time the convolution operation with spconv
time_spconv = timeit.timeit(
lambda: conv_layer_spconv(input_tensor),
number=num_runs
)
# Time the convolution operation with pytorch
time_pytorch = timeit.timeit(
lambda: conv_layer_pytorch(torch.ones((1, 1, *spatial_shape), dtype=torch.float32)),
number=num_runs
)
return time_spconv / num_runs, time_pytorch / num_runs
def plot_benchmark_time_vs_sparsity(kernel_size, spatial_shape, num_runs, sparsity_level):
voxel_number = int(sparsity_level * np.prod(spatial_shape))
time_spconv, time_pytorch = benchmark(voxel_number, kernel_size, spatial_shape, num_runs)
# Print the time
print(f"Time for sparsity level {sparsity_level * 100}% with spconv: {time_spconv} microseconds")
print(f"Time for sparsity level {sparsity_level * 100}% with pytorch: {time_pytorch} microseconds")
if __name__ == '__main__':
kernel_size = (3, 3, 3)
spatial_shape = (1408, 1600, 40)
num_runs = 10
sparsity_level = 0.03 # 3%
plot_benchmark_time_vs_sparsity(kernel_size, spatial_shape, num_runs, sparsity_level)
Editor is loading...