Untitled
unknown
python
a year ago
4.3 kB
6
Indexable
import os import time import torch import torch.nn as nn import torch.nn.functional as F import torch_geometric as tg import torch_geometric.nn as geom_nn from torch_geometric.nn.conv import TransformerConv, GCNConv, RGCNConv # from utils.tools import catch_lone_sender, fully_connected_edge_index from ..layers.layers import EGNNLayer class EGNN(nn.Module): def __init__( self, depth, hidden_features, node_features, out_features, norm, activation="swish", aggr="sum", pool="add", residual=True, RFF_dim=None, RFF_sigma=None, return_pos=False, **kwargs ): """E(n) Equivariant GNN model Args: depth: (int) - number of message passing layers hidden_features: (int) - hidden dimension node_features: (int) - initial node feature dimension out_features: (int) - output number of classes activation: (str) - non-linearity within MLPs (swish/relu) norm: (str) - normalisation layer (layer/batch) aggr: (str) - aggregation function `\oplus` (sum/mean/max) pool: (str) - global pooling function (sum/mean) residual: (bool) - whether to use residual connections """ super().__init__() # Name of the network self.name = "EGNN" # Embedding lookup for initial node features self.emb_in = nn.Linear(node_features, hidden_features) self.make_dist = PBCConvLayer() # Stack of GNN layers self.convs = torch.nn.ModuleList() for layer in range(depth): self.convs.append(EGNNLayer(hidden_features, activation, norm, aggr, RFF_dim, RFF_sigma)) # Global pooling/readout function self.pool = {"mean": tg.nn.global_mean_pool, "add": tg.nn.global_add_pool, "none": None}[pool] # Predictor MLP self.pred = torch.nn.Sequential( torch.nn.Linear(hidden_features, hidden_features), torch.nn.ReLU(), torch.nn.Linear(hidden_features, out_features) ) self.residual = residual def forward(self, batch): h = self.emb_in(batch.x) # (n,) -> (n, d) #pos = batch.pos # (n, 3) batch.pos = torch.autograd.Variable(batch.pos, requires_grad=True) distances = self.make_dist(batch.pos, batch.edge_index, batch.cell_offset ,batch.unit_cell[:3]) for conv in self.convs: # Message passing layer h_update = conv(h, batch.edge_index, distances) # Update node features (n, d) -> (n, d) h = h + h_update if self.residual else h_update # Update node coordinates (no residual) (n, 3) -> (n, 3) out = h if self.pool is not None: out = self.pool(h, batch.batch) energy = self.pred(out) #### Necemo to tako implementirat jer cemo imat pos1,pos2 blabla al oke force = -1.0 * torch.autograd.grad( energy, batch.pos, grad_outputs=torch.ones_like(energy), create_graph=True, retain_graph=True )[0] return energy, force # (batch_size, out_features) class PBCConvLayer(nn.Module): def __init__(self): super(PBCConvLayer, self).__init__() def forward(self, pos, edge_index, offsets, cell_vectors): # pos: Positions of nodes (N, 3) # edge_index: Indices of edges (2, E) # offsets: Offsets for PBC (E, 3), values like -1, 0, 1 for each edge considering PBC # cell_vectors: Cell vectors defining the unit cell (3, 3) # Calculate edge vectors considering initial positions to_move = pos[edge_index[1]] # Shape (E, 3) # Apply PBC corrections using offsets and cell vectors pbc_adjustments = torch.matmul(offsets, cell_vectors) corrected = to_move - pbc_adjustments # Compute distances distances = torch.linalg.vector_norm(corrected - pos[edge_index[0]],dim=-1) return distances
Editor is loading...
Leave a Comment