Untitled

 avatar
user_5573880
python
5 months ago
17 kB
3
Indexable
import onnx
from collections import defaultdict
import onnxruntime as ort
import numpy as np
import json

# Load the ONNX model
# model = onnx.load("bidaf-11-int8.onnx")
# model = onnx.load("/home/pei1005/former/gtransformer.onnx")
model = onnx.load("imdb_sentiment_model.onnx")
onnx_model_path = "imdb_sentiment_model.onnx"

# Create a dictionary to track which node produces each output tensor
output_to_node_map = {}
input_to_node_map = {}
for node in model.graph.node:
    for output in node.output:
        output_to_node_map[output] = node
    for input_name in node.input:
        if input_name not in input_to_node_map:
            input_to_node_map[input_name] = []
        input_to_node_map[input_name].append(node)

# Define the stopping operations, including MatMulInteger
stopping_ops = {"Transpose", "Unsqueeze", "Reshape", "Squeeze","CategoryMapper", "MatMulInteger", "MatMul", "Shape"}

# Define the types of elementwise operations
elementwise_ops = {
    "Add", "Sub", "Mul", "Div",            # Basic arithmetic
    "Relu", "Sigmoid", "Tanh", "LeakyRelu", "Softmax", "HardSigmoid",  # Activation functions
    "Greater", "Less", "Equal", "And", "Or", "Not", "Xor",             # Logical operations
    "Abs", "Neg", "Reciprocal", "Sqrt", "Exp", "Log", "Pow",           # Unary math functions
    "Sin", "Cos", "Tan", "Ceil", "Floor", "Clip", "Round", "Sign",     # More math functions
    "Max", "Min", "Mean", "Sum", "Square", "Erf"                       # Aggregation and other ops
}

matmul_related_nodes = {}

# Function to find previous operations along each input path until a stopping operation is encountered
def find_previous_ops(input_name, visited_nodes, elementwise_visited, skip_node_names, related_nodes):
    all_paths = []
    current_path = []
    elementwise_count = 0

    if input_name in output_to_node_map:
        node = output_to_node_map[input_name]
        if node.name in visited_nodes or node.name in skip_node_names:
            return [], elementwise_count  # Skip already visited nodes or nodes to be avoided
        visited_nodes.add(node.name)

        # Stop if this node is a stopping operation
        if node.op_type in stopping_ops:
            return [], elementwise_count  # Do not add stopping op to path

        current_path.append(node.op_type)
        related_nodes.add(node.name)
        
        # Only count elementwise operations if they haven't been visited in this path yet
        if node.op_type in elementwise_ops and node.name not in elementwise_visited:
            elementwise_visited.add(node.name)
            elementwise_count += 1

        # Recursively search for each input path, fully traversing each path until stopping
        for node_input in node.input:
            if node_input in output_to_node_map:
                sub_paths, sub_elementwise_count = find_previous_ops(node_input, visited_nodes, elementwise_visited, skip_node_names, related_nodes)
                for path in sub_paths:
                    all_paths.append(current_path + path)
                elementwise_count += sub_elementwise_count

    if not all_paths:  # If no sub-paths were found, the current path itself is a valid path
        all_paths.append(current_path)

    return all_paths, elementwise_count

# Function to remove duplicate paths and filter out empty paths
def filter_duplicate_and_empty_paths(paths):
    filtered_paths = []
    seen_paths = set()

    for path in paths:
        path_tuple = tuple(path)  # Convert list to tuple so it can be added to set
        if path and path_tuple not in seen_paths:  # Check if path is not empty and not seen
            seen_paths.add(path_tuple)
            filtered_paths.append(path)

    return filtered_paths

# Function to find following operations until a stopping operation is encountered
# Stops when node has multiple output paths (even if only one output node)
# Each node's inputs are also traversed to find their respective input paths until stopping, avoiding specific nodes
def find_following_ops(output_name, visited_nodes, elementwise_visited, skip_node_names, related_nodes):
    all_following_paths = []
    current_following_path = []
    elementwise_count = 0

    # Find the next node that uses this output as input
    if output_name in input_to_node_map:
        for node in input_to_node_map[output_name]:
            if node.name in visited_nodes:
                continue  # Skip already visited nodes to avoid duplication
            visited_nodes.add(node.name)

            # Stop if this node is a stopping operation (including MatMulInteger)
            if node.op_type in stopping_ops:
                return [], elementwise_count  # Do not add stopping op to path

            current_following_path.append(node.op_type)

            related_nodes.add(node.name)
            
            # Only count elementwise operations if they haven't been visited yet
            if node.op_type in elementwise_ops and node.name not in elementwise_visited:
                elementwise_visited.add(node.name)
                elementwise_count += 1

            # Traverse the current node's input paths until stopping, avoiding specific previous nodes
            input_paths_found = False
            for node_input in node.input:
                input_visited_nodes = set()
                # Skip the node from which we just came in the following path and MatMulInteger nodes
                skip_node_names.add(node.name)  # Avoid the current node in the following path
                input_paths, input_elementwise_count = find_previous_ops(node_input, visited_nodes, elementwise_visited, skip_node_names, related_nodes)
                filtered_input_paths = filter_duplicate_and_empty_paths(input_paths)

                # Only print if we find any valid input paths and count elementwise operations
                for j, input_path in enumerate(filtered_input_paths, 1):
                    input_paths_found = True
                    print(f"  Input Path {j} for Node ({node.op_type}): {input_path}")
                    elementwise_count += input_elementwise_count

            # Check if the node has multiple output paths (even if only one output node)
            # total_output_connections = sum(len(input_to_node_map.get(output, [])) for output in node.output)
            # if total_output_connections > 1:
            #     return [current_following_path], elementwise_count
            
            # Recursively search for the next operations
            for next_output in node.output:
                sub_paths, sub_elementwise_count = find_following_ops(next_output, visited_nodes, elementwise_visited, skip_node_names, related_nodes)
                for path in sub_paths:
                    all_following_paths.append(current_following_path + path)
                elementwise_count += sub_elementwise_count

    if not all_following_paths:  # If no sub-paths were found, the current path itself is a valid path
        all_following_paths.append(current_following_path)

    return all_following_paths, elementwise_count

def get_execution_order(onnx_model_path):
    # 讀取 ONNX 模型
    model = onnx.load(onnx_model_path)
    graph = model.graph

    # 紀錄所有節點及其輸入/輸出
    nodes = graph.node
    execution_order = []

    # 遍歷所有的 node,並以執行順序記錄
    for idx, node in enumerate(nodes):
        node_info = {
            'order': idx,
            'op_type': node.op_type,
            'name': node.name if node.name else f"Node_{idx}"
        }
        execution_order.append(node_info)

    return execution_order

# def get_execution_order(onnx_model_path):
#     # 創建 Inference Session
#     options = ort.SessionOptions()

#     # 啟用 profiling 並指定輸出文件
#     options.enable_profiling = True

#     # 創建 ONNX Runtime InferenceSession
#     session = ort.InferenceSession(onnx_model_path, options, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
    
#     # 建立執行順序的清單
#     execution_order = []
#     node_name_to_order = {}  # 用於記錄節點執行的順序

#     # 構造輸入資料
#     input_name = session.get_inputs()[0].name
#     input_shape = session.get_inputs()[0].shape

#     # 將 None 動態維度替換為具體的數值,例如 1
#     input_shape = [dim if isinstance(dim, int) else 1 for dim in input_shape]

#     # 根據模型的需求生成 int64 類型的假數據
#     input_data = np.random.randint(0, 100, size=input_shape).astype(np.int64)

#     # 執行模型推理
#     outputs = session.run(None, {input_name: input_data})

#     # 獲取 profiling 輸出文件
#     profile_file = session.end_profiling()
    
#     with open(profile_file, "r") as f:
#         profile_data = json.load(f)
#         count = 0
#         for i, item in enumerate(profile_data):
#             node_name = item.get("name", f"Node_{i}")
#             op_type = item.get("args", {}).get("op_name", "Unknown")
#             args = item.get("args", {})
#             if item.get("cat") == "Node" and "thread_scheduling_stats" in item.get("args", {}):
#                 execution_order.append({
#                     "order": count,
#                     "name": node_name,
#                     "op_type": op_type
#                 })
#                 count+=1

#     return execution_order

def group_execution_order_by_matmul(execution_order, matmul_related_nodes):
    # 用來儲存分組的結果
    grouped_order = []
    current_group = []
    visited_nodes = set()  # 避免重複處理的節點
    related_group = None
    matmul_found = False  # 標記當前群組是否包含 MatMul

    # 根據原始執行順序進行遍歷
    for node in execution_order:
        node_name = node['name']

        # 如果該節點已經處理過,跳過
        if node_name in visited_nodes:
            continue

        # 如果當前沒有相關群組,找到新的 MatMul 相關群組
        new_related_group = None
        for matmul_node, related_nodes in matmul_related_nodes.items():
            if node_name in related_nodes:
                new_related_group = related_nodes
                break

        # 如果該節點屬於 MatMul 相關群組,且為同一組或是新組的開始
        if new_related_group:
            if related_group is None or new_related_group == related_group:
                related_group = new_related_group  # 更新相關群組
                current_group.append(node)
                visited_nodes.add(node_name)

                # 檢查是否找到 MatMul 或 MatMulInteger
                if node['op_type'] in {"MatMul", "MatMulInteger"}:
                    matmul_found = True
            else:
                # 如果遇到新的相關群組,結束當前群組並重置
                if matmul_found:
                    grouped_order.append(current_group)
                else:
                    for n in current_group:
                        grouped_order.append([n])

                # 開始新的群組
                current_group = [node]
                visited_nodes.add(node_name)
                related_group = new_related_group
                matmul_found = node['op_type'] in {"MatMul", "MatMulInteger"}
        else:
            # 若該節點不屬於任何 MatMul 群組,單獨存放
            if current_group and matmul_found:
                grouped_order.append(current_group)
            elif current_group:
                for n in current_group:
                    grouped_order.append([n])

            # 單獨存放該節點
            current_group = []
            grouped_order.append([node])
            visited_nodes.add(node_name)
            related_group = None
            matmul_found = False

    # 處理最後一個群組
    if current_group and matmul_found:
        grouped_order.append(current_group)

    return grouped_order

# Count the number of MatMulInteger operations and total operations
matmulinteger_count = 0
total_op_count = 0
visited_nodes_for_input = set()  # Separate visited set for each input path
# Traverse through the model nodes, find MatMulInteger, and print preceding and following operations
for node in model.graph.node:
    total_op_count += 1
    if node.op_type == "MatMul" or node.op_type == "MatMulInteger":
        matmulinteger_count += 1
        print(f"\nMatMulInteger Node: {node.name}")
        
        related_nodes = set() 
        related_nodes.add(node.name)
        
        # Find previous operations for each input path, traversing fully along each path
        all_previous_ops = []
        total_elementwise_count = 0
        path_counter = 1
        elementwise_visited = set()  # Track visited elementwise nodes to avoid counting duplicates
        previous_op_names = {node.name}  # Initialize with MatMulInteger's own node.name
        for input_name in node.input:
            
            previous_paths, prev_elementwise_count = find_previous_ops(input_name, visited_nodes_for_input, elementwise_visited, set(), related_nodes)
            all_previous_ops.extend(previous_paths)
            total_elementwise_count += prev_elementwise_count
            # Record node names from previous paths
            for visited_node in visited_nodes_for_input:
                previous_op_names.add(visited_node)

        # Filter out empty and duplicate paths
        filtered_paths = filter_duplicate_and_empty_paths(all_previous_ops)
        
        # Print each path separately after filtering
        for i, path in enumerate(filtered_paths, 1):
            print(f"  Previous Path {i}: {path}")
        print(f"------------------------------------------------------------------------------------------------------------------")
        # Find following operations, traversing fully along each path until stopping
        all_following_ops = []
        visited_nodes = set()  # Reset visited nodes for forward search
        following_elementwise_visited = set()  # Track visited elementwise nodes in following ops
        for output_name in node.output:
            following_paths, next_elementwise_count = find_following_ops(output_name, visited_nodes, following_elementwise_visited, previous_op_names, related_nodes)
            all_following_ops.extend(following_paths)
            total_elementwise_count += next_elementwise_count

        # Filter out empty and duplicate following paths
        filtered_following_paths = filter_duplicate_and_empty_paths(all_following_ops)
        
        # Print each following path separately
        for i, path in enumerate(filtered_following_paths, 1):
            print(f"  Following Path {i}: {path}")
        
        # Print elementwise operation count
        print(f"------------------------------------------------------------------------------------------------------------------")
        print(f"  Number of elementwise operations: {total_elementwise_count}")
        print(f"\n")
        
        # 將收集到的節點名稱存入字典
        matmul_related_nodes[node.name] = related_nodes
        
# Print the final statistics
print(f"\nTotal number of MatMulInteger operations: {matmulinteger_count}")
print(f"Total number of operations in the model: {total_op_count}")

execution_order = get_execution_order(onnx_model_path)
grouped_order = group_execution_order_by_matmul(execution_order, matmul_related_nodes)

with open("grouped_execution_order.txt", "w") as f:
    for i, group in enumerate(grouped_order, 1):
        f.write(f"\nGroup {i}:\n")
        for node in group:
            if isinstance(node, dict):  # 確保 node 是字典
                f.write(f"  Order: {node['order']}, OP Type: {node['op_type']}, Name: {node['name']}\n")
            else:
                f.write(f"  Unexpected node format: {node}\n")
            
# 將結果儲存為檔案
with open("matmul_related_nodes.txt", "w") as f:
    for matmul_node, related_nodes in matmul_related_nodes.items():
        f.write(f"MatMul/MatMulInteger Node: {matmul_node}\n")
        f.write(f"Related Nodes: {', '.join(related_nodes)}\n")
        f.write("---------------------------------------------------\n")
Editor is loading...
Leave a Comment