sgconv

mail@pastecode.io avatar
unknown
python
7 months ago
2.6 kB
2
Indexable
Never
import torch
import torch.nn as nn
import torch.nn.functional as F

class SGConv(nn.Module):
    def __init__(self, kdim, sqlen, inchan, chan, outchan, device='cpu'):
        super().__init__()
        self.device = device

        kdim = torch.tensor(kdim).to(self.device)
        sqlen = torch.tensor(sqlen).to(self.device)
        num_subkernels = int(torch.ceil(torch.log2(sqlen/kdim))+1)
        print(num_subkernels)
        self.interp_factors = [2**max(0, i-1) for i in range(num_subkernels)]
        self.decay_coefs = [.8**i for i in range(num_subkernels)]

        self.subkernels = nn.ParameterList([nn.Parameter(torch.randn(chan, inchan, kdim).to(device)) for _ in range(num_subkernels)])
        decay_rates = torch.rand((inchan)).to(device) / 1.8
        self.decay = torch.exp(-decay_rates.view(1, -1, 1) * torch.log(torch.arange(sqlen).to(device) + 1).view(1, 1, -1)).to(self.device)

        self.kernel_norm = None

        self.D = nn.Parameter(torch.randn(chan, inchan))

        self.norm = nn.LayerNorm(chan*inchan, device='cpu')
        self.act = nn.GELU()
        # self.output_linear = nn.Linear(chan * inchan, outchan)
        self.output_linear = nn.Conv1d(chan*inchan, outchan, kernel_size=1,stride=1)

    def forward(self, x, return_kernel=False):
        x = x.to(self.device)
        L = x.shape[-1]

        scaled_subkernels = [F.interpolate(subkernel, scale_factor=interp_factor, mode='linear').squeeze(0).to(self.device) * decay_coef for decay_coef, interp_factor, subkernel in zip(self.decay_coefs, self.interp_factors, self.subkernels)]
        k = torch.cat(scaled_subkernels, dim=-1).to(self.device)

        if self.kernel_norm is None:
            self.kernel_norm = k.norm(dim=-1, keepdim=True).detach().to(self.device)

        if k.shape[-1] > L:
            k = k[..., :L]
        elif k.shape[-1] < L:
            k = F.pad(k, (0, L - k.size(-1)))
        
        k = k * self.decay
        k = k / self.kernel_norm
        
        k_f = torch.fft.rfft(k, n=2*L).to(self.device)
        x_f = torch.fft.rfft(x, n=2*L).to(self.device)
        y_f = torch.einsum('bhl,chl->bchl', x_f, k_f).to(self.device)
        y = torch.fft.irfft(y_f, n=2*L)[..., :L].to(self.device)
        D = self.D.to('cpu')
        y = (y + torch.einsum('bhl,ch->bchl', x, D)).to(self.device)
        y = y.flatten(1, 2).to(self.device)
        y = self.norm(y.permute(0, 2, 1)).permute(0, 2, 1).to(self.device)
        y = self.act(y).to(self.device)
        y = self.output_linear(y).to(self.device)
    
        if return_kernel:
            return y, k

        return y
Leave a Comment