Untitled
unknown
python
3 years ago
4.2 kB
4
Indexable
import os import sys import math import torch import torch.nn.functional as F import torch.nn as nn class input_kernel(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False): super(input_kernel, self).__init__() self.kernel=nn.Sequential( nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=bias, dilation=dilation), nn.BatchNorm1d(out_channels, eps=1e-4, momentum=0.1, affine=True), nn.ReLU(inplace=True), nn.MaxPool1d(kernel_size = 8,stride = 4) ) def forward(self, inputs): outputs=self.kernel(inputs) return outputs class conv_kernel(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False): super(conv_kernel, self).__init__() self.kernel=nn.Sequential( nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=bias, dilation=dilation), nn.BatchNorm1d(in_channels, eps=1e-4, momentum=0.1, affine=True), nn.ReLU(inplace=True), nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, groups=1, bias=bias, dilation=1), nn.BatchNorm1d(out_channels, eps=1e-4, momentum=0.1, affine=True), nn.ReLU(inplace=True) ) def forward(self, inputs): outputs=self.kernel(inputs) return outputs class BLSTM(nn.Module): def __init__(self, dim, layers=2, bi=True): super().__init__() klass = nn.LSTM self.lstm = klass(bidirectional=bi, num_layers=layers, hidden_size=dim, input_size=dim) self.linear = None if bi: self.linear = nn.Linear(2 * dim, 480) def forward(self, x, hidden=None): x, hidden = self.lstm(x, hidden) if self.linear: x = self.linear(x) return x, hidden class snr_net(nn.Module): def __init__(self): super(snr_net, self).__init__() #self.train_enb=training self.hidden_size = 480 self.encoder=torch.nn.Sequential( input_kernel(1, 4, kernel_size=8, stride=2, dilation=1, padding=0, bias=False), input_kernel(4, 8, kernel_size=8, stride=2, dilation=1, padding=0, bias=False), input_kernel(8, 16, kernel_size=8, stride=2, dilation=1, padding=0, bias=False), ) self.lstm = BLSTM(16384, bi=True) self.rnn = nn.LSTM( input_size=16384, hidden_size=self.hidden_size, num_layers=2, batch_first=True, bidirectional=True, ) # self.relu = self.out = nn.Linear(480, 50) def forward(self, input): x = input x = x.permute(2, 0, 1) x, _ = self.lstm(x) x = x.permute(1, 2, 0) input = x print(x.size()) #print(input.shape) #input = input.unsqueeze(1)#(B,1,T) #print(input.shape) # hidden_state = None # inputsize = input.shape # input = input.view(inputsize[0],-1,16384) #(B,16384,L/16384) # print(input.size()) # # input = input.permute(2,0,1) # input,_ = self.rnn(input) #(B,L/16384,480) # # print(hidden_size)#32 # # input = input.reshape(inputsize[0], self.hidden_size) #(B*L/16384,480) # input = self.out(input) #(B*L/16384,1) # print(input.size()) # m = nn.Softmax(dim=1) # # input = return input # m = snr_net() x = torch.rand(5,1,16384) m(x)
Editor is loading...