Untitled

mail@pastecode.io avatar
unknown
python
2 years ago
4.2 kB
1
Indexable
Never
import os
import sys
import math

import torch
import torch.nn.functional as F
import torch.nn as nn

class input_kernel(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=False):
        super(input_kernel, self).__init__()
        self.kernel=nn.Sequential(
            nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                               stride=stride, padding=padding, groups=groups, bias=bias, dilation=dilation),
            nn.BatchNorm1d(out_channels, eps=1e-4, momentum=0.1, affine=True),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(kernel_size = 8,stride = 4)
            )
    def forward(self, inputs):
        outputs=self.kernel(inputs)
        return outputs

class conv_kernel(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=False):
        super(conv_kernel, self).__init__()
        self.kernel=nn.Sequential(
            nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size,
                               stride=stride, padding=padding, groups=groups, bias=bias, dilation=dilation),
            nn.BatchNorm1d(in_channels, eps=1e-4, momentum=0.1, affine=True),
            nn.ReLU(inplace=True),
            
            nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1,
                               stride=1, padding=0, groups=1, bias=bias, dilation=1),
            nn.BatchNorm1d(out_channels, eps=1e-4, momentum=0.1, affine=True),
            nn.ReLU(inplace=True)
            )
    def forward(self, inputs):
        outputs=self.kernel(inputs)
        return outputs


class BLSTM(nn.Module):
    def __init__(self, dim, layers=2, bi=True):
        super().__init__()
        klass = nn.LSTM
        self.lstm = klass(bidirectional=bi, num_layers=layers, hidden_size=dim, input_size=dim)
        self.linear = None
        if bi:
            self.linear = nn.Linear(2 * dim, 480)

    def forward(self, x, hidden=None):
        x, hidden = self.lstm(x, hidden)
        if self.linear:
            x = self.linear(x)
        return x, hidden
    
class snr_net(nn.Module):
    def __init__(self):
        super(snr_net, self).__init__()
        #self.train_enb=training  
        self.hidden_size = 480
        self.encoder=torch.nn.Sequential(
            input_kernel(1, 4, kernel_size=8, stride=2, dilation=1, padding=0, bias=False),
            input_kernel(4, 8, kernel_size=8, stride=2, dilation=1, padding=0, bias=False),
            input_kernel(8, 16, kernel_size=8, stride=2, dilation=1, padding=0, bias=False),
            )     
        self.lstm = BLSTM(16384, bi=True)
        
        self.rnn = nn.LSTM(
            input_size=16384,
            hidden_size=self.hidden_size,
            num_layers=2,
            batch_first=True,
            bidirectional=True,

        )
#         self.relu = 
        self.out = nn.Linear(480, 50)
        
    def forward(self, input):
        
        x = input
        
        x = x.permute(2, 0, 1)
        x, _ = self.lstm(x)
        x = x.permute(1, 2, 0)
        
        input = x
        print(x.size())
        
        
        #print(input.shape)
        #input = input.unsqueeze(1)#(B,1,T)
        #print(input.shape)
        
#         hidden_state = None

#         inputsize = input.shape

#         input = input.view(inputsize[0],-1,16384) #(B,16384,L/16384)
#         print(input.size())
    
# #         input = input.permute(2,0,1)  

#         input,_  = self.rnn(input) #(B,L/16384,480)

        
#         # print(hidden_size)#32

# #         input = input.reshape(inputsize[0], self.hidden_size)  #(B*L/16384,480)
        
#         input = self.out(input) #(B*L/16384,1)
#         print(input.size())
#         m = nn.Softmax(dim=1)
# #         input = 

        


        return input
    

# 
m = snr_net()
x = torch.rand(5,1,16384)
m(x)