Untitled

mail@pastecode.io avatar
unknown
plain_text
2 years ago
6.8 kB
3
Indexable
Never
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from layers import *
from torchdiffeq import odeint

class Covid19Model(nn.Module):
    def __init__(self, initN, n_components ,days, sird_parameters, n_dropout):
        super().__init__()
        self.n_components = n_components
        self.cnn_block = CNNBlock(n_components - 2, n_components - 2, 16, norm = True)
        self.lstm_1 = nn.GRU(n_components - 2, 128, 1, batch_first = True)
        self.lstm_2 = nn.GRU(128, 256, 1, batch_first = True)
        self.dropout = nn.Dropout()
        self.fc = nn.Sequential(
            nn.Linear(256 + days + 128, 512),
            nn.LeakyReLU(inplace = True),
            # nn.Dropout(), 
            nn.Linear(512, 256),
            nn.LeakyReLU(inplace = True),
            # nn.Dropout(),
            nn.Linear(256, 128),
            nn.LeakyReLU(inplace = True),
        )
        self.dropouts = nn.ModuleList([nn.Dropout() for _ in range(n_dropout)])
        self.fc1 = nn.Linear(128, sird_parameters)
        self.tanh = nn.Tanh()
        self.sird = SEIRD_LVEC()
        self.days = days

    def forward(self, inputs):
        # inputs1 = (inputs[:, 1:] - inputs[:, :-1])[..., :self.n_components]
        # inputs1 = inputs1[:, 1:] - inputs1[:, :-1]
        inputs1 = inputs[..., :self.n_components]
        inputs1 = inputs1[..., [1, 2, 3, 4, 6, 7]]
        inputs1 = inputs1.permute(0, 2, 1)
        inputs1 = self.cnn_block(inputs1)
        inputs1 = inputs1.permute(0, 2, 1)
        # ... cnn -> cbam
        out_lstm_1, state_h1 = self.lstm_1(inputs1)
        # attention?
        out_lstm_2, state_h2 = self.lstm_2(out_lstm_1)
        # 1, N, D -> N, D, 1
        state_h = state_h2.permute(1, 2, 0)
        # N, L, D x N, D, 1 -> N, L, 1
        score = torch.matmul(out_lstm_2, state_h)
        score = F.softmax(score, dim = 1)
        out_lstm = torch.sum(out_lstm_2 * score, dim = -1)
        # NUk1, Pu1, ...
        # NUK2, ...

        out_concat = torch.cat((out_lstm, state_h1.squeeze(dim = 0), state_h.squeeze(dim = -1)), dim = 1)
        out_concat = self.dropout(out_concat)
        out_fc = self.fc(out_concat)
        for i, dropout in enumerate(self.dropouts):
            if i == 0:
                out_fc1 = (self.fc1(dropout(out_fc)))
            else:
                out_fc1 = out_fc1 + (self.fc1(dropout(out_fc)))
        out_fc1 = out_fc1 / len(self.dropouts)
        params = (out_fc1).unsqueeze(-1)
        params = torch.abs(self.tanh(params))
        y_pred = self.sird(inputs, params)
        return y_pred, params


class Covid19ModelV2(nn.Module):
    def __init__(self, initN, n_components ,days, sird_parameters, n_dropout):
        super().__init__()
        self.n_components = n_components
        self.cnn_block = CNNBlock(n_components, 16)
        # self.lstm_1 = nn.LSTM(n_components, 128, 1, batch_first = True, dropout = 0.5)
        # self.lstm_2 = nn.LSTM(128, 256, 1, batch_first = True, dropout = 0.5)
        self.lstm_1 = AttnLSTM(n_components, 128, days - 1)
        self.lstm_2 = AttnLSTM(128, 256, days - 1)
        self.dropout = nn.Dropout()
        self.fc = nn.Sequential(
            nn.Linear(255 + 128 + days, 512),
            nn.ReLU(inplace = True),
            nn.Dropout(), 
            nn.Linear(512, 256),
            nn.ReLU(inplace = True),
            nn.Dropout(),
            nn.Linear(256, 128),
            nn.ReLU(inplace = True),
        )
        self.dropouts = nn.ModuleList([nn.Dropout() for _ in range(n_dropout)])
        self.fc1 = nn.Linear(128, sird_parameters)
        self.tanh = nn.Tanh()
        self.sird = SEIRDTest(initN, days)

    def forward(self, inputs):
        inputs1 = (inputs[:, 1:] - inputs[:, :-1])[..., :self.n_components]
        
        inputs1 = self.cnn_block(inputs1)
        # out_lstm_2, (state_h, _)= self.lstm_1(inputs1)
        out_lstm_1, (state_h1, _) = self.lstm_1(inputs1)
        out_lstm_2, (state_h2, _) = self.lstm_2(out_lstm_1)
        # print(state_h.shape, state_c.shape)
        # 1, N, D -> N, D, 1
        # state_h = state_h.permute(1, 2, 0)
        # N, L, D x N, D, 1 -> N, L, 1
        # score = torch.matmul(out_lstm_2, state_h)
        # score = F.softmax(score, dim = 1)
        # out_lstm = torch.sum(out_lstm_2 * score, dim = -1)
        out_lstm = torch.sum(out_lstm_2, dim = -1)
        out_concat = torch.cat((out_lstm, state_h1.squeeze(dim = 0), state_h2.squeeze(dim = 0)), dim = 1)
        
        out_concat = self.dropout(out_concat)
        out_fc = self.fc(out_concat)
        for i, dropout in enumerate(self.dropouts):
            if i == 0:
                out_fc1 = self.tanh(self.fc1(dropout(out_fc)))
            else:
                out_fc1 = out_fc1 + self.tanh(self.fc1(dropout(out_fc)))
        out_fc1 = (out_fc1 / len(self.dropouts))
        params = out_fc1.unsqueeze(-1)

        params[:, [0, 1, 2, 3, 5, 6, 7]] = torch.abs(params[:, [0, 1, 2, 3, 5, 6, 7]])
        params[:, 4] = 1 / 180 + params[:, 4]
        y_pred = self.sird(inputs, params)
        return y_pred, params

class Covid19ModelV3(nn.Module):
    def __init__(self, initN, n_components, days, sird_parameters, n_dropout):
        super().__init__()
        self.n_components = n_components
        self.block1 = nn.Sequential(
            nn.Conv1d(n_components, 32, kernel_size = 3, padding = 1),
            nn.ReLU(),
            CNNBlock(32, 32, 32, kernel_size = 3, padding = 1, dilation = 1, residual = True, norm = True),
            CBAMBlock(32, 8)
        )
        self.block2 = nn.Sequential(
            CNNBlock(32, 64, 32, residual = True, norm = True),
            CBAMBlock(64, 8)
        )
        self.block3 = nn.Sequential(
            CNNBlock(64, 128, 32, dilation = 4, residual = True, norm = True),
        )

        self.flatten = nn.Flatten()

        self.fc = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(inplace = True),
            nn.Linear(256, 128),
            nn.ReLU(inplace = True),
            nn.Linear(128, sird_parameters)
        )
        self.sird = SEIRD_LVEC()

    def forward(self, inputs):
        # inputs1 = (inputs[:, 1:] - inputs[:, :-1])[..., :self.n_components]
        # inputs1 = inputs1[:, 1:] - inputs1[:, :-1]
        inputs1 = inputs[..., :self.n_components]
        inputs1 = inputs1.permute(0, 2, 1)
        out_block1 = self.block1(inputs1)
        out_block2 = self.block2(out_block1)
        out_block3 = self.block3(out_block2)
        out_flatten = self.flatten(out_block3)
        params = self.fc(out_flatten).unsqueeze(-1)
        params = torch.abs(torch.tanh(params))
        y_pred = self.sird(inputs, params)
        return y_pred, params