Untitled

mail@pastecode.io avatar
unknown
plain_text
10 days ago
3.3 kB
2
Indexable
Never
import torch
import torch.nn as nn
import torch.nn.functional as F

class CAE(nn.Module):
    """Convolutional AutoEncoder module."""
    def __init__(self, list_of_sizes, dropout_rate=0.5):
        super(CAE, self).__init__()
        self.encoder = self._build_encoder(list_of_sizes, dropout_rate)
        self.decoder = self._build_decoder(list_of_sizes, dropout_rate)

    def _build_encoder(self, sizes, dropout):
        layers = []
        for i in range(1, len(sizes)):
            layers += [
                nn.Conv1d(sizes[i-1], sizes[i], kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.Dropout(dropout)
            ]
        return nn.Sequential(*layers)

    def _build_decoder(self, sizes, dropout):
        layers = []
        for i in reversed(range(1, len(sizes))):
            layers += [
                nn.ConvTranspose1d(sizes[i], sizes[i-1], kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.Dropout(dropout)
            ]
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

class BasicBlock(nn.Module):
    """Basic convolutional and LSTM block for feature extraction."""
    def __init__(self, in_channels, out_cnns, out_lstms, dropout_rate=0.5):
        super(BasicBlock, self).__init__()
        self.conv = nn.Conv1d(in_channels, out_cnns, kernel_size=3, stride=1, padding=1)
        self.bn = nn.BatchNorm1d(out_cnns)
        self.lstm = nn.LSTM(out_cnns, out_lstms, batch_first=True, bidirectional=True)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = x.permute(0, 2, 1)  # (batch, feature, seq) to (batch, seq, feature)
        x, _ = self.lstm(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.permute(0, 2, 1)  # Restore to (batch, feature, seq)
        return x

class CAEL(nn.Module):
    """Composite model combining the CAE and BasicBlocks with a final classifier."""
    def __init__(self, block, in_channel, in_caes, in_cnns, in_lstms):
        super(CAEL, self).__init__()
        self.autoencoder = CAE([in_channel, *in_caes], dropout_rate=0.1)
        self.layer1 = block(in_caes[-1], in_cnns[0], in_lstms[0], dropout_rate=0.3)
        self.layer2 = block(in_lstms[0]*2, in_cnns[1], in_lstms[1], dropout_rate=0.3)
        self.layer3 = block(in_lstms[1]*2, in_cnns[2], in_lstms[2], dropout_rate=0.3)
        self.fc = nn.Linear(in_lstms[-1]*2, in_channel)  # Adjusting for bidirectional output

    def forward(self, x):
        x = self.autoencoder.encoder(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = torch.mean(x, dim=2)  # Average pooling over sequence
        out = self.fc(x)
        return out

# Example of initializing and using the model
if __name__ == "__main__":
    model = CAEL(BasicBlock, 3, [32, 64, 128], [64, 128, 256], [128, 256, 512])
    input_tensor = torch.randn(10, 3, 100)  # (batch_size, channels, length)
    output = model(input_tensor)
    print("Output shape:", output.shape)
Leave a Comment