Untitled
unknown
plain_text
4 years ago
6.8 kB
5
Indexable
import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from layers import * from torchdiffeq import odeint class Covid19Model(nn.Module): def __init__(self, initN, n_components ,days, sird_parameters, n_dropout): super().__init__() self.n_components = n_components self.cnn_block = CNNBlock(n_components - 2, n_components - 2, 16, norm = True) self.lstm_1 = nn.GRU(n_components - 2, 128, 1, batch_first = True) self.lstm_2 = nn.GRU(128, 256, 1, batch_first = True) self.dropout = nn.Dropout() self.fc = nn.Sequential( nn.Linear(256 + days + 128, 512), nn.LeakyReLU(inplace = True), # nn.Dropout(), nn.Linear(512, 256), nn.LeakyReLU(inplace = True), # nn.Dropout(), nn.Linear(256, 128), nn.LeakyReLU(inplace = True), ) self.dropouts = nn.ModuleList([nn.Dropout() for _ in range(n_dropout)]) self.fc1 = nn.Linear(128, sird_parameters) self.tanh = nn.Tanh() self.sird = SEIRD_LVEC() self.days = days def forward(self, inputs): # inputs1 = (inputs[:, 1:] - inputs[:, :-1])[..., :self.n_components] # inputs1 = inputs1[:, 1:] - inputs1[:, :-1] inputs1 = inputs[..., :self.n_components] inputs1 = inputs1[..., [1, 2, 3, 4, 6, 7]] inputs1 = inputs1.permute(0, 2, 1) inputs1 = self.cnn_block(inputs1) inputs1 = inputs1.permute(0, 2, 1) # ... cnn -> cbam out_lstm_1, state_h1 = self.lstm_1(inputs1) # attention? out_lstm_2, state_h2 = self.lstm_2(out_lstm_1) # 1, N, D -> N, D, 1 state_h = state_h2.permute(1, 2, 0) # N, L, D x N, D, 1 -> N, L, 1 score = torch.matmul(out_lstm_2, state_h) score = F.softmax(score, dim = 1) out_lstm = torch.sum(out_lstm_2 * score, dim = -1) # NUk1, Pu1, ... # NUK2, ... out_concat = torch.cat((out_lstm, state_h1.squeeze(dim = 0), state_h.squeeze(dim = -1)), dim = 1) out_concat = self.dropout(out_concat) out_fc = self.fc(out_concat) for i, dropout in enumerate(self.dropouts): if i == 0: out_fc1 = (self.fc1(dropout(out_fc))) else: out_fc1 = out_fc1 + (self.fc1(dropout(out_fc))) out_fc1 = out_fc1 / len(self.dropouts) params = (out_fc1).unsqueeze(-1) params = torch.abs(self.tanh(params)) y_pred = self.sird(inputs, params) return y_pred, params class Covid19ModelV2(nn.Module): def __init__(self, initN, n_components ,days, sird_parameters, n_dropout): super().__init__() self.n_components = n_components self.cnn_block = CNNBlock(n_components, 16) # self.lstm_1 = nn.LSTM(n_components, 128, 1, batch_first = True, dropout = 0.5) # self.lstm_2 = nn.LSTM(128, 256, 1, batch_first = True, dropout = 0.5) self.lstm_1 = AttnLSTM(n_components, 128, days - 1) self.lstm_2 = AttnLSTM(128, 256, days - 1) self.dropout = nn.Dropout() self.fc = nn.Sequential( nn.Linear(255 + 128 + days, 512), nn.ReLU(inplace = True), nn.Dropout(), nn.Linear(512, 256), nn.ReLU(inplace = True), nn.Dropout(), nn.Linear(256, 128), nn.ReLU(inplace = True), ) self.dropouts = nn.ModuleList([nn.Dropout() for _ in range(n_dropout)]) self.fc1 = nn.Linear(128, sird_parameters) self.tanh = nn.Tanh() self.sird = SEIRDTest(initN, days) def forward(self, inputs): inputs1 = (inputs[:, 1:] - inputs[:, :-1])[..., :self.n_components] inputs1 = self.cnn_block(inputs1) # out_lstm_2, (state_h, _)= self.lstm_1(inputs1) out_lstm_1, (state_h1, _) = self.lstm_1(inputs1) out_lstm_2, (state_h2, _) = self.lstm_2(out_lstm_1) # print(state_h.shape, state_c.shape) # 1, N, D -> N, D, 1 # state_h = state_h.permute(1, 2, 0) # N, L, D x N, D, 1 -> N, L, 1 # score = torch.matmul(out_lstm_2, state_h) # score = F.softmax(score, dim = 1) # out_lstm = torch.sum(out_lstm_2 * score, dim = -1) out_lstm = torch.sum(out_lstm_2, dim = -1) out_concat = torch.cat((out_lstm, state_h1.squeeze(dim = 0), state_h2.squeeze(dim = 0)), dim = 1) out_concat = self.dropout(out_concat) out_fc = self.fc(out_concat) for i, dropout in enumerate(self.dropouts): if i == 0: out_fc1 = self.tanh(self.fc1(dropout(out_fc))) else: out_fc1 = out_fc1 + self.tanh(self.fc1(dropout(out_fc))) out_fc1 = (out_fc1 / len(self.dropouts)) params = out_fc1.unsqueeze(-1) params[:, [0, 1, 2, 3, 5, 6, 7]] = torch.abs(params[:, [0, 1, 2, 3, 5, 6, 7]]) params[:, 4] = 1 / 180 + params[:, 4] y_pred = self.sird(inputs, params) return y_pred, params class Covid19ModelV3(nn.Module): def __init__(self, initN, n_components, days, sird_parameters, n_dropout): super().__init__() self.n_components = n_components self.block1 = nn.Sequential( nn.Conv1d(n_components, 32, kernel_size = 3, padding = 1), nn.ReLU(), CNNBlock(32, 32, 32, kernel_size = 3, padding = 1, dilation = 1, residual = True, norm = True), CBAMBlock(32, 8) ) self.block2 = nn.Sequential( CNNBlock(32, 64, 32, residual = True, norm = True), CBAMBlock(64, 8) ) self.block3 = nn.Sequential( CNNBlock(64, 128, 32, dilation = 4, residual = True, norm = True), ) self.flatten = nn.Flatten() self.fc = nn.Sequential( nn.Linear(512, 256), nn.ReLU(inplace = True), nn.Linear(256, 128), nn.ReLU(inplace = True), nn.Linear(128, sird_parameters) ) self.sird = SEIRD_LVEC() def forward(self, inputs): # inputs1 = (inputs[:, 1:] - inputs[:, :-1])[..., :self.n_components] # inputs1 = inputs1[:, 1:] - inputs1[:, :-1] inputs1 = inputs[..., :self.n_components] inputs1 = inputs1.permute(0, 2, 1) out_block1 = self.block1(inputs1) out_block2 = self.block2(out_block1) out_block3 = self.block3(out_block2) out_flatten = self.flatten(out_block3) params = self.fc(out_flatten).unsqueeze(-1) params = torch.abs(torch.tanh(params)) y_pred = self.sird(inputs, params) return y_pred, params
Editor is loading...