Untitled

 avatar
user_9286359
plain_text
15 days ago
2.2 kB
3
Indexable
import torch
import torch.nn as nn
import torch.onnx

# Define a deep convolutional neural network
class DeepConvNet(nn.Module):
    def __init__(self, in_channels=3, num_classes=10):
        super(DeepConvNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1),  # Layer 1
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),           # Layer 2
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),          # Layer 3
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),                            # Layer 4 (Pooling)
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),          # Layer 5
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),          # Layer 6
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)                             # Layer 7 (Pooling)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512 * 16 * 16, 1024),                                   # Fully connected layer
            nn.ReLU(),
            nn.Linear(1024, num_classes)                                      # Output layer
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# Initialize the model
model = DeepConvNet(in_channels=3, num_classes=10)
model.eval()  # Set to evaluation mode

# Dummy input for ONNX export (batch size 1, 3 channels, 64x64 image)
dummy_input = torch.randn(1, 3, 64, 64)

# Path to save the ONNX model
onnx_path = "deep_conv_net.onnx"

# Export the model to ONNX
torch.onnx.export(
    model,                  # Model to export
    dummy_input,            # Example input
    onnx_path,              # Path to save the ONNX file
    export_params=True,     # Store trained parameter weights inside the model
    opset_version=11,       # ONNX opset version
    input_names=["input"],  # Input tensor name
    output_names=["output"] # Output tensor name
)

print(f"Deep CNN model has been converted to ONNX and saved to {onnx_path}")
Leave a Comment