Untitled

 avatar
unknown
python
2 years ago
3.0 kB
6
Indexable
import torch
import torch.nn as nn
import torch.nn.functional as F

class SGConv(nn.Module):
    def __init__(self, kdim, sqlen, inchan, chan, outchan, device='cpu'):
        super().__init__()
        self.device = device

        kdim = torch.tensor(kdim)
        sqlen = torch.tensor(sqlen)
        num_subkernels = int(torch.ceil(torch.log2(sqlen/kdim))+1)

        #How much to stretch and how much to decay
        self.interp_factors = [2**max(0, i-1) for i in range(num_subkernels)]
        self.decay_coefs = [.8**i for i in range(num_subkernels)]

        # Construct subkernels
        self.subkernels = nn.ParameterList([nn.Parameter(torch.randn(chan, inchan, kdim)) for _ in range(num_subkernels)])
        
        #exponential decay with arbitrary rate of 1.8
        decay_rates = torch.rand((inchan)) / 1.8
        self.decay = torch.exp(-decay_rates.view(1, -1, 1) * torch.log(torch.arange(sqlen).to(device) + 1).view(1, 1, -1))


        self.kernel_norm = None

        #skip connection
        self.D = nn.Parameter(torch.randn(chan, inchan))

        self.norm = nn.LayerNorm(chan*inchan, device='cpu')
        self.act = nn.GELU()
        self.output_linear = nn.Linear(chan * inchan, outchan)

        #different strategy for channel mixing
        #self.output_linear = nn.Conv1d(chan*inchan, outchan, kernel_size=1,stride=1)

    def forward(self, x, return_kernel=False):
        x = x.to(self.device)
        L = x.shape[-1]

        #scale sub-kernels by interp factor and weight by decay_coef
        scaled_subkernels = [F.interpolate(subkernel, scale_factor=interp_factor, mode='linear').squeeze(0) * decay_coef for decay_coef, interp_factor, subkernel in zip(self.decay_coefs, self.interp_factors, self.subkernels)]
        #concatenate to full kernel
        k = torch.cat(scaled_subkernels, dim=-1)

        #if not already initialized, find kernel norm
        if self.kernel_norm is None:
            self.kernel_norm = k.norm(dim=-1, keepdim=True).detach()

        #cut or pad to sequence length
        if k.shape[-1] > L:
            k = k[..., :L]
        elif k.shape[-1] < L:
            k = F.pad(k, (0, L - k.size(-1)))
        
        #expnential decay and normalization 
        k = k * self.decay
        k = k / self.kernel_norm

        #convolution calculation using FFT
        k_f = torch.fft.rfft(k, n=2*L)
        x_f = torch.fft.rfft(x, n=2*L)
        y_f = torch.einsum('bhl,chl->bchl', x_f, k_f)
        y = torch.fft.irfft(y_f, n=2*L)[..., :L]

        #skip connection
        y = (y + torch.einsum('bhl,ch->bchl', x, self.D))

        #flattten along inchan and chan
        y = y.flatten(1, 2).to(self.device)

        #normalization
        y = self.norm(y.permute(0, 2, 1)).permute(0, 2, 1)

        #activation
        y = self.act(y)

        #linear output
        y = self.output_linear(y.permute(0,2,1)).permute(0,2,1)
    
        if return_kernel:
            return y, k

        return y
Editor is loading...
Leave a Comment