Untitled

mail@pastecode.io avatar
unknown
plain_text
a year ago
2.4 kB
3
Indexable
Never
import spconv
import torch
import torch.nn as nn
import timeit
import open3d as o3d
import numpy as np
import matplotlib.pyplot as plt

def generate_unique_coordinates(voxel_number, spatial_shape):
    max_coordinate = max(spatial_shape)
    unique_coordinates = set()

    while len(unique_coordinates) < voxel_number:
        coord = tuple(torch.randint(low=0, high=max_coordinate, size=(3,)).tolist())
        unique_coordinates.add(coord)

    return torch.tensor(list(unique_coordinates), dtype=torch.int32)


def benchmark(voxel_number, kernel_size, spatial_shape, num_runs):
    # Create input sparse tensor
    features = torch.ones((voxel_number, 1), dtype=torch.float32)
    coordinates = torch.cat((torch.zeros((voxel_number, 1), dtype=torch.int32), generate_unique_coordinates(voxel_number, spatial_shape)), dim=1)
    input_tensor = spconv.SparseConvTensor(features, coordinates, torch.Size(spatial_shape), batch_size=1)

    # Create convolution layer with spconv
    conv_layer_spconv = spconv.SparseConv3d(1, 16, kernel_size, bias=False)
    conv_layer_spconv.weight.data.fill_(1)
    # Create convolution layer with pytorch
    conv_layer_pytorch = nn.Conv3d(1, 16, kernel_size, bias=False)
    conv_layer_pytorch.weight.data.fill_(1)

    # Time the convolution operation with spconv
    time_spconv = timeit.timeit(
        lambda: conv_layer_spconv(input_tensor),
        number=num_runs
    )

    # Time the convolution operation with pytorch
    time_pytorch = timeit.timeit(
        lambda: conv_layer_pytorch(torch.ones((1, 1, *spatial_shape), dtype=torch.float32)),
        number=num_runs
    )

    return time_spconv / num_runs, time_pytorch / num_runs


def plot_benchmark_time_vs_sparsity(kernel_size, spatial_shape, num_runs, sparsity_level):
    voxel_number = int(sparsity_level * np.prod(spatial_shape))
    time_spconv, time_pytorch = benchmark(voxel_number, kernel_size, spatial_shape, num_runs)

    # Print the time
    print(f"Time for sparsity level {sparsity_level * 100}% with spconv: {time_spconv} microseconds")
    print(f"Time for sparsity level {sparsity_level * 100}% with pytorch: {time_pytorch} microseconds")


if __name__ == '__main__':
    kernel_size = (3, 3, 3)
    spatial_shape = (1408, 1600, 40)
    num_runs = 10
    sparsity_level = 0.03  # 3%
    plot_benchmark_time_vs_sparsity(kernel_size, spatial_shape, num_runs, sparsity_level)