Untitled

 avatar
user_5573880
python
12 days ago
4.2 kB
4
Indexable
Never
import onnx

# Load the ONNX model
model = onnx.load("bidaf-11-int8.onnx")

# Create a dictionary to track which node produces each output tensor
output_to_node_map = {}
input_to_node_map = {}
for node in model.graph.node:
    for output in node.output:
        output_to_node_map[output] = node
    for input_name in node.input:
        if input_name not in input_to_node_map:
            input_to_node_map[input_name] = []
        input_to_node_map[input_name].append(node)

# Define the types of elementwise operations
elementwise_ops = {
    "Add", "Sub", "Mul", "Div",            # Basic arithmetic
    "Relu", "Sigmoid", "Tanh", "LeakyRelu", "Softmax", "HardSigmoid",  # Activation functions
    "Greater", "Less", "Equal", "And", "Or", "Not", "Xor",             # Logical operations
    "Abs", "Neg", "Reciprocal", "Sqrt", "Exp", "Log", "Pow",           # Unary math functions
    "Sin", "Cos", "Tan", "Ceil", "Floor", "Clip", "Round", "Sign",     # More math functions
    "Max", "Min", "Mean", "Sum", "Square", "Erf"                       # Aggregation and other ops
}

# Function to find subsequent operations (up to 5)
def find_following_ops(output_name, max_ops=5, found_ops=0):
    following_ops = []
    elementwise_count = 0
    
    # Find the next node that uses this output as input
    if output_name in input_to_node_map:
        for node in input_to_node_map[output_name]:
            following_ops.append(node.op_type)
            found_ops += 1
            if node.op_type in elementwise_ops:
                elementwise_count += 1

            # Stop searching if we've already found 5 operations
            if found_ops >= max_ops:
                break

            # Recursively search for the next operations
            for next_output in node.output:
                next_ops, next_elementwise_count, found_ops = find_following_ops(next_output, max_ops, found_ops)
                following_ops.extend(next_ops)
                elementwise_count += next_elementwise_count
                if found_ops >= max_ops:
                    break
            break
    return following_ops, elementwise_count, found_ops

# Function to find preceding operations (up to 5)
def find_previous_ops(input_name, max_ops=5, found_ops=0):
    previous_ops = []
    if input_name in output_to_node_map:
        node = output_to_node_map[input_name]
        previous_ops.append(node.op_type)
        found_ops += 1
        if found_ops >= max_ops:
            return previous_ops, found_ops
        # Recursively search for the preceding operations
        for node_input in node.input:
            prev_ops, found_ops = find_previous_ops(node_input, max_ops, found_ops)
            previous_ops.extend(prev_ops)
            if found_ops >= max_ops:
                break
    return previous_ops, found_ops

# Count the number of MatMulInteger operations and total operations
matmulinteger_count = 0
total_op_count = 0

# Traverse through the model nodes, find MatMulInteger, and print preceding and following operations (up to 5)
for node in model.graph.node:
    total_op_count += 1
    if node.op_type == "MatMulInteger":
        matmulinteger_count += 1
        print(f"MatMulInteger Node: {node.name}")
        
        # Find previous operations
        previous_ops = []
        for input_name in node.input:
            prev_ops, _ = find_previous_ops(input_name)
            previous_ops.extend(prev_ops)
        print(f"  Previous operations (up to 5): {previous_ops[:5]}")
        
        # Find following operations
        following_ops = []
        elementwise_count = 0
        for output_name in node.output:
            next_ops, next_elementwise_count, _ = find_following_ops(output_name)
            following_ops.extend(next_ops)
            elementwise_count += next_elementwise_count
        print(f"  Following operations (up to 5): {following_ops[:5]}\n")
        # print(f"  Number of elementwise operations: {elementwise_count}")

# Print the final statistics
print(f"Total number of MatMulInteger operations: {matmulinteger_count}")
print(f"Total number of operations in the model: {total_op_count}")
Leave a Comment