Untitled

 avatar
unknown
c_cpp
25 days ago
51 kB
5
Indexable
import onnx
import numpy as np
import torch
import cutlass
from cutlass.epilogue import relu
from cutlass import Tensor as FakeTensor
from cutlass.utils.profiler import CUDAEventProfiler
from transformers import BertTokenizer
import time
from torch.cuda import nvtx
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

plan = cutlass.op.Gemm(
    element=torch.float32,
    layout=cutlass.LayoutType.RowMajor,
    element_accumulator=torch.float32,
    cc=80
)

def example_epilogue_matmul(accum):
    D = accum
    return D

def example_epilogue(accum, C):
    D = accum + C
    return D

def example_epilogue2(accum, C, F):
    D = accum + C + F
    return D

class ONNXCompiler:
    def __init__(self, model_path):
        # 加载 ONNX 模型
        self.model = onnx.load(model_path)
        self.graph = self.model.graph
        self.nodes = self.graph.node  # 初始化 nodes 屬性
        self.processed_nodes = set()  # 用於跟踪已處理的 Add 節點
        self.matmul_add_fuse_node = set() # 紀錄有 fuse 的 node

        for input_tensor in self.graph.input:
            print(f"Model Input Name: {input_tensor.name}, Shape: {[dim.dim_value for dim in input_tensor.type.tensor_type.shape.dim]}")

        # 提取初始化张量
        self.initializers = {}
        for tensor in self.graph.initializer:
            name = tensor.name
            array = onnx.numpy_helper.to_array(tensor)  # 将 TensorProto 转换为 NumPy 数组
            self.initializers[name] = array

        self.tensors = self.initializers.copy()

    def _execute_node(self, node):
        op_type = node.op_type
        inputs = []

        # 收集输入张量
        for input_name in node.input:
            if input_name in self.tensors:
                inputs.append(self.tensors[input_name])
                # if torch.isnan(torch.tensor(self.tensors[input_name])).any():
                #     print("NAN",input_name)
            else:
                print(f"Warning: Missing input tensor '{input_name}' for node '{node.name}'.")
                return None
            
        # for input in inputs:
        #     if torch.isnan(torch.tensor(input)).any():
        #         print("Inputs contain NaN values!")
        #         print(input)

        
        if op_type == "MatMul":

            # 提取 MatMul 輸入
            A = inputs[0]
            B = inputs[1]
            original_shape_A = A.shape
            original_shape_B = B.shape
            # print("shape: ")
            # print(original_shape_A, original_shape_B)
            # 處理多維情況(批量矩陣乘法)
            batch_dims = None
            if A.ndim > 2 or B.ndim > 2:
                batch_dims = np.prod(original_shape_A[:-2])
                A = A.reshape(batch_dims, original_shape_A[-2], original_shape_A[-1])
                B = B.reshape(batch_dims, original_shape_B[-2], original_shape_B[-1])

            M, K = A.shape[-2], A.shape[-1]
            K, N = B.shape[-2], B.shape[-1]

            print(f"Input size: ({batch_dims}, {M}, {K}, {N})")
    
            tensor_A = torch.tensor(A, dtype=torch.float32, device="cuda").contiguous()
            tensor_B = torch.tensor(B, dtype=torch.float32, device="cuda").contiguous()

            # print("Matmul input: ", tensor_A, tensor_B)

            # 查找與 MatMul 輸出相關的 Add 節點
            matmul_output_name = node.output[0]
            related_add_nodes = [
                n for n in self.nodes if n.op_type == "Add" and matmul_output_name in n.input
            ]

            # related_add_nodes = None # separate->t2.txt

            if not related_add_nodes:
                # 情況 3:MatMul 後無 Add
                print(f"No Add node related to MatMul output: {node.name}. Executing regular MatMul.")

                if batch_dims:
                    tensor_C = torch.zeros_like(torch.empty(size=(batch_dims, M, N), dtype=torch.float32, device="cuda"))
                else:
                    tensor_C = torch.zeros_like(torch.empty(size=(M, N), dtype=torch.float32, device="cuda"))
                
                tensor_D = torch.zeros_like(tensor_C)
                
                examples_tensors = {
                    "accum": FakeTensor(element=torch.float32, shape=(tensor_C.shape), layout_tag=cutlass.LayoutType.RowMajor),
                    "C": tensor_C,
                    "D": tensor_D
                }

                epilogue_visitor = cutlass.epilogue.trace(example_epilogue_matmul, examples_tensors)
                epilogue_visitor.epilogue_stages = 1
                
                visitor_args = {
                    "C": tensor_C,
                    "D": tensor_D
                }

                plan.epilogue_visitor = epilogue_visitor

                start_event = torch.cuda.Event(enable_timing=True)
                end_event = torch.cuda.Event(enable_timing=True)

                torch.cuda.synchronize()
                # nvtx.range_push(node.op_type + ", " + node.name)
                start_event.record()

                plan.run(
                    tensor_A, tensor_B, tensor_C, tensor_D,
                    visitor_args=visitor_args, print_module=False
                )

                end_event.record()
                # nvtx.range_pop()
                torch.cuda.synchronize()

                # 記錄執行時間
                execution_time = (start_event.elapsed_time(end_event)) / 1000
                self.total_execution_time += execution_time
                self.total_matmul_add_time += execution_time
                print(f"MatMul Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")

                result = tensor_D.cpu().numpy()

                if len(original_shape_A) > 2:
                    result = result.reshape(*original_shape_A[:-2], M, N)
                
                # print("Matmul output: ", result)

                return result

            # 情況 1 或 2:MatMul 輸出為 Add,檢查 Add 的輸出是否也為 Add
            self.matmul_add_fuse_node.add(node.name)

            add_node = related_add_nodes[0]  # 第一層 Add
            add_input_name = [name for name in add_node.input if name != matmul_output_name][0]
            add_input = self.tensors.get(add_input_name)

            if add_input is not None:
                tensor_add_input = torch.tensor(add_input, dtype=torch.float32, device="cuda").contiguous()
                tensor_add_input = tensor_add_input.repeat((M, 1))

                # 檢查第一層 Add 的輸出是否也作為另一個 Add 的輸入
                add_output_name = add_node.output[0]
                next_add_nodes = [
                    n for n in self.nodes if n.op_type == "Add" and add_output_name in n.input
                ]

                if next_add_nodes:
                    # 情況 2:兩層 Add 融合
                    next_add_node = next_add_nodes[0]
                    next_add_input_name = [
                        name for name in next_add_node.input if name != add_output_name
                    ][0]
                    next_add_input = self.tensors.get(next_add_input_name)

                    if next_add_input is not None:
                        print(f"Fusing MatMul with 2Add for node: {node.name}")
                        tensor_next_add_input = torch.tensor(next_add_input, dtype=torch.float32, device="cuda").contiguous()
                        print(tensor_add_input.shape, ", ", tensor_next_add_input.shape)

                        # print("fuse 2add input: ", tensor_add_input, tensor_next_add_input)

                        tensor_D = torch.zeros_like(tensor_add_input)

                        examples_tensors = {
                            "accum": FakeTensor(element=torch.float32, shape=(M, N), layout_tag=cutlass.LayoutType.RowMajor),
                            "C": tensor_add_input,
                            "F": tensor_next_add_input,
                            "D": tensor_D
                        }

                        epilogue_visitor = cutlass.epilogue.trace(example_epilogue2, examples_tensors)
                        epilogue_visitor.epilogue_stages = 1
                        
                        visitor_args = {
                            "C": tensor_add_input, "F": tensor_next_add_input, "D": tensor_D
                        }

                        plan.epilogue_visitor = epilogue_visitor

                        start_event = torch.cuda.Event(enable_timing=True)
                        end_event = torch.cuda.Event(enable_timing=True)

                        torch.cuda.synchronize()
                        # nvtx.range_push(node.op_type + ", " + node.name)
                        start_event.record()

                        plan.run(
                            tensor_A, tensor_B, tensor_add_input, tensor_D,
                            visitor_args=visitor_args, print_module=False
                        )
                        # tensor_D = torch.matmul(tensor_A, tensor_B) + tensor_add_input + tensor_next_add_input

                        end_event.record()
                        # nvtx.range_pop()
                        torch.cuda.synchronize()

                        # 記錄執行時間
                        execution_time = (start_event.elapsed_time(end_event)) / 1000
                        self.total_execution_time += execution_time
                        self.total_matmul_add_time += execution_time
                        print(f"MatMul Fuse node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")

                        result = tensor_D.cpu().numpy()

                        if len(original_shape_A) > 2:
                            result = result.reshape(*original_shape_A[:-2], M, N)
                        self.processed_nodes.add(next_add_node.name)
                        self.processed_nodes.add(add_node.name)
                        self.tensors[add_node.output[0]] = result
                        self.tensors[next_add_node.output[0]] = result
                        # print("matmul_2add_fuse_node output: ", next_add_node.name, result)

                        return result

                # 情況 1:一層 Add 融合
                
                print(f"Fusing MatMul with Add for node: {node.name}")
                print(tensor_add_input.shape)
                # print("fuse add input: ", tensor_add_input)

                tensor_D = torch.zeros_like(tensor_add_input)

                examples_tensors = {
                    "accum": FakeTensor(element=torch.float32, shape=(M, N), layout_tag=cutlass.LayoutType.RowMajor),
                    "C": tensor_add_input,
                    "D": tensor_D
                }

                epilogue_visitor = cutlass.epilogue.trace(example_epilogue, examples_tensors)
                epilogue_visitor.epilogue_stages = 1
                
                visitor_args = {
                    "C": tensor_add_input, "D": tensor_D
                }

                plan.epilogue_visitor = epilogue_visitor

                start_event = torch.cuda.Event(enable_timing=True)
                end_event = torch.cuda.Event(enable_timing=True)

                torch.cuda.synchronize()
                # nvtx.range_push(node.op_type + ", " + node.name)
                start_event.record()

                plan.run(
                    tensor_A, tensor_B, tensor_add_input, tensor_D,
                    visitor_args=visitor_args, print_module=False
                )
                # tensor_D = torch.matmul(tensor_A, tensor_B) + tensor_add_input

                end_event.record()
                # nvtx.range_pop()
                torch.cuda.synchronize()

                # 記錄執行時間
                execution_time = (start_event.elapsed_time(end_event)) / 1000
                self.total_execution_time += execution_time
                self.total_matmul_add_time += execution_time
                print(f"MatMul Fuse node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")

                result = tensor_D.cpu().numpy()

                if len(original_shape_A) > 2:
                    result = result.reshape(*original_shape_A[:-2], M, N)
                self.processed_nodes.add(add_node.name)
                self.tensors[add_node.output[0]] = result
                # print("matmul_add_fuse_node output: ", add_node.name, result)
                return result
                
        # 運算邏輯
        elif op_type == "Slice":
            # 確保輸入數據轉換為 PyTorch 張量
            data = torch.tensor(inputs[0], dtype=torch.float32, device="cuda")
            starts = torch.tensor(inputs[1])
            ends = torch.tensor(inputs[2])
            axes = torch.tensor(inputs[3])

            # 默認步長為 1
            steps = torch.ones_like(starts, dtype=torch.int64)

            # 處理負軸索引
            axes = torch.where(axes < 0, axes + data.dim(), axes)

            # 初始化切片索引
            slices = [slice(None)] * data.dim()

            # 計時器事件
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            # nvtx.range_push(node.op_type + ", " + node.name)
            start_event.record()

            # 遍歷每個 axis 並應用 slicing
            for start, end, axis, step in zip(starts, ends, axes, steps):
                axis = int(axis)
                dim = data.size(axis)

                # 處理負數的 starts 和 ends
                start = start.item() + dim if start.item() < 0 else start.item()
                end = end.item() + dim if end.item() < 0 else end.item()

                # clamp 範圍
                start = max(0, min(start, dim))
                end = max(0, min(end, dim)) if step > 0 else max(-1, min(end, dim - 1))

                # 設置切片
                slices[axis] = slice(start, end, step.item())

            # 應用切片
            result = data[tuple(slices)]

            end_event.record()
            # nvtx.range_pop()
            torch.cuda.synchronize()

            # 記錄執行時間
            execution_time = (start_event.elapsed_time(end_event)) / 1000
            self.total_execution_time += execution_time
            print(f"Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")

            return result.cpu().numpy()  # 返回 NumPy 格式的結果
        
        elif op_type == "Add":
            A = torch.tensor(inputs[0], dtype=torch.float32, device="cuda").contiguous()
            B = torch.tensor(inputs[1], dtype=torch.float32, device="cuda").contiguous()

            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            # nvtx.range_push(node.op_type + ", " + node.name)
            start_event.record()
            
            result = (A + B).cpu().numpy()

            end_event.record()
            # nvtx.range_pop()
            torch.cuda.synchronize()

            # 記錄執行時間
            execution_time = (start_event.elapsed_time(end_event)) / 1000
            self.total_execution_time += execution_time
            self.total_matmul_add_time += execution_time
            print(f"Add Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")
            
            return result

        elif op_type == "Sub":
            A = torch.tensor(inputs[0], dtype=torch.float32, device="cuda").contiguous()
            B = torch.tensor(inputs[1], dtype=torch.float32, device="cuda").contiguous()

            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            # nvtx.range_push(node.op_type + ", " + node.name)
            start_event.record()
            
            result = (A - B).cpu().numpy()

            end_event.record()
            # nvtx.range_pop()
            torch.cuda.synchronize()

            # 記錄執行時間
            execution_time = (start_event.elapsed_time(end_event)) / 1000
            self.total_execution_time += execution_time
            print(f"Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")

            return result
        elif op_type == "Mul":
            # 確保形狀兼容,否則調整形狀
            try:
                A = torch.tensor(inputs[0], dtype=torch.float32, device="cuda").contiguous()
                B = torch.tensor(inputs[1], dtype=torch.float32, device="cuda").contiguous()
                start_event = torch.cuda.Event(enable_timing=True)
                end_event = torch.cuda.Event(enable_timing=True)

                torch.cuda.synchronize()
                # nvtx.range_push(node.op_type + ", " + node.name)
                start_event.record()
                
                result = (A * B).cpu().numpy()

                end_event.record()
                # nvtx.range_pop()
                torch.cuda.synchronize()

                # 記錄執行時間
                execution_time = (start_event.elapsed_time(end_event)) / 1000
                self.total_execution_time += execution_time
                print(f"Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")

            except ValueError:
                # 嘗試廣播形狀
                print("Broadcasting shapes for Mul operation...")
                inputs[1] = np.broadcast_to(inputs[1], inputs[0].shape)
                A = torch.tensor(inputs[0], dtype=torch.float32, device="cuda").contiguous()
                B = torch.tensor(inputs[1], dtype=torch.float32, device="cuda").contiguous()
                start_event = torch.cuda.Event(enable_timing=True)
                end_event = torch.cuda.Event(enable_timing=True)

                torch.cuda.synchronize()
                # nvtx.range_push(node.op_type + ", " + node.name)
                start_event.record()
                
                result = (A * B).cpu().numpy()

                end_event.record()
                # nvtx.range_pop()
                torch.cuda.synchronize()

                # 記錄執行時間
                execution_time = (start_event.elapsed_time(end_event)) / 1000
                self.total_execution_time += execution_time
                print(f"Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")
            return result

        elif op_type == "Div":
            A = torch.tensor(inputs[0], dtype=torch.float32, device="cuda").contiguous()
            B = torch.tensor(inputs[1], dtype=torch.float32, device="cuda").contiguous()
            
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            # nvtx.range_push(node.op_type + ", " + node.name)
            start_event.record()
            
            result = (A / B).cpu().numpy()

            end_event.record()
            # nvtx.range_pop()
            torch.cuda.synchronize()

            execution_time = (start_event.elapsed_time(end_event)) / 1000
            self.total_execution_time += execution_time
            print(f"Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")

            return result
        elif op_type == "Sqrt":
            A = torch.tensor(inputs[0], dtype=torch.float32, device="cuda").contiguous()
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            # nvtx.range_push(node.op_type + ", " + node.name)
            start_event.record()
            
            result = torch.sqrt(A).cpu().numpy()

            end_event.record()
            # nvtx.range_pop()
            torch.cuda.synchronize()
            
            execution_time = (start_event.elapsed_time(end_event)) / 1000
            self.total_execution_time += execution_time
            print(f"Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")
            
            return result
        elif op_type == "Reciprocal":
            A = torch.tensor(inputs[0], dtype=torch.float32, device="cuda").contiguous()
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            # nvtx.range_push(node.op_type + ", " + node.name)
            start_event.record()
            
            result = (1 / A).cpu().numpy()

            end_event.record()
            # nvtx.range_pop()
            torch.cuda.synchronize()
            
            execution_time = (start_event.elapsed_time(end_event)) / 1000
            self.total_execution_time += execution_time
            print(f"Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")

            return result
        
        elif op_type == "Shape":
            # 確保輸入數據轉換為 PyTorch 張量
            input_data = torch.tensor(inputs[0])
            
            # 計時器事件
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            # nvtx.range_push(node.op_type + ", " + node.name)
            start_event.record()

            # 提取輸入張量的形狀
            shape = torch.tensor(list(input_data.shape), dtype=torch.int64)

            end_event.record()
            # nvtx.range_pop()
            torch.cuda.synchronize()

            # 計算執行時間
            execution_time = (start_event.elapsed_time(end_event)) / 1000
            self.total_execution_time += execution_time
            print(f"Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")

            return shape.numpy()  # 返回 NumPy 格式的結果
        
        elif op_type == "Transpose":
            A = torch.tensor(inputs[0], dtype=torch.float32, device="cuda")
            perm = [attr.ints for attr in node.attribute if attr.name == "perm"]
            if not perm:
                perm = list(range(A.ndim))[::-1]
            else:
                perm = perm[0]
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            # nvtx.range_push(node.op_type + ", " + node.name)
            start_event.record()
            
            result = A.permute(*perm).cpu().numpy()

            end_event.record()
            # nvtx.range_pop()
            torch.cuda.synchronize()
            
            execution_time = (start_event.elapsed_time(end_event)) / 1000
            self.total_execution_time += execution_time
            print(f"Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")
            
            return result

        elif op_type == "Reshape":
            # 確保輸入數據轉換為 PyTorch 張量
            input_data = torch.tensor(inputs[0], device="cuda")
            # 提取目標 shape
            if len(inputs) > 1:
                shape = inputs[1].astype(np.int64).tolist()
            else:
                shape = list(node.attribute[0].ints)
            
            # 計時器事件
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            # nvtx.range_push(node.op_type + ", " + node.name)
            start_event.record()

            # 使用 PyTorch 的 reshape 函數
            result = input_data.reshape(shape)

            end_event.record()
            # nvtx.range_pop()
            torch.cuda.synchronize()

            # 計算執行時間
            execution_time = (start_event.elapsed_time(end_event)) / 1000
            self.total_execution_time += execution_time
            print(f"Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")
            return result.cpu().numpy()  # 返回 NumPy 格式的結果

        elif op_type == "Concat":
            # 從屬性或輸入中獲取 axis
            axis = next((attr.i for attr in node.attribute if attr.name == "axis"), 0)
            
            # 確保輸入數據為 PyTorch 張量
            input_tensors = [
                torch.tensor(inp, device="cuda")
                for inp in inputs
            ]
            
            # 計時器事件
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            # nvtx.range_push(node.op_type + ", " + node.name)
            start_event.record()
            
            # 使用 PyTorch 的 cat 函數進行拼接
            result = torch.cat(input_tensors, dim=axis)

            end_event.record()
            # nvtx.range_pop()
            torch.cuda.synchronize()

            # 計算執行時間
            execution_time = (start_event.elapsed_time(end_event)) / 1000
            self.total_execution_time += execution_time
            print(f"Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")
            
            # 將結果轉換為 NumPy 格式並返回
            return result.cpu().numpy()
        
        elif op_type == "Squeeze":
            # 檢查並轉換輸入數據
            if inputs[0] is None:
                raise ValueError("Input data for Squeeze operation is None.")
            data = torch.tensor(inputs[0], device="cuda")
            if data is None:
                raise ValueError("Failed to convert input data to a tensor.")

            # 從屬性中提取 axes
            axes = [attr.i for attr in node.attribute if attr.name == "axes"]
            if not axes:
                axes = None  # 如果屬性中未提供 axes,設為 None
            
            # 計時器事件
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            # nvtx.range_push(node.op_type + ", " + node.name)
            start_event.record()

            # 執行 Squeeze 操作
            if axes is not None:
                # 處理負軸索引並過濾掉非 1 的軸
                axes = [axis if axis >= 0 else axis + data.dim() for axis in axes]
                valid_axes = [axis for axis in axes if data.size(axis) == 1]
                if not valid_axes:
                    raise ValueError("Cannot squeeze axes that do not have size equal to one.")
                for axis in sorted(valid_axes, reverse=True):  # 按降序處理以避免索引錯誤
                    data = torch.squeeze(data, dim=axis)
            else:
                data = torch.squeeze(data)  # 移除所有大小為 1 的維度

            result = data

            end_event.record()
            # nvtx.range_pop()
            torch.cuda.synchronize()

            # 計算執行時間
            execution_time = (start_event.elapsed_time(end_event)) / 1000
            self.total_execution_time += execution_time
            print(f"Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")

            return result.cpu().numpy()
        
        elif op_type == "Unsqueeze":
            # 提取輸入數據,將 NumPy 陣列轉為 PyTorch 張量
            data = torch.tensor(inputs[0], device="cuda")
            
            # 從屬性中提取 axes
            axes = [attr.i for attr in node.attribute if attr.name == "axes"]
            if not axes:
                raise ValueError("Attribute 'axes' is missing for Unsqueeze operation.")
            
            # 計時器事件
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            # nvtx.range_push(node.op_type + ", " + node.name)
            start_event.record()
            
            # 將數據執行 unsqueeze 操作
            for axis in sorted(axes):  # 確保按升序插入維度
                data = torch.unsqueeze(data, dim=axis)

            result = data

            end_event.record()
            # nvtx.range_pop()
            torch.cuda.synchronize()

            # 計算執行時間
            execution_time = (start_event.elapsed_time(end_event)) / 1000
            self.total_execution_time += execution_time
            print(f"Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")

            return result.cpu().numpy()
        
        elif op_type == "Identity":
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            # nvtx.range_push(node.op_type + ", " + node.name)
            start_event.record()
            
            result = torch.tensor(inputs[0], dtype=torch.int64, device="cuda").cpu().numpy()

            end_event.record()
            # nvtx.range_pop()
            torch.cuda.synchronize()
            
            execution_time = (start_event.elapsed_time(end_event)) / 1000
            self.total_execution_time += execution_time
            print(f"Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")
            
            return result
        
        elif op_type == "Tanh":
            A = torch.tensor(inputs[0], dtype=torch.float32, device="cuda").contiguous()
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            # nvtx.range_push(node.op_type + ", " + node.name)
            start_event.record()
            
            result = torch.tanh(A).cpu().numpy()

            end_event.record()
            # nvtx.range_pop()
            torch.cuda.synchronize()
            
            execution_time = (start_event.elapsed_time(end_event)) / 1000
            self.total_execution_time += execution_time
            print(f"Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")
            
            return result
        
        elif op_type == "Sigmoid":
            A = torch.tensor(inputs[0], dtype=torch.float32, device="cuda").contiguous()
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            # nvtx.range_push(node.op_type + ", " + node.name)
            start_event.record()
            
            result = torch.sigmoid(A).cpu().numpy()

            end_event.record()
            # nvtx.range_pop()
            torch.cuda.synchronize()
            
            execution_time = (start_event.elapsed_time(end_event)) / 1000
            self.total_execution_time += execution_time
            print(f"Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")
            
            return result
        
        elif op_type == "Relu":
            A = torch.tensor(inputs[0], dtype=torch.float32, device="cuda").contiguous()
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            # nvtx.range_push(node.op_type + ", " + node.name)
            start_event.record()
            
            result = torch.relu(A).cpu().numpy()

            end_event.record()
            # nvtx.range_pop()
            torch.cuda.synchronize()
            
            execution_time = (start_event.elapsed_time(end_event)) / 1000
            self.total_execution_time += execution_time
            print(f"Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")
            
            return result
        
        elif op_type == "Pow":
            A = torch.tensor(inputs[0], dtype=torch.float32, device="cuda").contiguous()
            B = torch.tensor(inputs[1], dtype=torch.float32, device="cuda").contiguous()
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            # nvtx.range_push(node.op_type + ", " + node.name)
            start_event.record()
            
            result = torch.pow(A, B).cpu().numpy()

            end_event.record()
            # nvtx.range_pop()
            torch.cuda.synchronize()
            
            execution_time = (start_event.elapsed_time(end_event)) / 1000
            self.total_execution_time += execution_time
            print(f"Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")
            
            return result
        
        elif op_type == "Gather":
            # 將 NumPy 輸入數據轉換為 PyTorch 張量
            data = torch.tensor(inputs[0], dtype=torch.float32)
            indices = torch.tensor(inputs[1], dtype=torch.int64)
            
            # 獲取 axis,默認為 0
            axis = self.tensors[node.input[2]] if len(node.input) > 2 and node.input[2] in self.tensors else 0

            # 將 indices 的形狀調整為與 data 匹配
            if axis < 0:
                axis += data.ndim  # 處理負軸索引
            expanded_indices = indices.unsqueeze(-1).expand(*indices.shape, *data.shape[axis + 1 :])

            # 計時器事件
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            # nvtx.range_push(node.op_type + ", " + node.name)
            start_event.record()

            # 使用 PyTorch 的 gather 操作
            result = torch.gather(data, dim=axis, index=expanded_indices)

            end_event.record()
            # nvtx.range_pop()
            torch.cuda.synchronize()

            # 計算執行時間
            execution_time = (start_event.elapsed_time(end_event)) / 1000
            self.total_execution_time += execution_time
            print(f"Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")

            return result.cpu().numpy()
        
        elif op_type == "ReduceMean":
            A = torch.tensor(inputs[0], dtype=torch.float32, device="cuda")
            axes = inputs[1] if len(inputs) > 1 else None
            keepdims = inputs[2] if len(inputs) > 2 else True

            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            # nvtx.range_push(node.op_type + ", " + node.name)
            start_event.record()
            
            result = torch.mean(A, dim=axes, keepdim=keepdims).cpu().numpy()

            end_event.record()
            # nvtx.range_pop()
            torch.cuda.synchronize()
            
            execution_time = (start_event.elapsed_time(end_event)) / 1000
            self.total_execution_time += execution_time
            print(f"Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")
            return result
        
        elif op_type == "Cast":
            # 定義類型映射
            dtype_map = {
                1: torch.float32,  # FLOAT
                2: torch.uint8,    # UINT8
                3: torch.int8,     # INT8
                # 4: torch.uint16,   # UINT16
                5: torch.int16,    # INT16
                6: torch.int32,    # INT32
                7: torch.int64,    # INT64
                9: torch.bool,     # BOOL
                10: torch.float32, # FLOAT32
                11: torch.float64, # DOUBLE
                # 12: torch.uint32,  # UINT32
                # 13: torch.uint64,  # UINT64
            }
            
            # 提取目標數據類型
            target_type = node.attribute[0].i if node.attribute else None
            if target_type not in dtype_map:
                raise NotImplementedError(f"Unsupported target type {target_type} for Cast operation.")

            # 計時器事件
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            # nvtx.range_push(node.op_type + ", " + node.name)
            start_event.record()
            
            # 將輸入 NumPy 陣列轉換為 PyTorch 張量
            input_tensor = torch.tensor(inputs[0]) if isinstance(inputs[0], np.ndarray) else inputs[0]
            
            # 進行類型轉換
            result = input_tensor.to(dtype=dtype_map[target_type])

            end_event.record()
            # nvtx.range_pop()
            torch.cuda.synchronize()

            # 計算執行時間
            execution_time = (start_event.elapsed_time(end_event)) / 1000
            self.total_execution_time += execution_time
            print(f"Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")

            return result.cpu().numpy()
        
        elif op_type == "ConstantOfShape":
            # 提取輸入數據,將 NumPy 陣列轉為 PyTorch 張量
            shape = torch.tensor(inputs[0], dtype=torch.int64)  # 確保形狀為 int64 張量
            value = node.attribute[0].t if node.attribute else 0  # 默認為 0,若有指定則使用屬性中的值

            # 如果 value 存在,轉換為 PyTorch 張量
            constant_value = torch.tensor(np.frombuffer(value.raw_data, dtype=np.float32)) if value else torch.tensor(0.0, dtype=torch.float32)
            
            # 計時器事件
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            # nvtx.range_push(node.op_type + ", " + node.name)
            start_event.record()

            # 使用 PyTorch 的 `torch.full` 來生成指定形狀的常數張量
            result = torch.full(shape.tolist(), constant_value.item(), dtype=constant_value.dtype)

            end_event.record()
            # nvtx.range_pop()
            torch.cuda.synchronize()

            # 計算執行時間
            execution_time = (start_event.elapsed_time(end_event)) / 1000
            self.total_execution_time += execution_time
            print(f"Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")

            return result.cpu().numpy()
        
        elif op_type == "OneHot":
            # 提取輸入數據,將 NumPy 轉換為 PyTorch 張量
            indices = torch.tensor(inputs[0], dtype=torch.int64)  # 索引數據
            depth = int(inputs[1])  # one-hot 深度
            values = torch.tensor(inputs[2], dtype=torch.float32) if len(inputs) > 2 else torch.tensor([0, 1], dtype=torch.float32)  # one-hot 值
            axis = next((attr.i for attr in node.attribute if attr.name == "axis"), -1)  # 默認 -1(最後一個軸)

            # 計時器事件
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            # nvtx.range_push(node.op_type + ", " + node.name)
            start_event.record()

            # 驗證索引範圍,超出範圍的值設為全 off_value
            indices = torch.clamp(indices, min=0, max=depth - 1)

            # 建立 one-hot 編碼
            one_hot = torch.nn.functional.one_hot(indices, num_classes=depth).to(dtype=torch.float32)

            # 調整 one-hot 的軸位置
            if axis != -1:
                one_hot = one_hot.permute(*list(range(indices.ndim)), -1, axis)

            # 使用 values 替換默認的 0 和 1
            result = one_hot * (values[1] - values[0]) + values[0]

            end_event.record()
            # nvtx.range_pop()
            torch.cuda.synchronize()

            # 計算執行時間
            execution_time = (start_event.elapsed_time(end_event)) / 1000
            self.total_execution_time += execution_time
            print(f"Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")

            return result.cpu().numpy()
        
        elif op_type == "Softmax":
            A = torch.tensor(inputs[0], dtype=torch.float32, device="cuda")
            axis = node.attribute[0].i if node.attribute else -1
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            # nvtx.range_push(node.op_type + ", " + node.name)
            start_event.record()
            
            result = torch.softmax(A, dim=axis).cpu().numpy()

            end_event.record()
            # nvtx.range_pop()
            torch.cuda.synchronize()
            
            execution_time = (start_event.elapsed_time(end_event)) / 1000
            self.total_execution_time += execution_time
            print(f"Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")

            return result
        
        elif op_type == "Split":
            input_data = torch.tensor(inputs[0])
            
            # 從屬性中提取 axis
            axis = next((attr.i for attr in node.attribute if attr.name == "axis"), 0)

            # 提取可選的 split 張量
            split_tensor = inputs[1] if len(inputs) > 1 else None
            split = split_tensor.tolist() if split_tensor is not None else None
            
            # 計時器事件
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)
            
            torch.cuda.synchronize()
            # nvtx.range_push(node.op_type + ", " + node.name)
            start_event.record()

            # 處理負軸索引
            if axis < 0:
                axis += input_data.dim()

            # 如果未指定 split,均勻分割
            if split is None:
                num_outputs = len(node.output)  # 獲取輸出的數量
                total_size = input_data.size(axis)
                if total_size % num_outputs != 0:
                    raise ValueError(f"Cannot evenly split axis {axis} into {num_outputs} parts.")
                split_size = total_size // num_outputs
                split = [split_size] * num_outputs
            else:
                # 確保 split 和目標軸的大小一致
                if sum(split) != input_data.size(axis):
                    raise ValueError(f"Sum of split sizes {sum(split)} does not match size of axis {axis} ({input_data.size(axis)}).")

            # 使用 PyTorch 的 split 函數進行分割
            result = torch.split(input_data, split, dim=axis)

            # 將分割結果存儲到對應的輸出名稱中
            for i, output_name in enumerate(node.output):
                self.tensors[output_name] = result[i].cpu().numpy()  # 將結果轉換為 NumPy 格式
            
            end_event.record()
            # nvtx.range_pop()
            torch.cuda.synchronize()
            
            # 計算執行時間
            execution_time = (start_event.elapsed_time(end_event)) / 1000
            self.total_execution_time += execution_time
            print(f"Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")

            return result[0].cpu().numpy()  # 返回所有分割的結果

        else:
            raise NotImplementedError(f"Operation {op_type} is not implemented.")

    def execute(self, inputs):
        # 確保 inputs 被添加到張量字典中
        for input_name, input_value in inputs.items():
            self.tensors[input_name] = input_value
        # 打印所有輸入的詳細資訊
        print("\nInputs Details:")
        for input_name, input_value in inputs.items():
            print(f"Input Name: {input_name}")
            print(f"Shape: {input_value.shape if isinstance(input_value, np.ndarray) else 'N/A'}")
            if isinstance(input_value, np.ndarray):
                print(f"Data (first 10 values): {input_value.flatten()[:10]}...")
            else:
                print(f"Data: {input_value}")
            print("-" * 50)

        execution_order = []
        node_execution_times = {}  # 用於記錄每個節點的執行時間
        self.total_execution_time = 0.0
        total_execution_time = 0.0  # 累計所有節點的執行時間
        total_fuse_execution_time = 0.0
        self.total_matmul_add_time = 0.0
        total_matmul_add_time = 0.0

        for node in self.nodes:
            # print(f"Executing node: {node.name}")
            # 檢查節點的所有輸入是否已準備好
            ready = all(inp in self.tensors for inp in node.input)
            if ready:
                if node.name in self.processed_nodes:
                    print(f"Skipping already processed Node: {node.name}\n")
                else:
                    start_event = torch.cuda.Event(enable_timing=True)
                    end_event = torch.cuda.Event(enable_timing=True)

                    torch.cuda.synchronize()
                    start_event.record()

                    self.tensors[node.output[0]] = self._execute_node(node)

                    end_event.record()
                    torch.cuda.synchronize()

                    # print(node.name, self.tensors[node.output[0]])

                    # 記錄執行時間
                    execution_time = (start_event.elapsed_time(end_event)) / 1000
                    node_execution_times[node.name] = execution_time
                    total_execution_time += execution_time  # 累計執行時間
                    # if node.name in self.matmul_add_fuse_node:
                    #     print(f"MatMul Fuse Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")
                    # elif node.op_type == "MatMul":
                    #     print(f"MatMul Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")
                    # elif node.op_type == "Add":
                    #     print(f"Add Node: {node.name}, Execution Time: {execution_time:.6f} seconds\n")
                    if node.op_type == "MatMul" or node.op_type == "Add":
                        total_matmul_add_time += execution_time

                execution_order.append(node)
            else:
                print(f"Skipping node '{node.name}' due to missing inputs.")
        # 打印所有節點的執行時間
        print("\nNode Execution Times:")
        # for node_name, exec_time in node_execution_times.items():
        #     if node_name in self.matmul_add_fuse_node:
        #         print(f"Matmul Fuse Node: {node_name}, Execution Time: {exec_time:.6f} seconds")
        #         total_fuse_execution_time += exec_time
        #     else:
        #         print(f"Node: {node_name}, Execution Time: {exec_time:.6f} seconds")

        # 打印所有節點的總執行時間
        # print(f"\nTotal Execution Time: {total_execution_time:.6f} seconds")
        # print(f"\nTotal Matmul Fuse Execution Time: {total_fuse_execution_time:.6f} seconds")
        # print(f"\nTotal Matmul + Add Execution Time: {total_matmul_add_time:.6f} seconds")

        print(f"\nTotal Execution Time: {self.total_execution_time:.6f} seconds")
        print(f"\nTotal Matmul + Add Execution Time: {self.total_matmul_add_time:.6f} seconds")

        # 收集所有輸出的張量
        outputs = {o.name: self.tensors[o.name] for o in self.graph.output if o.name in self.tensors}
        return outputs, execution_order


def main():
    compiler = ONNXCompiler("model/bertsquad-10_simplified.onnx")
    # 示例问题和上下文
    question = "What is the capital of France?"
    context = "The capital of France is Paris."

    # 分词
    inputs = tokenizer(question, context, return_tensors='np', padding='max_length', max_length=256, truncation=True)

    # 提取输入数据
    input_ids = inputs['input_ids'].astype(np.int64)
    segment_ids = inputs['token_type_ids'].astype(np.int64)
    input_mask = inputs['attention_mask'].astype(np.int64)
    unique_ids_raw_output = np.array([0], dtype=np.int64)

    # input_ids = np.random.randint(0, 30522, size=(1, 256), dtype=np.int64)
    # segment_ids = np.random.randint(0, 2, size=(1, 256), dtype=np.int64)
    # input_mask = np.random.randint(0, 2, size=(1, 256), dtype=np.int64)
    # unique_ids_raw_output = np.random.randint(0, 2, size=(1), dtype=np.int64)

    try:
        print("Starting model execution...")
        start_time = time.time()  # 計算開始時間
        outputs, execution_order = compiler.execute({
            "input_ids:0": input_ids,
            "segment_ids:0": segment_ids,
            "input_mask:0": input_mask,
            "unique_ids_raw_output___9:0": unique_ids_raw_output
        })
        end_time = time.time()  # 計算結束時間
        
        print("Execution complete.")
        # print(f"\nTotal execution time: {end_time - start_time:.6f} seconds")  # 打印總執行時間

        print("Model outputs:", outputs)
        # print("Execution order:", [node.name for node in execution_order])

        start_logits = outputs['unstack:1']  # 提取 start logits
        end_logits = outputs['unstack:0']    # 提取 end logits

        # 計算起始和結束索引
        start_index = np.argmax(start_logits)  # 找到 start logits 最大值的位置
        end_index = np.argmax(end_logits) + 1  # 找到 end logits 最大值的位置

        # 解碼答案
        answer = tokenizer.convert_tokens_to_string(
            tokenizer.convert_ids_to_tokens(input_ids[0][start_index:end_index])
        )

        print("\nQuestion:", question)
        print("Context:", context)
        print("Answer:", answer)

    except ValueError as e:
        print(f"Error: {e}")

if __name__ == "__main__":
    main()
Leave a Comment