Untitled
unknown
plain_text
4 years ago
6.8 kB
10
Indexable
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from layers import *
from torchdiffeq import odeint
class Covid19Model(nn.Module):
def __init__(self, initN, n_components ,days, sird_parameters, n_dropout):
super().__init__()
self.n_components = n_components
self.cnn_block = CNNBlock(n_components - 2, n_components - 2, 16, norm = True)
self.lstm_1 = nn.GRU(n_components - 2, 128, 1, batch_first = True)
self.lstm_2 = nn.GRU(128, 256, 1, batch_first = True)
self.dropout = nn.Dropout()
self.fc = nn.Sequential(
nn.Linear(256 + days + 128, 512),
nn.LeakyReLU(inplace = True),
# nn.Dropout(),
nn.Linear(512, 256),
nn.LeakyReLU(inplace = True),
# nn.Dropout(),
nn.Linear(256, 128),
nn.LeakyReLU(inplace = True),
)
self.dropouts = nn.ModuleList([nn.Dropout() for _ in range(n_dropout)])
self.fc1 = nn.Linear(128, sird_parameters)
self.tanh = nn.Tanh()
self.sird = SEIRD_LVEC()
self.days = days
def forward(self, inputs):
# inputs1 = (inputs[:, 1:] - inputs[:, :-1])[..., :self.n_components]
# inputs1 = inputs1[:, 1:] - inputs1[:, :-1]
inputs1 = inputs[..., :self.n_components]
inputs1 = inputs1[..., [1, 2, 3, 4, 6, 7]]
inputs1 = inputs1.permute(0, 2, 1)
inputs1 = self.cnn_block(inputs1)
inputs1 = inputs1.permute(0, 2, 1)
# ... cnn -> cbam
out_lstm_1, state_h1 = self.lstm_1(inputs1)
# attention?
out_lstm_2, state_h2 = self.lstm_2(out_lstm_1)
# 1, N, D -> N, D, 1
state_h = state_h2.permute(1, 2, 0)
# N, L, D x N, D, 1 -> N, L, 1
score = torch.matmul(out_lstm_2, state_h)
score = F.softmax(score, dim = 1)
out_lstm = torch.sum(out_lstm_2 * score, dim = -1)
# NUk1, Pu1, ...
# NUK2, ...
out_concat = torch.cat((out_lstm, state_h1.squeeze(dim = 0), state_h.squeeze(dim = -1)), dim = 1)
out_concat = self.dropout(out_concat)
out_fc = self.fc(out_concat)
for i, dropout in enumerate(self.dropouts):
if i == 0:
out_fc1 = (self.fc1(dropout(out_fc)))
else:
out_fc1 = out_fc1 + (self.fc1(dropout(out_fc)))
out_fc1 = out_fc1 / len(self.dropouts)
params = (out_fc1).unsqueeze(-1)
params = torch.abs(self.tanh(params))
y_pred = self.sird(inputs, params)
return y_pred, params
class Covid19ModelV2(nn.Module):
def __init__(self, initN, n_components ,days, sird_parameters, n_dropout):
super().__init__()
self.n_components = n_components
self.cnn_block = CNNBlock(n_components, 16)
# self.lstm_1 = nn.LSTM(n_components, 128, 1, batch_first = True, dropout = 0.5)
# self.lstm_2 = nn.LSTM(128, 256, 1, batch_first = True, dropout = 0.5)
self.lstm_1 = AttnLSTM(n_components, 128, days - 1)
self.lstm_2 = AttnLSTM(128, 256, days - 1)
self.dropout = nn.Dropout()
self.fc = nn.Sequential(
nn.Linear(255 + 128 + days, 512),
nn.ReLU(inplace = True),
nn.Dropout(),
nn.Linear(512, 256),
nn.ReLU(inplace = True),
nn.Dropout(),
nn.Linear(256, 128),
nn.ReLU(inplace = True),
)
self.dropouts = nn.ModuleList([nn.Dropout() for _ in range(n_dropout)])
self.fc1 = nn.Linear(128, sird_parameters)
self.tanh = nn.Tanh()
self.sird = SEIRDTest(initN, days)
def forward(self, inputs):
inputs1 = (inputs[:, 1:] - inputs[:, :-1])[..., :self.n_components]
inputs1 = self.cnn_block(inputs1)
# out_lstm_2, (state_h, _)= self.lstm_1(inputs1)
out_lstm_1, (state_h1, _) = self.lstm_1(inputs1)
out_lstm_2, (state_h2, _) = self.lstm_2(out_lstm_1)
# print(state_h.shape, state_c.shape)
# 1, N, D -> N, D, 1
# state_h = state_h.permute(1, 2, 0)
# N, L, D x N, D, 1 -> N, L, 1
# score = torch.matmul(out_lstm_2, state_h)
# score = F.softmax(score, dim = 1)
# out_lstm = torch.sum(out_lstm_2 * score, dim = -1)
out_lstm = torch.sum(out_lstm_2, dim = -1)
out_concat = torch.cat((out_lstm, state_h1.squeeze(dim = 0), state_h2.squeeze(dim = 0)), dim = 1)
out_concat = self.dropout(out_concat)
out_fc = self.fc(out_concat)
for i, dropout in enumerate(self.dropouts):
if i == 0:
out_fc1 = self.tanh(self.fc1(dropout(out_fc)))
else:
out_fc1 = out_fc1 + self.tanh(self.fc1(dropout(out_fc)))
out_fc1 = (out_fc1 / len(self.dropouts))
params = out_fc1.unsqueeze(-1)
params[:, [0, 1, 2, 3, 5, 6, 7]] = torch.abs(params[:, [0, 1, 2, 3, 5, 6, 7]])
params[:, 4] = 1 / 180 + params[:, 4]
y_pred = self.sird(inputs, params)
return y_pred, params
class Covid19ModelV3(nn.Module):
def __init__(self, initN, n_components, days, sird_parameters, n_dropout):
super().__init__()
self.n_components = n_components
self.block1 = nn.Sequential(
nn.Conv1d(n_components, 32, kernel_size = 3, padding = 1),
nn.ReLU(),
CNNBlock(32, 32, 32, kernel_size = 3, padding = 1, dilation = 1, residual = True, norm = True),
CBAMBlock(32, 8)
)
self.block2 = nn.Sequential(
CNNBlock(32, 64, 32, residual = True, norm = True),
CBAMBlock(64, 8)
)
self.block3 = nn.Sequential(
CNNBlock(64, 128, 32, dilation = 4, residual = True, norm = True),
)
self.flatten = nn.Flatten()
self.fc = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(inplace = True),
nn.Linear(256, 128),
nn.ReLU(inplace = True),
nn.Linear(128, sird_parameters)
)
self.sird = SEIRD_LVEC()
def forward(self, inputs):
# inputs1 = (inputs[:, 1:] - inputs[:, :-1])[..., :self.n_components]
# inputs1 = inputs1[:, 1:] - inputs1[:, :-1]
inputs1 = inputs[..., :self.n_components]
inputs1 = inputs1.permute(0, 2, 1)
out_block1 = self.block1(inputs1)
out_block2 = self.block2(out_block1)
out_block3 = self.block3(out_block2)
out_flatten = self.flatten(out_block3)
params = self.fc(out_flatten).unsqueeze(-1)
params = torch.abs(torch.tanh(params))
y_pred = self.sird(inputs, params)
return y_pred, params
Editor is loading...