Untitled
unknown
python
4 years ago
4.2 kB
8
Indexable
import os
import sys
import math
import torch
import torch.nn.functional as F
import torch.nn as nn
class input_kernel(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=False):
super(input_kernel, self).__init__()
self.kernel=nn.Sequential(
nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, groups=groups, bias=bias, dilation=dilation),
nn.BatchNorm1d(out_channels, eps=1e-4, momentum=0.1, affine=True),
nn.ReLU(inplace=True),
nn.MaxPool1d(kernel_size = 8,stride = 4)
)
def forward(self, inputs):
outputs=self.kernel(inputs)
return outputs
class conv_kernel(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=False):
super(conv_kernel, self).__init__()
self.kernel=nn.Sequential(
nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size,
stride=stride, padding=padding, groups=groups, bias=bias, dilation=dilation),
nn.BatchNorm1d(in_channels, eps=1e-4, momentum=0.1, affine=True),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1,
stride=1, padding=0, groups=1, bias=bias, dilation=1),
nn.BatchNorm1d(out_channels, eps=1e-4, momentum=0.1, affine=True),
nn.ReLU(inplace=True)
)
def forward(self, inputs):
outputs=self.kernel(inputs)
return outputs
class BLSTM(nn.Module):
def __init__(self, dim, layers=2, bi=True):
super().__init__()
klass = nn.LSTM
self.lstm = klass(bidirectional=bi, num_layers=layers, hidden_size=dim, input_size=dim)
self.linear = None
if bi:
self.linear = nn.Linear(2 * dim, 480)
def forward(self, x, hidden=None):
x, hidden = self.lstm(x, hidden)
if self.linear:
x = self.linear(x)
return x, hidden
class snr_net(nn.Module):
def __init__(self):
super(snr_net, self).__init__()
#self.train_enb=training
self.hidden_size = 480
self.encoder=torch.nn.Sequential(
input_kernel(1, 4, kernel_size=8, stride=2, dilation=1, padding=0, bias=False),
input_kernel(4, 8, kernel_size=8, stride=2, dilation=1, padding=0, bias=False),
input_kernel(8, 16, kernel_size=8, stride=2, dilation=1, padding=0, bias=False),
)
self.lstm = BLSTM(16384, bi=True)
self.rnn = nn.LSTM(
input_size=16384,
hidden_size=self.hidden_size,
num_layers=2,
batch_first=True,
bidirectional=True,
)
# self.relu =
self.out = nn.Linear(480, 50)
def forward(self, input):
x = input
x = x.permute(2, 0, 1)
x, _ = self.lstm(x)
x = x.permute(1, 2, 0)
input = x
print(x.size())
#print(input.shape)
#input = input.unsqueeze(1)#(B,1,T)
#print(input.shape)
# hidden_state = None
# inputsize = input.shape
# input = input.view(inputsize[0],-1,16384) #(B,16384,L/16384)
# print(input.size())
# # input = input.permute(2,0,1)
# input,_ = self.rnn(input) #(B,L/16384,480)
# # print(hidden_size)#32
# # input = input.reshape(inputsize[0], self.hidden_size) #(B*L/16384,480)
# input = self.out(input) #(B*L/16384,1)
# print(input.size())
# m = nn.Softmax(dim=1)
# # input =
return input
#
m = snr_net()
x = torch.rand(5,1,16384)
m(x)Editor is loading...