Untitled
user_9286359
plain_text
5 months ago
1.9 kB
3
Indexable
import torch import torch.nn as nn import torch.onnx # Define a single-layer LSTM model class SingleLayerLSTMStaticShape(nn.Module): def __init__(self, input_size, hidden_size, output_size, sequence_length): super(SingleLayerLSTMStaticShape, self).__init__() self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) # Single-layer LSTM self.sequence_length = sequence_length self.hidden_size = hidden_size self.fc = nn.Linear(hidden_size, output_size) # Fully connected layer for output def forward(self, x): # Pass through LSTM layer lstm_out, _ = self.lstm(x) # Use a static shape to select the last time step out = self.fc(lstm_out[:, self.sequence_length - 1, :]) return out # Initialize the model input_size = 10 # Number of input features hidden_size = 20 # Number of hidden units in the LSTM output_size = 5 # Number of output features sequence_length = 15 # Fixed sequence length model = SingleLayerLSTMStaticShape(input_size, hidden_size, output_size, sequence_length) model.eval() # Set to evaluation mode # Create dummy input (batch size 1, sequence length 15, input size 10) dummy_input = torch.randn(1, sequence_length, input_size) # Path to save the ONNX model onnx_path = "single_layer_lstm_static_shape.onnx" # Export the model to ONNX torch.onnx.export( model, # Model to export dummy_input, # Example input onnx_path, # Path to save the ONNX file export_params=True, # Store trained parameter weights in the model opset_version=11, # ONNX opset version input_names=["input"], # Input tensor name output_names=["output"], # Output tensor name ) print(f"Single-layer LSTM model (static shape) has been converted to ONNX and saved to {onnx_path}")
Editor is loading...
Leave a Comment