Untitled

 avatar
unknown
python
a year ago
6.9 kB
5
Indexable
#!/usr/bin/env python
# coding: utf-8
import ase.db.sqlite
import ase.io.trajectory
import numpy as np
import torch
from pymatgen.io.ase import AseAtomsAdaptor
from torch_geometric.data import Data
from torch_geometric.data import Dataset

class EGNNClass(Dataset):
    def __init__(self, N: int, atoms: list, cutoff: float, unw : list = None,
        fetch_pos: bool = False,
        fetch_vel:bool = False,
        fetch_w: bool = False,
        potim: float = 0.5,
        max_neigh: int = 4000,
        essential:bool=False,
        validation:bool=False):
        super().__init__()
        self.potim = potim
        self.atoms = atoms
        self.cutoff = cutoff
        self.N = N
        self.fetch_pos = fetch_pos
        self.fetch_vel = fetch_vel
        self.fetch_w = fetch_w
        self.unw = unw
        self.max_neigh = max_neigh
        self.essential = essential # I am using this flag in the training pipeline
        self.validation = validation # I am using this flag in the creation of training dataset
        
    def get(self, idx: int):
        properties = self.get_properties(idx)
        properties["_idx"] = np.array([idx], dtype=np.int32)
        return torchify_dict(properties)
    
    def len(self):
        return len(self.atoms) - (self.N - 1)
    
    def get_properties(self,idx: int):
        # Still have to add forces and velocities
        outputs={}
        
        if self.essential:
            edge_index, edge_distances, offsets = self._get_neighbors_pymatgen(self.atoms[idx]) 
            outputs['edge_index'] = edge_index.astype(np.int64)
            outputs['pos'] = self.atoms[idx].get_positions().astype(np.float32)
            outputs['Z'] = self.atoms[idx].numbers.astype(np.int64)
            outputs['cell_offset'] = offsets.astype(np.float32)
            outputs['unit_cell'] = self.atoms[idx].get_cell().astype(np.float32)
            outputs['x'] = atoms_to_onehot_np(self.atoms[idx])
            
        else:
            edge_index, edge_distances, offsets = self._get_neighbors_pymatgen(self.atoms[idx])
            outputs['pos'] = self.atoms[idx].get_positions().astype(np.float32)
            outputs['x'] = atoms_to_onehot_np(self.atoms[idx])
            outputs['Z'] = self.atoms[idx].numbers.astype(np.int64)
            outputs['n_atoms'] = np.array([self.atoms[idx].get_global_number_of_atoms()]).astype(np.int64)
            outputs['vel'] = self.atoms[idx].get_velocities().astype(np.float32)
            outputs['unit_cell'] = self.atoms[idx].get_cell().astype(np.float32)##promjenio iz 64
            outputs['masses'] = self.atoms[idx].get_masses().astype(np.float32)
            outputs['potim'] = np.repeat(np.array([self.potim]).astype(np.float32), outputs['masses'].size )
            outputs['edge_index'] = edge_index.astype(np.int64)
            outputs['cell_offset'] = offsets.astype(np.float32)
            count=0
            
            for i in range(idx,idx+self.N):
                if not self.validation:
                    if (count+1) % 5 == 0: #This parameter 5 should be implemented into the train.json!!!
                        edge_index, edge_distances, offsets = self._get_neighbors_pymatgen(self.atoms[idx])
                        factor = count // 5 # every 5th step we are supplementing edge index
                        outputs['edge_index' + str(factor * 5)] = edge_index.astype(np.int64)
                        outputs['cell_offset' + str(factor * 5)] = offsets.astype(np.float32)
                    
                outputs['forces' + str(count)] = self.atoms[i].get_forces().astype(np.float32)
                outputs['energy' + str(count)] = np.array([self.atoms[i].get_potential_energy()]).astype(np.float32)
                
                if (self.fetch_pos and i > idx):
                    outputs['pos' + str(count)] = self.atoms[i].get_positions()
                
                if (self.fetch_pos and i > idx):
                    outputs['vel' + str(count)] = self.atoms[i].get_velocities()
                    
                count+=1
        return outputs        
        
    def _get_neighbors_pymatgen(self, atoms):
        """Preforms nearest neighbor search and returns edge index, distances,
        and cell offsets"""
        struct = AseAtomsAdaptor.get_structure(atoms)
        _c_index, _n_index, _offsets, n_distance = struct.get_neighbor_list(
            r=self.cutoff, numerical_tol=0, exclude_self=True
        )
        _nonmax_idx = []
        for i in range(len(atoms)):
            idx_i = (_c_index == i).nonzero()[0]
            # sort neighbors by distance, remove edges larger than max_neighbors
            idx_sorted = np.argsort(n_distance[idx_i])[: self.max_neigh]
            _nonmax_idx.append(idx_i[idx_sorted])
        _nonmax_idx = np.concatenate(_nonmax_idx)

        _c_index = _c_index[_nonmax_idx]
        _n_index = _n_index[_nonmax_idx]
        n_distance = n_distance[_nonmax_idx]
        _offsets = _offsets[_nonmax_idx]
        
        edge_index = np.vstack((_n_index, _c_index))
        edge_distances = n_distance
        cell_offsets = _offsets
        nonzero = np.where(edge_distances >= 1e-8)[0]
        edge_index = edge_index[:, nonzero]
        edge_distances = edge_distances[nonzero]
        cell_offsets = cell_offsets[nonzero]        
                
        return edge_index, edge_distances, cell_offsets 
        
def torchify_dict(data: dict):
    torch_properties = {}
    for pname, prop in data.items():
        if prop.dtype==np.int32:
            torch_properties[pname] = torch.IntTensor(prop)
        elif prop.dtype==np.int64:
            torch_properties[pname] = torch.LongTensor(prop)
        elif prop.dtype == np.float32:
            torch_properties[pname] = torch.FloatTensor(prop.copy())
        elif prop.dtype == np.float64:
            torch_properties[pname] = torch.DoubleTensor(prop.copy())
        else:
            raise CellDataError(
                "Invalid datatype {} for property {}!".format(type(prop), pname)
            )
    final_data = Data()

    # Assign each numpy array (now a tensor) to the Data object
    for key, value in torch_properties.items():
        setattr(final_data, key, value)
        
    return final_data

def atoms_to_onehot_np(atom):
    # Step 1: Get atomic numbers as a numpy array
    atomic_numbers = atom.numbers
    
    # Find unique atomic numbers and sort them
    unique_atomic_numbers = np.unique(atomic_numbers)
    
    # Map atomic numbers to sequential indices starting from 0
    mapping = {num: i for i, num in enumerate(unique_atomic_numbers)}
    mapped_atomic_numbers = np.vectorize(mapping.get)(atomic_numbers)
    
    # Step 2: One-hot encode the mapped atomic numbers
    num_atom_types = len(unique_atomic_numbers)
    one_hot_encoded = np.eye(num_atom_types)[mapped_atomic_numbers]
    
    return one_hot_encoded.astype(np.float32)
Editor is loading...
Leave a Comment