Untitled
user_5573880
python
5 months ago
17 kB
3
Indexable
import onnx from collections import defaultdict import onnxruntime as ort import numpy as np import json # Load the ONNX model # model = onnx.load("bidaf-11-int8.onnx") # model = onnx.load("/home/pei1005/former/gtransformer.onnx") model = onnx.load("imdb_sentiment_model.onnx") onnx_model_path = "imdb_sentiment_model.onnx" # Create a dictionary to track which node produces each output tensor output_to_node_map = {} input_to_node_map = {} for node in model.graph.node: for output in node.output: output_to_node_map[output] = node for input_name in node.input: if input_name not in input_to_node_map: input_to_node_map[input_name] = [] input_to_node_map[input_name].append(node) # Define the stopping operations, including MatMulInteger stopping_ops = {"Transpose", "Unsqueeze", "Reshape", "Squeeze","CategoryMapper", "MatMulInteger", "MatMul", "Shape"} # Define the types of elementwise operations elementwise_ops = { "Add", "Sub", "Mul", "Div", # Basic arithmetic "Relu", "Sigmoid", "Tanh", "LeakyRelu", "Softmax", "HardSigmoid", # Activation functions "Greater", "Less", "Equal", "And", "Or", "Not", "Xor", # Logical operations "Abs", "Neg", "Reciprocal", "Sqrt", "Exp", "Log", "Pow", # Unary math functions "Sin", "Cos", "Tan", "Ceil", "Floor", "Clip", "Round", "Sign", # More math functions "Max", "Min", "Mean", "Sum", "Square", "Erf" # Aggregation and other ops } matmul_related_nodes = {} # Function to find previous operations along each input path until a stopping operation is encountered def find_previous_ops(input_name, visited_nodes, elementwise_visited, skip_node_names, related_nodes): all_paths = [] current_path = [] elementwise_count = 0 if input_name in output_to_node_map: node = output_to_node_map[input_name] if node.name in visited_nodes or node.name in skip_node_names: return [], elementwise_count # Skip already visited nodes or nodes to be avoided visited_nodes.add(node.name) # Stop if this node is a stopping operation if node.op_type in stopping_ops: return [], elementwise_count # Do not add stopping op to path current_path.append(node.op_type) related_nodes.add(node.name) # Only count elementwise operations if they haven't been visited in this path yet if node.op_type in elementwise_ops and node.name not in elementwise_visited: elementwise_visited.add(node.name) elementwise_count += 1 # Recursively search for each input path, fully traversing each path until stopping for node_input in node.input: if node_input in output_to_node_map: sub_paths, sub_elementwise_count = find_previous_ops(node_input, visited_nodes, elementwise_visited, skip_node_names, related_nodes) for path in sub_paths: all_paths.append(current_path + path) elementwise_count += sub_elementwise_count if not all_paths: # If no sub-paths were found, the current path itself is a valid path all_paths.append(current_path) return all_paths, elementwise_count # Function to remove duplicate paths and filter out empty paths def filter_duplicate_and_empty_paths(paths): filtered_paths = [] seen_paths = set() for path in paths: path_tuple = tuple(path) # Convert list to tuple so it can be added to set if path and path_tuple not in seen_paths: # Check if path is not empty and not seen seen_paths.add(path_tuple) filtered_paths.append(path) return filtered_paths # Function to find following operations until a stopping operation is encountered # Stops when node has multiple output paths (even if only one output node) # Each node's inputs are also traversed to find their respective input paths until stopping, avoiding specific nodes def find_following_ops(output_name, visited_nodes, elementwise_visited, skip_node_names, related_nodes): all_following_paths = [] current_following_path = [] elementwise_count = 0 # Find the next node that uses this output as input if output_name in input_to_node_map: for node in input_to_node_map[output_name]: if node.name in visited_nodes: continue # Skip already visited nodes to avoid duplication visited_nodes.add(node.name) # Stop if this node is a stopping operation (including MatMulInteger) if node.op_type in stopping_ops: return [], elementwise_count # Do not add stopping op to path current_following_path.append(node.op_type) related_nodes.add(node.name) # Only count elementwise operations if they haven't been visited yet if node.op_type in elementwise_ops and node.name not in elementwise_visited: elementwise_visited.add(node.name) elementwise_count += 1 # Traverse the current node's input paths until stopping, avoiding specific previous nodes input_paths_found = False for node_input in node.input: input_visited_nodes = set() # Skip the node from which we just came in the following path and MatMulInteger nodes skip_node_names.add(node.name) # Avoid the current node in the following path input_paths, input_elementwise_count = find_previous_ops(node_input, visited_nodes, elementwise_visited, skip_node_names, related_nodes) filtered_input_paths = filter_duplicate_and_empty_paths(input_paths) # Only print if we find any valid input paths and count elementwise operations for j, input_path in enumerate(filtered_input_paths, 1): input_paths_found = True print(f" Input Path {j} for Node ({node.op_type}): {input_path}") elementwise_count += input_elementwise_count # Check if the node has multiple output paths (even if only one output node) # total_output_connections = sum(len(input_to_node_map.get(output, [])) for output in node.output) # if total_output_connections > 1: # return [current_following_path], elementwise_count # Recursively search for the next operations for next_output in node.output: sub_paths, sub_elementwise_count = find_following_ops(next_output, visited_nodes, elementwise_visited, skip_node_names, related_nodes) for path in sub_paths: all_following_paths.append(current_following_path + path) elementwise_count += sub_elementwise_count if not all_following_paths: # If no sub-paths were found, the current path itself is a valid path all_following_paths.append(current_following_path) return all_following_paths, elementwise_count def get_execution_order(onnx_model_path): # 讀取 ONNX 模型 model = onnx.load(onnx_model_path) graph = model.graph # 紀錄所有節點及其輸入/輸出 nodes = graph.node execution_order = [] # 遍歷所有的 node,並以執行順序記錄 for idx, node in enumerate(nodes): node_info = { 'order': idx, 'op_type': node.op_type, 'name': node.name if node.name else f"Node_{idx}" } execution_order.append(node_info) return execution_order # def get_execution_order(onnx_model_path): # # 創建 Inference Session # options = ort.SessionOptions() # # 啟用 profiling 並指定輸出文件 # options.enable_profiling = True # # 創建 ONNX Runtime InferenceSession # session = ort.InferenceSession(onnx_model_path, options, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) # # 建立執行順序的清單 # execution_order = [] # node_name_to_order = {} # 用於記錄節點執行的順序 # # 構造輸入資料 # input_name = session.get_inputs()[0].name # input_shape = session.get_inputs()[0].shape # # 將 None 動態維度替換為具體的數值,例如 1 # input_shape = [dim if isinstance(dim, int) else 1 for dim in input_shape] # # 根據模型的需求生成 int64 類型的假數據 # input_data = np.random.randint(0, 100, size=input_shape).astype(np.int64) # # 執行模型推理 # outputs = session.run(None, {input_name: input_data}) # # 獲取 profiling 輸出文件 # profile_file = session.end_profiling() # with open(profile_file, "r") as f: # profile_data = json.load(f) # count = 0 # for i, item in enumerate(profile_data): # node_name = item.get("name", f"Node_{i}") # op_type = item.get("args", {}).get("op_name", "Unknown") # args = item.get("args", {}) # if item.get("cat") == "Node" and "thread_scheduling_stats" in item.get("args", {}): # execution_order.append({ # "order": count, # "name": node_name, # "op_type": op_type # }) # count+=1 # return execution_order def group_execution_order_by_matmul(execution_order, matmul_related_nodes): # 用來儲存分組的結果 grouped_order = [] current_group = [] visited_nodes = set() # 避免重複處理的節點 related_group = None matmul_found = False # 標記當前群組是否包含 MatMul # 根據原始執行順序進行遍歷 for node in execution_order: node_name = node['name'] # 如果該節點已經處理過,跳過 if node_name in visited_nodes: continue # 如果當前沒有相關群組,找到新的 MatMul 相關群組 new_related_group = None for matmul_node, related_nodes in matmul_related_nodes.items(): if node_name in related_nodes: new_related_group = related_nodes break # 如果該節點屬於 MatMul 相關群組,且為同一組或是新組的開始 if new_related_group: if related_group is None or new_related_group == related_group: related_group = new_related_group # 更新相關群組 current_group.append(node) visited_nodes.add(node_name) # 檢查是否找到 MatMul 或 MatMulInteger if node['op_type'] in {"MatMul", "MatMulInteger"}: matmul_found = True else: # 如果遇到新的相關群組,結束當前群組並重置 if matmul_found: grouped_order.append(current_group) else: for n in current_group: grouped_order.append([n]) # 開始新的群組 current_group = [node] visited_nodes.add(node_name) related_group = new_related_group matmul_found = node['op_type'] in {"MatMul", "MatMulInteger"} else: # 若該節點不屬於任何 MatMul 群組,單獨存放 if current_group and matmul_found: grouped_order.append(current_group) elif current_group: for n in current_group: grouped_order.append([n]) # 單獨存放該節點 current_group = [] grouped_order.append([node]) visited_nodes.add(node_name) related_group = None matmul_found = False # 處理最後一個群組 if current_group and matmul_found: grouped_order.append(current_group) return grouped_order # Count the number of MatMulInteger operations and total operations matmulinteger_count = 0 total_op_count = 0 visited_nodes_for_input = set() # Separate visited set for each input path # Traverse through the model nodes, find MatMulInteger, and print preceding and following operations for node in model.graph.node: total_op_count += 1 if node.op_type == "MatMul" or node.op_type == "MatMulInteger": matmulinteger_count += 1 print(f"\nMatMulInteger Node: {node.name}") related_nodes = set() related_nodes.add(node.name) # Find previous operations for each input path, traversing fully along each path all_previous_ops = [] total_elementwise_count = 0 path_counter = 1 elementwise_visited = set() # Track visited elementwise nodes to avoid counting duplicates previous_op_names = {node.name} # Initialize with MatMulInteger's own node.name for input_name in node.input: previous_paths, prev_elementwise_count = find_previous_ops(input_name, visited_nodes_for_input, elementwise_visited, set(), related_nodes) all_previous_ops.extend(previous_paths) total_elementwise_count += prev_elementwise_count # Record node names from previous paths for visited_node in visited_nodes_for_input: previous_op_names.add(visited_node) # Filter out empty and duplicate paths filtered_paths = filter_duplicate_and_empty_paths(all_previous_ops) # Print each path separately after filtering for i, path in enumerate(filtered_paths, 1): print(f" Previous Path {i}: {path}") print(f"------------------------------------------------------------------------------------------------------------------") # Find following operations, traversing fully along each path until stopping all_following_ops = [] visited_nodes = set() # Reset visited nodes for forward search following_elementwise_visited = set() # Track visited elementwise nodes in following ops for output_name in node.output: following_paths, next_elementwise_count = find_following_ops(output_name, visited_nodes, following_elementwise_visited, previous_op_names, related_nodes) all_following_ops.extend(following_paths) total_elementwise_count += next_elementwise_count # Filter out empty and duplicate following paths filtered_following_paths = filter_duplicate_and_empty_paths(all_following_ops) # Print each following path separately for i, path in enumerate(filtered_following_paths, 1): print(f" Following Path {i}: {path}") # Print elementwise operation count print(f"------------------------------------------------------------------------------------------------------------------") print(f" Number of elementwise operations: {total_elementwise_count}") print(f"\n") # 將收集到的節點名稱存入字典 matmul_related_nodes[node.name] = related_nodes # Print the final statistics print(f"\nTotal number of MatMulInteger operations: {matmulinteger_count}") print(f"Total number of operations in the model: {total_op_count}") execution_order = get_execution_order(onnx_model_path) grouped_order = group_execution_order_by_matmul(execution_order, matmul_related_nodes) with open("grouped_execution_order.txt", "w") as f: for i, group in enumerate(grouped_order, 1): f.write(f"\nGroup {i}:\n") for node in group: if isinstance(node, dict): # 確保 node 是字典 f.write(f" Order: {node['order']}, OP Type: {node['op_type']}, Name: {node['name']}\n") else: f.write(f" Unexpected node format: {node}\n") # 將結果儲存為檔案 with open("matmul_related_nodes.txt", "w") as f: for matmul_node, related_nodes in matmul_related_nodes.items(): f.write(f"MatMul/MatMulInteger Node: {matmul_node}\n") f.write(f"Related Nodes: {', '.join(related_nodes)}\n") f.write("---------------------------------------------------\n")
Editor is loading...
Leave a Comment