1003_matmul

 avatar
user_3093867
python
5 months ago
2.1 kB
2
Indexable
import onnxruntime as ort
import numpy as np
import time

def benchmark_onnxruntime(left, right, num_iterations=1):
    # Create an ONNX Runtime session
    providers = ['CUDAExecutionProvider']
    # providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
    sess = ort.InferenceSession("./saved_model/matmul_model.onnx", providers=providers)
    # Verify the provider being used
    print("Using provider:", sess.get_providers())
    print(f"Using device: {ort.get_device()}")


    # Get input and output names
    input_name_0 = sess.get_inputs()[0].name
    input_name_1 = sess.get_inputs()[1].name
    output_name = sess.get_outputs()[0].name

    # Warm-up run
    _ = sess.run([output_name], {input_name_0: left, input_name_1: right})[0]

    # Benchmark
    start_time = time.time()
    for _ in range(num_iterations):
        result = sess.run([output_name], {input_name_0: left, input_name_1: right})[0]
    end_time = time.time()

    return (end_time - start_time) / num_iterations

# Generate random input data
left = np.random.randn(1,1024, 1024).astype(np.float32)
right = np.random.randn(1024, 1024).astype(np.float32)



# Run benchmark
avg_time = benchmark_onnxruntime(left, right)
print(f"ONNX Runtime average execution time: {avg_time:.6f} seconds")

# Print shapes for verification
print("Left input shape:", left.shape)
print("Right input shape:", right.shape)

# Print a few values from the inputs
print("\nLeft input (first few values):")
print(left.flatten()[:10])
print("\nRight input (first few values):")
print(right.flatten()[:10])

# Perform a single run to get the output shape and values
sess = ort.InferenceSession("matmul_model.onnx", providers=['CUDAExecutionProvider'])
print("Current provider:", sess.get_providers())
input_name_0 = sess.get_inputs()[0].name
input_name_1 = sess.get_inputs()[1].name
output_name = sess.get_outputs()[0].name
result = sess.run([output_name], {input_name_0: left, input_name_1: right})[0]

print("\nOutput shape:", result.shape)
print("\nOutput (first few values):")
print(result.flatten()[:10])
Editor is loading...
Leave a Comment