Untitled
user_5573880
python
12 days ago
4.2 kB
4
Indexable
Never
import onnx # Load the ONNX model model = onnx.load("bidaf-11-int8.onnx") # Create a dictionary to track which node produces each output tensor output_to_node_map = {} input_to_node_map = {} for node in model.graph.node: for output in node.output: output_to_node_map[output] = node for input_name in node.input: if input_name not in input_to_node_map: input_to_node_map[input_name] = [] input_to_node_map[input_name].append(node) # Define the types of elementwise operations elementwise_ops = { "Add", "Sub", "Mul", "Div", # Basic arithmetic "Relu", "Sigmoid", "Tanh", "LeakyRelu", "Softmax", "HardSigmoid", # Activation functions "Greater", "Less", "Equal", "And", "Or", "Not", "Xor", # Logical operations "Abs", "Neg", "Reciprocal", "Sqrt", "Exp", "Log", "Pow", # Unary math functions "Sin", "Cos", "Tan", "Ceil", "Floor", "Clip", "Round", "Sign", # More math functions "Max", "Min", "Mean", "Sum", "Square", "Erf" # Aggregation and other ops } # Function to find subsequent operations (up to 5) def find_following_ops(output_name, max_ops=5, found_ops=0): following_ops = [] elementwise_count = 0 # Find the next node that uses this output as input if output_name in input_to_node_map: for node in input_to_node_map[output_name]: following_ops.append(node.op_type) found_ops += 1 if node.op_type in elementwise_ops: elementwise_count += 1 # Stop searching if we've already found 5 operations if found_ops >= max_ops: break # Recursively search for the next operations for next_output in node.output: next_ops, next_elementwise_count, found_ops = find_following_ops(next_output, max_ops, found_ops) following_ops.extend(next_ops) elementwise_count += next_elementwise_count if found_ops >= max_ops: break break return following_ops, elementwise_count, found_ops # Function to find preceding operations (up to 5) def find_previous_ops(input_name, max_ops=5, found_ops=0): previous_ops = [] if input_name in output_to_node_map: node = output_to_node_map[input_name] previous_ops.append(node.op_type) found_ops += 1 if found_ops >= max_ops: return previous_ops, found_ops # Recursively search for the preceding operations for node_input in node.input: prev_ops, found_ops = find_previous_ops(node_input, max_ops, found_ops) previous_ops.extend(prev_ops) if found_ops >= max_ops: break return previous_ops, found_ops # Count the number of MatMulInteger operations and total operations matmulinteger_count = 0 total_op_count = 0 # Traverse through the model nodes, find MatMulInteger, and print preceding and following operations (up to 5) for node in model.graph.node: total_op_count += 1 if node.op_type == "MatMulInteger": matmulinteger_count += 1 print(f"MatMulInteger Node: {node.name}") # Find previous operations previous_ops = [] for input_name in node.input: prev_ops, _ = find_previous_ops(input_name) previous_ops.extend(prev_ops) print(f" Previous operations (up to 5): {previous_ops[:5]}") # Find following operations following_ops = [] elementwise_count = 0 for output_name in node.output: next_ops, next_elementwise_count, _ = find_following_ops(output_name) following_ops.extend(next_ops) elementwise_count += next_elementwise_count print(f" Following operations (up to 5): {following_ops[:5]}\n") # print(f" Number of elementwise operations: {elementwise_count}") # Print the final statistics print(f"Total number of MatMulInteger operations: {matmulinteger_count}") print(f"Total number of operations in the model: {total_op_count}")
Leave a Comment