Untitled
user_5573880
python
6 months ago
8.0 kB
2
Indexable
import onnx # Load the ONNX model model = onnx.load("bidaf-11-int8.onnx") # Create a dictionary to track which node produces each output tensor output_to_node_map = {} input_to_node_map = {} for node in model.graph.node: for output in node.output: output_to_node_map[output] = node for input_name in node.input: if input_name not in input_to_node_map: input_to_node_map[input_name] = [] input_to_node_map[input_name].append(node) # Define the stopping operations, including MatMulInteger stopping_ops = {"Transpose", "Unsqueeze", "Gather", "Reshape", "Squeeze", "MatMulInteger"} # Define the types of elementwise operations elementwise_ops = { "Add", "Sub", "Mul", "Div", # Basic arithmetic "Relu", "Sigmoid", "Tanh", "LeakyRelu", "Softmax", "HardSigmoid", # Activation functions "Greater", "Less", "Equal", "And", "Or", "Not", "Xor", # Logical operations "Abs", "Neg", "Reciprocal", "Sqrt", "Exp", "Log", "Pow", # Unary math functions "Sin", "Cos", "Tan", "Ceil", "Floor", "Clip", "Round", "Sign", # More math functions "Max", "Min", "Mean", "Sum", "Square", "Erf" # Aggregation and other ops } # Function to find previous operations along each input path until a stopping operation is encountered def find_previous_ops(input_name, visited_nodes, elementwise_visited): all_paths = [] current_path = [] elementwise_count = 0 if input_name in output_to_node_map: node = output_to_node_map[input_name] if node.name in visited_nodes: return [], elementwise_count # Skip already visited nodes visited_nodes.add(node.name) current_path.append(node.op_type) # Only count elementwise operations if they haven't been visited in this path yet if node.op_type in elementwise_ops and node.name not in elementwise_visited: elementwise_visited.add(node.name) elementwise_count += 1 # Stop if this node is a stopping operation if node.op_type in stopping_ops: return [current_path], elementwise_count # Recursively search for each input path, fully traversing each path until stopping for node_input in node.input: if node_input in output_to_node_map: sub_paths, sub_elementwise_count = find_previous_ops(node_input, visited_nodes, elementwise_visited) for path in sub_paths: all_paths.append(current_path + path) elementwise_count += sub_elementwise_count if not all_paths: # If no sub-paths were found, the current path itself is a valid path all_paths.append(current_path) return all_paths, elementwise_count # Function to remove duplicate paths and filter out empty paths def filter_duplicate_and_empty_paths(paths): filtered_paths = [] seen_paths = set() for path in paths: path_tuple = tuple(path) # Convert list to tuple so it can be added to set if path and path_tuple not in seen_paths: # Check if path is not empty and not seen seen_paths.add(path_tuple) filtered_paths.append(path) return filtered_paths # Function to find following operations until a stopping operation is encountered # Stops when node has multiple output paths (even if only one output node) def find_following_ops(output_name, visited_nodes, elementwise_visited): all_following_paths = [] current_following_path = [] elementwise_count = 0 # Find the next node that uses this output as input if output_name in input_to_node_map: for node in input_to_node_map[output_name]: if node.name in visited_nodes: continue # Skip already visited nodes to avoid duplication visited_nodes.add(node.name) current_following_path.append(node.op_type) # Only count elementwise operations if they haven't been visited yet if node.op_type in elementwise_ops and node.name not in elementwise_visited: elementwise_visited.add(node.name) elementwise_count += 1 # Stop if this node is a stopping operation (including MatMulInteger) if node.op_type in stopping_ops: return [current_following_path], elementwise_count # Check if the node has multiple output paths (even if only one output node) total_output_connections = sum(len(input_to_node_map.get(output, [])) for output in node.output) if total_output_connections > 1: return [current_following_path], elementwise_count # Recursively search for the next operations for next_output in node.output: sub_paths, sub_elementwise_count = find_following_ops(next_output, visited_nodes, elementwise_visited) for path in sub_paths: all_following_paths.append(current_following_path + path) elementwise_count += sub_elementwise_count if not all_following_paths: # If no sub-paths were found, the current path itself is valid all_following_paths.append(current_following_path) return all_following_paths, elementwise_count # Count the number of MatMulInteger operations and total operations matmulinteger_count = 0 total_op_count = 0 # Traverse through the model nodes, find MatMulInteger, and print preceding and following operations for node in model.graph.node: total_op_count += 1 if node.op_type == "MatMulInteger": matmulinteger_count += 1 print(f"\nMatMulInteger Node: {node.name}") # Find previous operations for each input path, traversing fully along each path all_previous_ops = [] total_elementwise_count = 0 path_counter = 1 elementwise_visited = set() # Track visited elementwise nodes to avoid counting duplicates for input_name in node.input: visited_nodes_for_input = set() # Separate visited set for each input path previous_paths, prev_elementwise_count = find_previous_ops(input_name, visited_nodes_for_input, elementwise_visited) all_previous_ops.extend(previous_paths) total_elementwise_count += prev_elementwise_count # Filter out empty and duplicate paths filtered_paths = filter_duplicate_and_empty_paths(all_previous_ops) # Print each path separately after filtering for i, path in enumerate(filtered_paths, 1): print(f" Previous Path {i}: {path}") # Find following operations, traversing fully along each path until stopping all_following_ops = [] visited_nodes = set() # Reset visited nodes for forward search following_elementwise_visited = set() # Track visited elementwise nodes in following ops for output_name in node.output: following_paths, next_elementwise_count = find_following_ops(output_name, visited_nodes, following_elementwise_visited) all_following_ops.extend(following_paths) total_elementwise_count += next_elementwise_count # Filter out empty and duplicate following paths filtered_following_paths = filter_duplicate_and_empty_paths(all_following_ops) # Print each following path separately for i, path in enumerate(filtered_following_paths, 1): print(f" Following Path {i}: {path}") # Print elementwise operation count print(f" Number of elementwise operations: {total_elementwise_count}") # Print the final statistics print(f"\nTotal number of MatMulInteger operations: {matmulinteger_count}") print(f"Total number of operations in the model: {total_op_count}")
Editor is loading...
Leave a Comment