sgconv
unknown
python
2 years ago
2.6 kB
5
Indexable
import torch import torch.nn as nn import torch.nn.functional as F class SGConv(nn.Module): def __init__(self, kdim, sqlen, inchan, chan, outchan, device='cpu'): super().__init__() self.device = device kdim = torch.tensor(kdim).to(self.device) sqlen = torch.tensor(sqlen).to(self.device) num_subkernels = int(torch.ceil(torch.log2(sqlen/kdim))+1) print(num_subkernels) self.interp_factors = [2**max(0, i-1) for i in range(num_subkernels)] self.decay_coefs = [.8**i for i in range(num_subkernels)] self.subkernels = nn.ParameterList([nn.Parameter(torch.randn(chan, inchan, kdim).to(device)) for _ in range(num_subkernels)]) decay_rates = torch.rand((inchan)).to(device) / 1.8 self.decay = torch.exp(-decay_rates.view(1, -1, 1) * torch.log(torch.arange(sqlen).to(device) + 1).view(1, 1, -1)).to(self.device) self.kernel_norm = None self.D = nn.Parameter(torch.randn(chan, inchan)) self.norm = nn.LayerNorm(chan*inchan, device='cpu') self.act = nn.GELU() # self.output_linear = nn.Linear(chan * inchan, outchan) self.output_linear = nn.Conv1d(chan*inchan, outchan, kernel_size=1,stride=1) def forward(self, x, return_kernel=False): x = x.to(self.device) L = x.shape[-1] scaled_subkernels = [F.interpolate(subkernel, scale_factor=interp_factor, mode='linear').squeeze(0).to(self.device) * decay_coef for decay_coef, interp_factor, subkernel in zip(self.decay_coefs, self.interp_factors, self.subkernels)] k = torch.cat(scaled_subkernels, dim=-1).to(self.device) if self.kernel_norm is None: self.kernel_norm = k.norm(dim=-1, keepdim=True).detach().to(self.device) if k.shape[-1] > L: k = k[..., :L] elif k.shape[-1] < L: k = F.pad(k, (0, L - k.size(-1))) k = k * self.decay k = k / self.kernel_norm k_f = torch.fft.rfft(k, n=2*L).to(self.device) x_f = torch.fft.rfft(x, n=2*L).to(self.device) y_f = torch.einsum('bhl,chl->bchl', x_f, k_f).to(self.device) y = torch.fft.irfft(y_f, n=2*L)[..., :L].to(self.device) D = self.D.to('cpu') y = (y + torch.einsum('bhl,ch->bchl', x, D)).to(self.device) y = y.flatten(1, 2).to(self.device) y = self.norm(y.permute(0, 2, 1)).permute(0, 2, 1).to(self.device) y = self.act(y).to(self.device) y = self.output_linear(y).to(self.device) if return_kernel: return y, k return y
Editor is loading...
Leave a Comment