Untitled
unknown
python
a year ago
2.0 kB
19
Indexable
def forward(self, batch):
h = self.emb_in(batch.x) # (n,) -> (n, d)
#pos = batch.pos # (n, 3)
batch.pos = torch.autograd.Variable(batch.pos, requires_grad=True)
distances = self.make_dist(batch.pos, batch.edge_index, batch.cell_offset ,batch.unit_cell[:3])
for conv in self.convs:
# Message passing layer
h_update = conv(h, batch.edge_index, distances)
# Update node features (n, d) -> (n, d)
h = h + h_update if self.residual else h_update
# Update node coordinates (no residual) (n, 3) -> (n, 3)
out = h
if self.pool is not None:
out = self.pool(h, batch.batch)
energy = self.pred(out)
#### Necemo to tako implementirat jer cemo imat pos1,pos2 blabla al oke
force = -1.0 * torch.autograd.grad(
energy,
batch.pos,
grad_outputs=torch.ones_like(energy),
create_graph=True,
retain_graph=True
)[0]
return energy, force # (batch_size, out_features)
class PBCConvLayer(nn.Module):
def __init__(self):
super(PBCConvLayer, self).__init__()
def forward(self, pos, edge_index, offsets, cell_vectors):
# pos: Positions of nodes (N, 3)
# edge_index: Indices of edges (2, E)
# offsets: Offsets for PBC (E, 3), values like -1, 0, 1 for each edge considering PBC
# cell_vectors: Cell vectors defining the unit cell (3, 3)
# Calculate edge vectors considering initial positions
to_move = pos[edge_index[1]] # Shape (E, 3)
# Apply PBC corrections using offsets and cell vectors
pbc_adjustments = torch.matmul(offsets, cell_vectors)
corrected = to_move - pbc_adjustments
# Compute distances
distances = torch.linalg.vector_norm(corrected - pos[edge_index[0]],dim=-1)
return distancesEditor is loading...
Leave a Comment