Untitled

 avatar
unknown
python
a year ago
4.3 kB
6
Indexable
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric as tg
import torch_geometric.nn as geom_nn
from torch_geometric.nn.conv import TransformerConv, GCNConv, RGCNConv
# from utils.tools import catch_lone_sender, fully_connected_edge_index
from ..layers.layers import EGNNLayer


class EGNN(nn.Module):
    def __init__(
            self,
            depth,
            hidden_features,
            node_features,
            out_features,
            norm,
            activation="swish",
            aggr="sum",
            pool="add",
            residual=True,
            RFF_dim=None,
            RFF_sigma=None,
            return_pos=False,
            **kwargs
    ):
        """E(n) Equivariant GNN model

        Args:
            depth: (int) - number of message passing layers
            hidden_features: (int) - hidden dimension
            node_features: (int) - initial node feature dimension
            out_features: (int) - output number of classes
            activation: (str) - non-linearity within MLPs (swish/relu)
            norm: (str) - normalisation layer (layer/batch)
            aggr: (str) - aggregation function `\oplus` (sum/mean/max)
            pool: (str) - global pooling function (sum/mean)
            residual: (bool) - whether to use residual connections
        """
        super().__init__()
        # Name of the network
        self.name = "EGNN"
        
        # Embedding lookup for initial node features
        self.emb_in = nn.Linear(node_features, hidden_features)
        
        self.make_dist = PBCConvLayer()

        # Stack of GNN layers
        self.convs = torch.nn.ModuleList()
        for layer in range(depth):
            self.convs.append(EGNNLayer(hidden_features, activation, norm, aggr, RFF_dim, RFF_sigma))

        # Global pooling/readout function
        self.pool = {"mean": tg.nn.global_mean_pool, "add": tg.nn.global_add_pool, "none": None}[pool]

        # Predictor MLP
        self.pred = torch.nn.Sequential(
            torch.nn.Linear(hidden_features, hidden_features),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_features, out_features)
        )
        self.residual = residual

    def forward(self, batch):
        h = self.emb_in(batch.x)  # (n,) -> (n, d)
        #pos = batch.pos  # (n, 3)
        batch.pos = torch.autograd.Variable(batch.pos, requires_grad=True)
        distances = self.make_dist(batch.pos, batch.edge_index, batch.cell_offset ,batch.unit_cell[:3])
        for conv in self.convs:
            # Message passing layer
            h_update = conv(h, batch.edge_index, distances)

            # Update node features (n, d) -> (n, d)
            h = h + h_update if self.residual else h_update

            # Update node coordinates (no residual) (n, 3) -> (n, 3)
        out = h
        if self.pool is not None:
            out = self.pool(h, batch.batch)
        
        energy = self.pred(out)
        
        #### Necemo to tako implementirat jer cemo imat pos1,pos2 blabla al oke
        force = -1.0 * torch.autograd.grad(
                    energy,
                    batch.pos,
                    grad_outputs=torch.ones_like(energy),
                    create_graph=True,
                    retain_graph=True
                )[0]

        return energy, force  # (batch_size, out_features)

class PBCConvLayer(nn.Module):
    def __init__(self):
        super(PBCConvLayer, self).__init__()

    def forward(self, pos, edge_index, offsets, cell_vectors):
            # pos: Positions of nodes (N, 3)
            # edge_index: Indices of edges (2, E)
            # offsets: Offsets for PBC (E, 3), values like -1, 0, 1 for each edge considering PBC
            # cell_vectors: Cell vectors defining the unit cell (3, 3)

            # Calculate edge vectors considering initial positions
            to_move = pos[edge_index[1]]  # Shape (E, 3)

            # Apply PBC corrections using offsets and cell vectors
            pbc_adjustments = torch.matmul(offsets, cell_vectors)
            corrected = to_move - pbc_adjustments        
            # Compute distances
            distances = torch.linalg.vector_norm(corrected - pos[edge_index[0]],dim=-1)

            return distances
Editor is loading...
Leave a Comment