Untitled
user_5573880
python
a year ago
9.9 kB
4
Indexable
import onnx
# Load the ONNX model
model = onnx.load("bidaf-11-int8.onnx")
# Create a dictionary to track which node produces each output tensor
output_to_node_map = {}
input_to_node_map = {}
for node in model.graph.node:
for output in node.output:
output_to_node_map[output] = node
for input_name in node.input:
if input_name not in input_to_node_map:
input_to_node_map[input_name] = []
input_to_node_map[input_name].append(node)
# Define the stopping operations, including MatMulInteger
stopping_ops = {"Transpose", "Unsqueeze", "Reshape", "Squeeze","CategoryMapper", "MatMulInteger"}
# Define the types of elementwise operations
elementwise_ops = {
"Add", "Sub", "Mul", "Div", # Basic arithmetic
"Relu", "Sigmoid", "Tanh", "LeakyRelu", "Softmax", "HardSigmoid", # Activation functions
"Greater", "Less", "Equal", "And", "Or", "Not", "Xor", # Logical operations
"Abs", "Neg", "Reciprocal", "Sqrt", "Exp", "Log", "Pow", # Unary math functions
"Sin", "Cos", "Tan", "Ceil", "Floor", "Clip", "Round", "Sign", # More math functions
"Max", "Min", "Mean", "Sum", "Square", "Erf" # Aggregation and other ops
}
# Function to find previous operations along each input path until a stopping operation is encountered
def find_previous_ops(input_name, visited_nodes, elementwise_visited, skip_node_names):
all_paths = []
current_path = []
elementwise_count = 0
if input_name in output_to_node_map:
node = output_to_node_map[input_name]
if node.name in visited_nodes or node.name in skip_node_names:
return [], elementwise_count # Skip already visited nodes or nodes to be avoided
visited_nodes.add(node.name)
# Stop if this node is a stopping operation
if node.op_type in stopping_ops:
return [], elementwise_count # Do not add stopping op to path
current_path.append(node.op_type)
# Only count elementwise operations if they haven't been visited in this path yet
if node.op_type in elementwise_ops and node.name not in elementwise_visited:
elementwise_visited.add(node.name)
elementwise_count += 1
# Recursively search for each input path, fully traversing each path until stopping
for node_input in node.input:
if node_input in output_to_node_map:
sub_paths, sub_elementwise_count = find_previous_ops(node_input, visited_nodes, elementwise_visited, skip_node_names)
for path in sub_paths:
all_paths.append(current_path + path)
elementwise_count += sub_elementwise_count
if not all_paths: # If no sub-paths were found, the current path itself is a valid path
all_paths.append(current_path)
return all_paths, elementwise_count
# Function to remove duplicate paths and filter out empty paths
def filter_duplicate_and_empty_paths(paths):
filtered_paths = []
seen_paths = set()
for path in paths:
path_tuple = tuple(path) # Convert list to tuple so it can be added to set
if path and path_tuple not in seen_paths: # Check if path is not empty and not seen
seen_paths.add(path_tuple)
filtered_paths.append(path)
return filtered_paths
# Function to find following operations until a stopping operation is encountered
# Stops when node has multiple output paths (even if only one output node)
# Each node's inputs are also traversed to find their respective input paths until stopping, avoiding specific nodes
def find_following_ops(output_name, visited_nodes, elementwise_visited, skip_node_names):
all_following_paths = []
current_following_path = []
elementwise_count = 0
# Find the next node that uses this output as input
if output_name in input_to_node_map:
for node in input_to_node_map[output_name]:
if node.name in visited_nodes:
continue # Skip already visited nodes to avoid duplication
visited_nodes.add(node.name)
# Stop if this node is a stopping operation (including MatMulInteger)
if node.op_type in stopping_ops:
return [], elementwise_count # Do not add stopping op to path
current_following_path.append(node.op_type)
# Only count elementwise operations if they haven't been visited yet
if node.op_type in elementwise_ops and node.name not in elementwise_visited:
elementwise_visited.add(node.name)
elementwise_count += 1
# Traverse the current node's input paths until stopping, avoiding specific previous nodes
input_paths_found = False
for node_input in node.input:
input_visited_nodes = set()
# Skip the node from which we just came in the following path and MatMulInteger nodes
skip_node_names.add(node.name) # Avoid the current node in the following path
input_paths, input_elementwise_count = find_previous_ops(node_input, input_visited_nodes, elementwise_visited, skip_node_names)
filtered_input_paths = filter_duplicate_and_empty_paths(input_paths)
# Only print if we find any valid input paths and count elementwise operations
for j, input_path in enumerate(filtered_input_paths, 1):
input_paths_found = True
print(f" Input Path {j} for Node ({node.op_type}): {input_path}")
elementwise_count += input_elementwise_count
# Check if the node has multiple output paths (even if only one output node)
total_output_connections = sum(len(input_to_node_map.get(output, [])) for output in node.output)
if total_output_connections > 1:
return [current_following_path], elementwise_count
# Recursively search for the next operations
for next_output in node.output:
sub_paths, sub_elementwise_count = find_following_ops(next_output, visited_nodes, elementwise_visited, skip_node_names)
for path in sub_paths:
all_following_paths.append(current_following_path + path)
elementwise_count += sub_elementwise_count
if not all_following_paths: # If no sub-paths were found, the current path itself is a valid path
all_following_paths.append(current_following_path)
return all_following_paths, elementwise_count
# Count the number of MatMulInteger operations and total operations
matmulinteger_count = 0
total_op_count = 0
# Traverse through the model nodes, find MatMulInteger, and print preceding and following operations
for node in model.graph.node:
total_op_count += 1
if node.op_type == "MatMulInteger":
matmulinteger_count += 1
print(f"\nMatMulInteger Node: {node.name}")
# Find previous operations for each input path, traversing fully along each path
all_previous_ops = []
total_elementwise_count = 0
path_counter = 1
elementwise_visited = set() # Track visited elementwise nodes to avoid counting duplicates
previous_op_names = {node.name} # Initialize with MatMulInteger's own node.name
for input_name in node.input:
visited_nodes_for_input = set() # Separate visited set for each input path
previous_paths, prev_elementwise_count = find_previous_ops(input_name, visited_nodes_for_input, elementwise_visited, set())
all_previous_ops.extend(previous_paths)
total_elementwise_count += prev_elementwise_count
# Record node names from previous paths
for visited_node in visited_nodes_for_input:
previous_op_names.add(visited_node)
# Filter out empty and duplicate paths
filtered_paths = filter_duplicate_and_empty_paths(all_previous_ops)
# Print each path separately after filtering
for i, path in enumerate(filtered_paths, 1):
print(f" Previous Path {i}: {path}")
print(f"------------------------------------------------------------------------------------------------------------------")
# Find following operations, traversing fully along each path until stopping
all_following_ops = []
visited_nodes = set() # Reset visited nodes for forward search
following_elementwise_visited = set() # Track visited elementwise nodes in following ops
for output_name in node.output:
following_paths, next_elementwise_count = find_following_ops(output_name, visited_nodes, following_elementwise_visited, previous_op_names)
all_following_ops.extend(following_paths)
total_elementwise_count += next_elementwise_count
# Filter out empty and duplicate following paths
filtered_following_paths = filter_duplicate_and_empty_paths(all_following_ops)
# Print each following path separately
for i, path in enumerate(filtered_following_paths, 1):
print(f" Following Path {i}: {path}")
# Print elementwise operation count
print(f"------------------------------------------------------------------------------------------------------------------")
print(f" Number of elementwise operations: {total_elementwise_count}")
print(f"\n")
# Print the final statistics
print(f"\nTotal number of MatMulInteger operations: {matmulinteger_count}")
print(f"Total number of operations in the model: {total_op_count}")
Editor is loading...
Leave a Comment