Untitled
unknown
python
a year ago
3.6 kB
11
Indexable
import torch
import torch.nn as nn
import torch_geometric as tg
#from torch_scatter import scatter_add, scatter
import torch.nn.functional as F
import math
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.utils import softmax
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import Adj, OptTensor, PairTensor, SparseTensor
from torch_geometric.utils import softmax
from torch_scatter import scatter
class RFF(nn.Module):
def __init__(self, in_features, out_features, sigma=1.0):
super().__init__()
self.sigma = sigma
self.in_features = in_features
self.out_features = out_features
if out_features % 2 != 0:
self.compensation = 1
else:
self.compensation = 0
B = torch.randn(int(out_features / 2) + self.compensation, in_features) * sigma
B /= math.sqrt(2)
self.register_buffer("B", B)
def forward(self, x):
x = F.linear(x, self.B)
x = torch.cat((x.sin(), x.cos()), dim=-1)
if self.compensation:
x = x[..., :-1]
return x
def extra_repr(self) -> str:
return "in_features={}, out_features={}, sigma={}".format(
self.in_features, self.out_features, self.sigma
)
class EGNNLayer(tg.nn.MessagePassing):
def __init__(self, emb_dim, activation="relu", norm="layer", aggr="add", RFF_dim=None, RFF_sigma=None, mask=None):
super().__init__(aggr=aggr)
self.emb_dim = emb_dim
self.activation = {"swish": nn.SiLU(), "relu": nn.ReLU()}[activation]
self.norm = {"layer": torch.nn.LayerNorm, "batch": torch.nn.BatchNorm1d, "none": nn.Identity}[norm]
self.RFF_dim = RFF_dim
self.RFF_sigma = RFF_sigma
self.mask = mask
self.mlp_msg = nn.Sequential(
nn.Linear(2 * emb_dim + 1 if self.RFF_dim is None else 2 * emb_dim + RFF_dim, emb_dim),
self.norm(emb_dim),
self.activation,
nn.Linear(emb_dim, emb_dim),
self.norm(emb_dim),
self.activation,
)
self.mlp_upd = nn.Sequential(
nn.Linear(2 * emb_dim, emb_dim),
self.norm(emb_dim),
self.activation,
nn.Linear(emb_dim, emb_dim),
self.norm(emb_dim) if norm != "none" else nn.Identity(),
self.activation,
)
if self.RFF_dim is not None:
self.RFF = RFF(1, RFF_dim, RFF_sigma)
def forward(self, h, edge_index, distances, mask=None):
# Update self.mask with the provided mask argument
self.mask = mask
# Pass 'distances' to propagate() so it can be used in message()
out = self.propagate(edge_index, h=h, distances=distances, mask=mask)
return out
def message(self, h_i, h_j, distances):
# Use the provided 'distances' instead of computing from pos_i and pos_j
dists = distances.unsqueeze(1)
if self.RFF_dim is not None:
dists = self.RFF(dists)
msg = torch.cat([h_i, h_j, dists], dim=-1)
msg = self.mlp_msg(msg)
return msg
def update(self, aggr_out, h):
msg_aggr = aggr_out
upd_out = self.mlp_upd(torch.cat([h, msg_aggr], dim=-1))
if self.mask is not None:
upd_out = torch.where(self.mask.unsqueeze(-1), upd_out, h)
return upd_out
def __repr__(self) -> str:
return f"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})"
Editor is loading...
Leave a Comment