Untitled
user_5573880
python
a year ago
4.2 kB
14
Indexable
import onnx
# Load the ONNX model
model = onnx.load("bidaf-11-int8.onnx")
# Create a dictionary to track which node produces each output tensor
output_to_node_map = {}
input_to_node_map = {}
for node in model.graph.node:
for output in node.output:
output_to_node_map[output] = node
for input_name in node.input:
if input_name not in input_to_node_map:
input_to_node_map[input_name] = []
input_to_node_map[input_name].append(node)
# Define the types of elementwise operations
elementwise_ops = {
"Add", "Sub", "Mul", "Div", # Basic arithmetic
"Relu", "Sigmoid", "Tanh", "LeakyRelu", "Softmax", "HardSigmoid", # Activation functions
"Greater", "Less", "Equal", "And", "Or", "Not", "Xor", # Logical operations
"Abs", "Neg", "Reciprocal", "Sqrt", "Exp", "Log", "Pow", # Unary math functions
"Sin", "Cos", "Tan", "Ceil", "Floor", "Clip", "Round", "Sign", # More math functions
"Max", "Min", "Mean", "Sum", "Square", "Erf" # Aggregation and other ops
}
# Function to find subsequent operations (up to 5)
def find_following_ops(output_name, max_ops=5, found_ops=0):
following_ops = []
elementwise_count = 0
# Find the next node that uses this output as input
if output_name in input_to_node_map:
for node in input_to_node_map[output_name]:
following_ops.append(node.op_type)
found_ops += 1
if node.op_type in elementwise_ops:
elementwise_count += 1
# Stop searching if we've already found 5 operations
if found_ops >= max_ops:
break
# Recursively search for the next operations
for next_output in node.output:
next_ops, next_elementwise_count, found_ops = find_following_ops(next_output, max_ops, found_ops)
following_ops.extend(next_ops)
elementwise_count += next_elementwise_count
if found_ops >= max_ops:
break
break
return following_ops, elementwise_count, found_ops
# Function to find preceding operations (up to 5)
def find_previous_ops(input_name, max_ops=5, found_ops=0):
previous_ops = []
if input_name in output_to_node_map:
node = output_to_node_map[input_name]
previous_ops.append(node.op_type)
found_ops += 1
if found_ops >= max_ops:
return previous_ops, found_ops
# Recursively search for the preceding operations
for node_input in node.input:
prev_ops, found_ops = find_previous_ops(node_input, max_ops, found_ops)
previous_ops.extend(prev_ops)
if found_ops >= max_ops:
break
return previous_ops, found_ops
# Count the number of MatMulInteger operations and total operations
matmulinteger_count = 0
total_op_count = 0
# Traverse through the model nodes, find MatMulInteger, and print preceding and following operations (up to 5)
for node in model.graph.node:
total_op_count += 1
if node.op_type == "MatMulInteger":
matmulinteger_count += 1
print(f"MatMulInteger Node: {node.name}")
# Find previous operations
previous_ops = []
for input_name in node.input:
prev_ops, _ = find_previous_ops(input_name)
previous_ops.extend(prev_ops)
print(f" Previous operations (up to 5): {previous_ops[:5]}")
# Find following operations
following_ops = []
elementwise_count = 0
for output_name in node.output:
next_ops, next_elementwise_count, _ = find_following_ops(output_name)
following_ops.extend(next_ops)
elementwise_count += next_elementwise_count
print(f" Following operations (up to 5): {following_ops[:5]}\n")
# print(f" Number of elementwise operations: {elementwise_count}")
# Print the final statistics
print(f"Total number of MatMulInteger operations: {matmulinteger_count}")
print(f"Total number of operations in the model: {total_op_count}")
Editor is loading...
Leave a Comment