Untitled

 avatar
user_5573880
python
6 months ago
8.0 kB
2
Indexable
import onnx

# Load the ONNX model
model = onnx.load("bidaf-11-int8.onnx")

# Create a dictionary to track which node produces each output tensor
output_to_node_map = {}
input_to_node_map = {}
for node in model.graph.node:
    for output in node.output:
        output_to_node_map[output] = node
    for input_name in node.input:
        if input_name not in input_to_node_map:
            input_to_node_map[input_name] = []
        input_to_node_map[input_name].append(node)

# Define the stopping operations, including MatMulInteger
stopping_ops = {"Transpose", "Unsqueeze", "Gather", "Reshape", "Squeeze", "MatMulInteger"}

# Define the types of elementwise operations
elementwise_ops = {
    "Add", "Sub", "Mul", "Div",            # Basic arithmetic
    "Relu", "Sigmoid", "Tanh", "LeakyRelu", "Softmax", "HardSigmoid",  # Activation functions
    "Greater", "Less", "Equal", "And", "Or", "Not", "Xor",             # Logical operations
    "Abs", "Neg", "Reciprocal", "Sqrt", "Exp", "Log", "Pow",           # Unary math functions
    "Sin", "Cos", "Tan", "Ceil", "Floor", "Clip", "Round", "Sign",     # More math functions
    "Max", "Min", "Mean", "Sum", "Square", "Erf"                       # Aggregation and other ops
}

# Function to find previous operations along each input path until a stopping operation is encountered
def find_previous_ops(input_name, visited_nodes, elementwise_visited):
    all_paths = []
    current_path = []
    elementwise_count = 0

    if input_name in output_to_node_map:
        node = output_to_node_map[input_name]
        if node.name in visited_nodes:
            return [], elementwise_count  # Skip already visited nodes
        visited_nodes.add(node.name)
        current_path.append(node.op_type)

        # Only count elementwise operations if they haven't been visited in this path yet
        if node.op_type in elementwise_ops and node.name not in elementwise_visited:
            elementwise_visited.add(node.name)
            elementwise_count += 1

        # Stop if this node is a stopping operation
        if node.op_type in stopping_ops:
            return [current_path], elementwise_count

        # Recursively search for each input path, fully traversing each path until stopping
        for node_input in node.input:
            if node_input in output_to_node_map:
                sub_paths, sub_elementwise_count = find_previous_ops(node_input, visited_nodes, elementwise_visited)
                for path in sub_paths:
                    all_paths.append(current_path + path)
                elementwise_count += sub_elementwise_count

    if not all_paths:  # If no sub-paths were found, the current path itself is a valid path
        all_paths.append(current_path)

    return all_paths, elementwise_count

# Function to remove duplicate paths and filter out empty paths
def filter_duplicate_and_empty_paths(paths):
    filtered_paths = []
    seen_paths = set()

    for path in paths:
        path_tuple = tuple(path)  # Convert list to tuple so it can be added to set
        if path and path_tuple not in seen_paths:  # Check if path is not empty and not seen
            seen_paths.add(path_tuple)
            filtered_paths.append(path)

    return filtered_paths

# Function to find following operations until a stopping operation is encountered
# Stops when node has multiple output paths (even if only one output node)
def find_following_ops(output_name, visited_nodes, elementwise_visited):
    all_following_paths = []
    current_following_path = []
    elementwise_count = 0

    # Find the next node that uses this output as input
    if output_name in input_to_node_map:
        for node in input_to_node_map[output_name]:
            if node.name in visited_nodes:
                continue  # Skip already visited nodes to avoid duplication
            visited_nodes.add(node.name)
            current_following_path.append(node.op_type)

            # Only count elementwise operations if they haven't been visited yet
            if node.op_type in elementwise_ops and node.name not in elementwise_visited:
                elementwise_visited.add(node.name)
                elementwise_count += 1

            # Stop if this node is a stopping operation (including MatMulInteger)
            if node.op_type in stopping_ops:
                return [current_following_path], elementwise_count

            # Check if the node has multiple output paths (even if only one output node)
            total_output_connections = sum(len(input_to_node_map.get(output, [])) for output in node.output)
            if total_output_connections > 1:
                return [current_following_path], elementwise_count

            # Recursively search for the next operations
            for next_output in node.output:
                sub_paths, sub_elementwise_count = find_following_ops(next_output, visited_nodes, elementwise_visited)
                for path in sub_paths:
                    all_following_paths.append(current_following_path + path)
                elementwise_count += sub_elementwise_count

    if not all_following_paths:  # If no sub-paths were found, the current path itself is valid
        all_following_paths.append(current_following_path)

    return all_following_paths, elementwise_count

# Count the number of MatMulInteger operations and total operations
matmulinteger_count = 0
total_op_count = 0

# Traverse through the model nodes, find MatMulInteger, and print preceding and following operations
for node in model.graph.node:
    total_op_count += 1
    if node.op_type == "MatMulInteger":
        matmulinteger_count += 1
        print(f"\nMatMulInteger Node: {node.name}")
        
        # Find previous operations for each input path, traversing fully along each path
        all_previous_ops = []
        total_elementwise_count = 0
        path_counter = 1
        elementwise_visited = set()  # Track visited elementwise nodes to avoid counting duplicates
        for input_name in node.input:
            visited_nodes_for_input = set()  # Separate visited set for each input path
            previous_paths, prev_elementwise_count = find_previous_ops(input_name, visited_nodes_for_input, elementwise_visited)
            all_previous_ops.extend(previous_paths)
            total_elementwise_count += prev_elementwise_count

        # Filter out empty and duplicate paths
        filtered_paths = filter_duplicate_and_empty_paths(all_previous_ops)
        
        # Print each path separately after filtering
        for i, path in enumerate(filtered_paths, 1):
            print(f"  Previous Path {i}: {path}")
        
        # Find following operations, traversing fully along each path until stopping
        all_following_ops = []
        visited_nodes = set()  # Reset visited nodes for forward search
        following_elementwise_visited = set()  # Track visited elementwise nodes in following ops
        for output_name in node.output:
            following_paths, next_elementwise_count = find_following_ops(output_name, visited_nodes, following_elementwise_visited)
            all_following_ops.extend(following_paths)
            total_elementwise_count += next_elementwise_count

        # Filter out empty and duplicate following paths
        filtered_following_paths = filter_duplicate_and_empty_paths(all_following_ops)
        
        # Print each following path separately
        for i, path in enumerate(filtered_following_paths, 1):
            print(f"  Following Path {i}: {path}")
        
        # Print elementwise operation count
        print(f"  Number of elementwise operations: {total_elementwise_count}")

# Print the final statistics
print(f"\nTotal number of MatMulInteger operations: {matmulinteger_count}")
print(f"Total number of operations in the model: {total_op_count}")
Editor is loading...
Leave a Comment