Untitled

 avatar
unknown
plain_text
4 years ago
18 kB
7
Indexable
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class SIRD(nn.Module):
    def __init__(self, days):
        super().__init__()
        self.days = days
    def forward(self, inputs, params):
        if self.training:
            std = torch.cat([params[:, 6]] * self.days, dim = 1)
            rand = torch.normal(torch.zeros(inputs.shape[0], self.days), std)
            beta = params[:, 6] + rand * params[:, 5]
        else:
            beta = torch.cat([params[:, 6]] * self.days, dim = 1)
        S = inputs[..., 0] - beta * inputs[..., 0] * inputs[..., 1] + params[:, 3] * inputs[..., 3] - params[:, 4] * inputs[..., 0]
        E = inputs[..., 1] +  beta * inputs[..., 0] * inputs[..., 1] - params[:, 2] * inputs[..., 1]
        I = inputs[..., 2] + params[:, 4] * inputs[..., 0] - params[:, 1] * inputs[..., 2] + params[:, 2] * inputs[..., 1] - params[:, 0] * inputs[..., 2]
        R = inputs[..., 3] + params[:, 0] * inputs[..., 2] - params[:, 3] * inputs[..., 3]
        D = inputs[..., 4] + params[:, 1] * inputs[..., 2]
        N = inputs[..., -1]
        return torch.stack((S, E, I, R, D, beta, N), dim = -1)

class SEIRDTest(nn.Module):
    def __init__(self, initN, days):
        super().__init__()
        self.days = days
        self.initN = initN
    def forward(self, inputs, params):
        # inputs: [Sn (0), E (1), St (2), I (3), R (4), D (5), beta, N]
        # params [sigma (0), beta_bar (1), tau (2), rho (3), xi (4), delta (5), gamma (6), mu (7)]
        params[:,7] = params[:,7] /100
        if self.training:
            std = torch.cat([params[:, 1]] * self.days, dim = 1)
            rand = torch.normal(torch.zeros(inputs.shape[0], self.days), std)
            beta = params[:, 1] + rand * params[:, 0]
        else:
            beta = torch.cat([params[:, 1]] * self.days, dim = 1)
        log_Sn = torch.log(inputs[..., 0] * self.initN / 1000 + 2)
        log_E = torch.log(inputs[..., 1] * self.initN / 1000 + 2)
        tau_s = params[:, 2] * log_Sn / (log_Sn + log_E)
        tau_e = params[:, 2] * log_E / (log_Sn + log_E)
        Sn = inputs[..., 0] - beta * inputs[..., 0] * inputs[..., 1] - tau_s * inputs[..., 0] + params[:, 3] * inputs[..., 2] + params[:, 4] * inputs[..., 4]
        E = inputs[..., 1] +  beta * inputs[..., 0] * inputs[..., 1] - tau_e * inputs[..., 1]
        St = inputs[..., 2] + tau_s * inputs[..., 0] - params[:, 5] * inputs[..., 2] - params[:, 3] * inputs[..., 2]
        I = inputs[..., 3] + tau_e * inputs[..., 1] + params[:, 5] * inputs[..., 2] - params[:, 6] * inputs[..., 3] - params[:, 7] * inputs[..., 3]
        R = inputs[..., 4] + params[:, 6] * inputs[..., 3] - params[:, 4] * inputs[..., 4]
        D = inputs[..., 5] + params[:, 7] * inputs[..., 3]
        N = inputs[..., -1]
        return torch.stack((Sn, E, St, I, R, D, beta, N), dim = -1)

class WeightedMSELoss(nn.Module):
    def __init__(self, weight = None, reduction='mean'):
        super().__init__()
        self.weight = torch.Tensor(weight)
        self.reduction = reduction
    def forward(self, inputs, targets):
        if self.weight is None:
            return F.mse_loss(inputs, targets, reduction = self.reduction)
        loss = torch.sum(self.weight * ((inputs - targets) ** 2), dim = 1)
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss

class CBAMBlock(nn.Module):
    def __init__(self, channels, ratio, kernel_size = 3):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.max_pool = nn.AdaptiveMaxPool1d(1)
        self.fc1 = nn.Conv1d(channels, channels // ratio, kernel_size = 1)
        self.fc2 = nn.Conv1d(channels // ratio, channels, kernel_size = 1)
        self.relu = nn.ReLU(inplace = True)
        self.sigmoid = nn.Sigmoid()
        self.conv = nn.Conv1d(2, 1, kernel_size = kernel_size, padding = kernel_size // 2)
    def forward(self, x):
        out_avg_pool = self.avg_pool(x) 
        out_max_pool = self.max_pool(x)
        out_avg_fc1 = self.fc1(out_avg_pool)
        out_max_fc1 = self.fc1(out_max_pool)
        out_avg_relu = self.relu(out_avg_fc1)
        out_max_relu = self.relu(out_max_fc1)
        out_avg_fc2 = self.fc2(out_avg_relu)
        out_max_fc2 = self.fc2(out_max_relu) # N, C, H, W
        out = out_avg_fc2 + out_max_fc2
        out_sigmoid = self.sigmoid(out)
        out_channel = x * out_sigmoid
        out_avg = torch.mean(out_channel, 1, True)
        out_max, _ = torch.max(out_channel, 1, True)
        out = torch.cat((out_avg, out_max), dim = 1)
        out = self.conv(out)
        out = self.sigmoid(out)
        return out * out_channel + x
        

class AlphaLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(1, 32, kernel_size = 2),
            nn.ReLU(),
            CBAMBlock(32, 8),
            nn.Dropout(0.4),
            nn.Conv1d(32, 32, kernel_size = 2, dilation = 2),
            nn.ReLU(),
            CBAMBlock(32, 8),
            nn.Dropout(0.4),
            nn.Conv1d(32, 64, kernel_size = 2, dilation = 4),
            nn.ReLU(),
            CBAMBlock(64, 8),
            nn.Dropout(0.4),
            nn.Flatten(),
            nn.Linear(192, 64),
            nn.ReLU(inplace = True),
            nn.Dropout(0.4),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
    def forward(self, input):
        return self.net(input)

class SIRD_T(nn.Module):
    def forward(self, inputs, params):
        # inputs: NUK, PU, PK, R, D, V, EC
        # params: tau - 0, zeta - 1, beta - 2, eta1 - 3, mu - 4, gamma_1 - 5, theta - 6, gamma_2 - 7
        xi = 1 / 180
        tau, zeta, beta, eta_1, mu, gamma_1,theta, gamma_2 = params[:,0], params[:,1], params[:,2], params[:,3], params[:,4], params[:, 5], params[:, 6], params[:, 7]
        NUK, PU, PK, R, D, V, EC = inputs[..., 0], inputs[..., 1], inputs[..., 2], inputs[..., 3], inputs[..., 4], inputs[..., 5], inputs[..., 6]

        NUK = NUK - tau*NUK - beta*NUK*PU + xi*R
        PU = PU + beta*NUK*PU + zeta*beta*V*PU - eta_1*PU
        PK = PK + eta_1*PU - theta*PK - gamma_1*PK
        EC = EC + theta*PK - gamma_2*EC - mu*EC
        R = R + gamma_1*PK + gamma_2*EC - xi*R
        D = D + mu*EC
        V = V + tau*NUK - zeta*beta*V*PU

        return torch.stack((NUK,PU,PK,R,D,V,EC),dim=-1)

    @staticmethod
    def apply_constraints(input, param):
        xi = 1/180
        tau, zeta, beta, eta_1, mu, gamma_1,theta, gamma_2 = param[:,0], param[:,1], param[:,2], param[:,3], param[:,4], param[:, 5], param[:, 6], param[:, 7]
        NUK, PU, PK, R, D, V, EC = input[..., 0], input[..., 1], input[..., 2], input[..., 3], input[..., 4], input[..., 5], input[..., 6]
        return torch.relu(xi*R - gamma_1*PK - gamma_2*EC).sum(dim=1).mean() + torch.relu(theta+gamma_1-1).mean() + torch.relu(mu+gamma_2-1).mean() + torch.relu(beta-0.01).mean() + torch.relu(zeta-0.05).mean() 

class SEIRD_LVEC(nn.Module):
    def forward(self, inputs, params):
        # inputs: NUK, PU, PK, R, D, V, EC, L
        # params: tau - 0, zeta - 1, beta - 2, eta1 - 3, mu - 4, gamma_1 - 5, theta - 6, gamma_2 - 7, rho - 8, delta - 9, rho_hat - 10
        xi = 1 / 180
        # print(inputs)
        tau, zeta, beta, eta_1, mu, gamma_1,theta, gamma_2, rho, delta, rho_hat = params[:,0], params[:,1], params[:,2], params[:,3], params[:,4], params[:, 5], params[:, 6], params[:, 7], params[:, 8], params[:, 9], params[:, 10]
        NUK, PU, PK, R, D, V, EC, L = inputs[..., 0], inputs[..., 1], inputs[..., 2], inputs[..., 3], inputs[..., 4], inputs[..., 5], inputs[..., 6], inputs[..., 7]

        _NUK = NUK - tau*NUK - beta*NUK*PU + xi*R - rho*NUK + rho_hat*L
        _PU = PU + beta*NUK*PU + zeta*beta*V*PU - eta_1*PU
        _PK = PK + eta_1*PU - theta*PK - gamma_1*PK + delta*L
        _R = R + gamma_1*PK + gamma_2*EC - xi * R
        _D = D + mu*EC
        _V = V + tau*NUK - zeta*beta*V*PU
        _EC = EC + theta*PK - gamma_2*EC - mu*EC
        _L = L + rho*NUK - (delta + rho_hat)*L
        return torch.stack((_NUK,_PU,_PK,_R,_D,_V,_EC,_L),dim=-1)

    @staticmethod
    def apply_constraints(input, param):
        xi = 1 / 180
        tau, zeta, beta, eta_1, mu, gamma_1,theta, gamma_2, rho, delta, rho_hat = param[:,0], param[:,1], param[:,2], param[:,3], param[:,4], param[:, 5], param[:, 6], param[:, 7], param[:, 8], param[:, 9], param[:, 10]
        NUK, PU, PK, R, D, V, EC, L = input[..., 0], input[..., 1], input[..., 2], input[..., 3], input[..., 4], input[..., 5], input[..., 6], input[..., 7]
        return torch.relu(xi*R - gamma_1*PK - gamma_2*EC).sum(dim=1).mean() + torch.relu(theta+gamma_1-1).mean() + torch.relu(mu+gamma_2-1).mean() + torch.relu(beta-0.001).mean() + torch.relu(zeta-0.5).mean() + torch.relu(delta + rho_hat - 1).mean() 


class SIRD_T_v2(nn.Module):
    """
    2 steps
    """
    def forward(self, inputs, params):
        # inputs: NUK, PU, PK, R, D, V
        # params: tau - 0, zeta - 1, beta - 2, eta1 - 3, mu - 4, gamma - 5, xi - 6
        xi = 1 / 180
        #N, t+1 -> t+11, 6
        NUK, PU, PK, R, D, V = inputs[..., 0], inputs[..., 1], inputs[..., 2], inputs[..., 3], inputs[..., 4], inputs[..., 5]
        NUK1 = NUK - params[:, 0] * NUK - params[:, 2] * NUK * PU + xi * R
        PU1 = PU + params[:, 2] * NUK * PU + params[:, 2] * (params[:, 1]) * V * PU - params[:, 3] * PU
        PK1 = PK + params[:, 3] * PU - (params[:, 4] + params[:, 5]) * PK
        R1 = R + params[:, 5] * PK - xi * R
        D1 = D + params[:, 4] * PK
        V1 = V + params[:, 0] * NUK - params[:, 2] * (params[:, 1]) * V * PU
        
        #N, t+2 -> t+12, 6
        NUK2 = NUK1 - params[:, 0] * NUK1 - params[:, 2] * NUK1 * PU1 + xi * R1
        PU2 = PU1 + params[:, 2] * NUK1 * PU1 + params[:, 2] * (params[:, 1]) * V1 * PU1 - params[:, 3] * PU1
        PK2 = PK1 + params[:, 3] * PU1 - (params[:, 4] + params[:, 5]) * PK1
        R2 = R1 + params[:, 5] * PK1 - xi * R1
        D2 = D1 + params[:, 4] * PK1
        V2 = V1 + params[:, 0] * NUK1 - params[:, 2] * (params[:, 1]) * V1 * PU1

        
        return torch.stack((NUK, PU, PK, R, D, V), dim = -1)
    @staticmethod
    def apply_constraints(input, param):
        return torch.relu(1 / 180 * input[..., 3] - param[:, 5] * input[..., 2]).mean()

def ODEFunc(params):
    class SEIRDV(nn.Module):
        def forward(self, t, y):
            # inputs: NUK, PU, PK, R, D, V
            # params: tau - 0, zeta - 1, beta - 2, eta1 - 3, mu - 4, gamma - 5, xi - 6
            # params[:, 6] = 1 / 180 + params[:, 6]
            xi = 1 / 180
            NUK, PU, PK, R, D, V = y[..., 0], y[..., 1], y[..., 2], y[..., 3], y[..., 4], y[..., 5]
            NUK = - params[:, 0] * NUK - params[:, 2] * NUK * PU + xi * R
            PU = params[:, 2] * NUK * PU + params[:, 2] * (params[:, 1]) * V * PU - params[:, 3] * PU
            PK = params[:, 3] * PU - (params[:, 4] + params[:, 5]) * PK
            R = params[:, 5] * PK - xi * R
            D = params[:, 4] * PK
            V = params[:, 0] * NUK - params[:, 2] * (params[:, 1]) * V * PU
            return torch.stack((NUK, PU, PK, R, D, V), dim = -1)
    
    return SEIRDV()


class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, cardinality, kernel_size = 3, dilation = 1, padding = 1, residual = False, norm = False):
        super().__init__()
        if residual:
            if norm:
                self.res_block = nn.Sequential(
                    nn.Conv1d(in_channels, out_channels, kernel_size = kernel_size, padding = padding, dilation = dilation),
                    # nn.LayerNorm(out_channels),
                    DAIN_Layer(out_channels),
                    nn.ReLU(inplace = True)
                )
            else:
                self.res_block = nn.Sequential(
                    nn.Conv1d(in_channels, out_channels, kernel_size = kernel_size, padding = padding, dilation = dilation),
                    nn.ReLU(inplace = True)
                )
        if norm:
            self.block1 = nn.Sequential(
                nn.Conv1d(in_channels, cardinality * 4, kernel_size = 1),
                DAIN_Layer(cardinality * 4),
                nn.ReLU(inplace = True),
                nn.Conv1d(cardinality * 4, cardinality * 4, kernel_size = kernel_size, padding = padding, dilation = dilation, groups = cardinality),
                DAIN_Layer(cardinality * 4),
                nn.ReLU(inplace = True),
                nn.Conv1d(cardinality * 4, out_channels, kernel_size = 1),
                DAIN_Layer(out_channels),
                nn.ReLU(inplace = True)
            )
        else:
            self.block1 = nn.Sequential(
                nn.Conv1d(in_channels, cardinality * 4, kernel_size = 1),
                nn.ReLU(inplace = True),
                nn.Conv1d(cardinality * 4, cardinality * 4, kernel_size = kernel_size, padding = padding, dilation = dilation, groups = cardinality),
                nn.ReLU(inplace = True),
                nn.Conv1d(cardinality * 4, out_channels, kernel_size = 1),
                nn.ReLU(inplace = True)
            )
        self.residual = residual
    def forward(self, input):
        output = self.block1(input)
        if self.residual:
            output = output + self.res_block(input)
        return output

class AttnLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, time_steps):
        super().__init__()
        self.fc = nn.Linear(2 * hidden_dim + time_steps, 1)
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers = 1, batch_first = True)
        self.softmax = nn.Softmax(dim = 1)
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.time_steps = time_steps
    def forward(self, input):
        device = input.device
        h_t, c_t = Variable(torch.zeros(1, input.size(0), self.hidden_dim, device = device)), Variable(torch.zeros(1, input.size(0), self.hidden_dim, device = device))
        output = Variable(torch.zeros(input.size(0), input.size(1), self.hidden_dim, device = device))

        for i in range(self.time_steps):
            x = torch.cat((h_t.repeat(self.input_dim, 1, 1).permute(1, 0, 2), # 1, batch_size, hidden_dim -> batch_size, input_dim, hidden_dim
                           c_t.repeat(self.input_dim, 1, 1).permute(1, 0, 2),
                           input.permute(0, 2, 1)), dim = 2) # input -> batch_size, input_dim, time_steps
            e_t = self.fc(x.view(-1, 2 * self.hidden_dim + self.time_steps)) # batch_size, input_dim
            a_t = self.softmax(e_t.view(-1, self.input_dim))
            w_input = torch.mul(a_t, input[:, i]) # batch_size, input_dim x batch_size, input_dim -> batch_size, input_dim
            _, (h_t, c_t) = self.lstm(w_input.unsqueeze(1), (h_t, c_t))
            output[:, i] = h_t.squeeze(0)
        return output, (h_t, c_t)


class DAIN_Layer(nn.Module):
    def __init__(self, input_dim=144, mode='adaptive_avg', mean_lr=0.00001, gate_lr=0.001, scale_lr=0.00001):
        super(DAIN_Layer, self).__init__()

        self.mode = mode
        self.mean_lr = mean_lr
        self.gate_lr = gate_lr
        self.scale_lr = scale_lr

        # Parameters for adaptive average
        self.mean_layer = nn.Linear(input_dim, input_dim, bias=False)
        self.mean_layer.weight.data = torch.FloatTensor(data=np.eye(input_dim, input_dim))

        # Parameters for adaptive std
        self.scaling_layer = nn.Linear(input_dim, input_dim, bias=False)
        self.scaling_layer.weight.data = torch.FloatTensor(data=np.eye(input_dim, input_dim))

        # Parameters for adaptive scaling
        self.gating_layer = nn.Linear(input_dim, input_dim)

        self.eps = 1e-8

    def forward(self, x):
        # Expecting  (n_samples, dim,  n_feature_vectors)

        # Nothing to normalize
        if self.mode == None:
            pass

        # Do simple average normalization
        elif self.mode == 'avg':
            avg = torch.mean(x, 2)
            avg = avg.view(avg.size(0), avg.size(1), 1)
            x = x - avg

        # Perform only the first step (adaptive averaging)
        elif self.mode == 'adaptive_avg':
            avg = torch.mean(x, 2)
            adaptive_avg = self.mean_layer(avg)
            adaptive_avg = adaptive_avg.view(adaptive_avg.size(0), adaptive_avg.size(1), 1)
            x = x - adaptive_avg

        # Perform the first + second step (adaptive averaging + adaptive scaling )
        elif self.mode == 'adaptive_scale':

            # Step 1:
            avg = torch.mean(x, 2)
            adaptive_avg = self.mean_layer(avg)
            adaptive_avg = adaptive_avg.view(adaptive_avg.size(0), adaptive_avg.size(1), 1)
            x = x - adaptive_avg

            # Step 2:
            std = torch.mean(x ** 2, 2)
            std = torch.sqrt(std + self.eps)
            adaptive_std = self.scaling_layer(std)
            adaptive_std[adaptive_std <= self.eps] = 1

            adaptive_std = adaptive_std.view(adaptive_std.size(0), adaptive_std.size(1), 1)
            x = x / (adaptive_std)

        elif self.mode == 'full':

            # Step 1:
            avg = torch.mean(x, 2)
            adaptive_avg = self.mean_layer(avg)
            adaptive_avg = adaptive_avg.view(adaptive_avg.size(0), adaptive_avg.size(1), 1)
            x = x - adaptive_avg

            # # Step 2:
            std = torch.mean(x ** 2, 2)
            std = torch.sqrt(std + self.eps)
            adaptive_std = self.scaling_layer(std)
            adaptive_std[adaptive_std <= self.eps] = 1

            adaptive_std = adaptive_std.view(adaptive_std.size(0), adaptive_std.size(1), 1)
            x = x / adaptive_std

            # Step 3: 
            avg = torch.mean(x, 2)
            gate = F.sigmoid(self.gating_layer(avg))
            gate = gate.view(gate.size(0), gate.size(1), 1)
            x = x * gate

        else:
            assert False
        return x
Editor is loading...