Untitled

mail@pastecode.io avatar
unknown
plain_text
a year ago
5.7 kB
4
Indexable
Never
import torch
import torch.nn as nn
from torch_geometric.loader import DataLoader
from tqdm import tqdm
import copy
import wandb
from torch_scatter import scatter_add
from torch_geometric.nn import global_add_pool
from torch_geometric.datasets import QM9
from torch_geometric.transforms import RadiusGraph, AddRandomWalkPE, Compose


class EGNNLayer(nn.Module):
    def __init__(self, num_hidden):
        super().__init__()
        self.message_mlp = nn.Sequential(nn.Linear(2 * num_hidden + 1, num_hidden), nn.SiLU(), nn.Linear(num_hidden, num_hidden), nn.SiLU())
        self.update_mlp = nn.Sequential(nn.Linear(2 * num_hidden, num_hidden), nn.SiLU(), nn.Linear(num_hidden, num_hidden))
        self.edge_net = nn.Sequential(nn.Linear(num_hidden, 1), nn.Sigmoid())

    def forward(self, x, pos, edge_index):
        
        send, rec = edge_index
        dist =  torch.linalg.norm(pos[send] - pos[rec], dim=1).unsqueeze(1)
        state = torch.cat((x[send], x[rec], dist), dim=1)
        message = self.message_mlp(state)
        # message = self.edge_net(message_pre) * message_pre
        aggr = scatter_add(dist*message, rec, dim=0)
        update = self.update_mlp(torch.cat((x, aggr), dim=1))

        return update


class EGNN(nn.Module):
    def __init__(self, num_in, num_hidden, num_out, num_layers, pe_dim):
        super().__init__()
        self.embed = nn.Sequential(nn.Linear(num_in+pe_dim, num_hidden))
        self.layers = nn.ModuleList([EGNNLayer(num_hidden) for _ in range(num_layers)])
        self.pre_readout = nn.Sequential(nn.Linear(num_hidden, num_hidden), nn.SiLU(), nn.Linear(num_hidden, num_hidden))
        self.readout = nn.Sequential(nn.Linear(num_hidden, num_hidden), nn.SiLU(), nn.Linear(num_hidden, num_out))

    def forward(self, data):
        x, pos, edge_index, batch, rw = data.x, data.pos, data.edge_index, data.batch, data.random_walk_pe
        x = torch.cat([x, rw], dim = -1)
        x = self.embed(x)

        for layer in self.layers:
            x = x + layer(x, pos, edge_index)

        x = self.pre_readout(x)
        x = global_add_pool(x, batch)
        out = self.readout(x)

        return torch.squeeze(out)


if __name__ == '__main__':
    wandb.init(project=f"DL2-EGNN", name="NO_FC_Yes_PE_Yes_weighted_distance")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # transform = RadiusGraph(r=1e6)
    pe_dim = 24
    t_compose = Compose([AddRandomWalkPE(walk_length = pe_dim)l])
    dataset = QM9('./data_PE24', pre_transform = t_compose)
    epochs = 1000

    n_train, n_test = 100000, 110000
    train_dataset = dataset[:n_train]
    test_dataset = dataset[n_train:n_test]
    val_dataset = dataset[n_test:]

    print("Total number of edges: ", train_dataset.data.edge_index.shape[1] + val_dataset.data.edge_index.shape[1] + test_dataset.data.edge_index.shape[1])


    # dataloaders
    train_loader = DataLoader(train_dataset, batch_size=96, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=96, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=96, shuffle=False)
    

    # graph.y[:, 1] for alpha
    values = [torch.squeeze(graph.y[:, 1]) for graph in train_loader.dataset]
    mean = sum(values) / len(values)
    mad = sum([abs(v - mean) for v in values]) / len(values)
    mean, mad = mean.to(device), mad.to(device)

    model = EGNN(
        num_in=11,
        num_hidden=128,
        num_out=1,
        num_layers=7,
        pe_dim=pe_dim
    ).to(device)

    criterion = torch.nn.L1Loss(reduction='sum')
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-16)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    best_train_mae, best_val_mae, best_model = float('inf'), float('inf'), None

    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'Number of parameters: {num_params}')

    for _ in tqdm(range(epochs)):
        epoch_mae_train, epoch_mae_val = 0, 0

        model.train()
        for _, batch in enumerate(train_loader):
            optimizer.zero_grad()
            batch = batch.to(device)
            pred = model(batch)
            target = torch.squeeze(batch.y[:, 1])  # batch.y[1] for alpha
            loss = criterion(pred, (target - mean) / mad)
            mae = criterion(pred * mad + mean, target)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            epoch_mae_train += mae.item()

        model.eval()
        for _, batch in enumerate(val_loader):
            batch = batch.to(device)
            target = torch.squeeze(batch.y[:, 1])  # batch.y[1] for alpha
            pred = model(batch)

            mae = criterion(pred * mad + mean, target)

            epoch_mae_val += mae.item()

        epoch_mae_train /= len(train_loader.dataset)
        epoch_mae_val /= len(val_loader.dataset)

        if epoch_mae_val < best_val_mae:
            best_val_mae = epoch_mae_val
            best_model = copy.deepcopy(model)

        scheduler.step()

        wandb.log({
            'Train MAE': epoch_mae_train,
            'Validation MAE': epoch_mae_val
        })

    test_mae = 0
    best_model.eval()
    for _, batch in enumerate(test_loader):
        batch = batch.to(device)
        target = torch.squeeze(batch.y[:, 1])  # batch.y[:, 1] for alpha

        pred = best_model(batch)
        mae = criterion(pred * mad + mean, target)
        test_mae += mae.item()

    test_mae /= len(test_loader.dataset)
    print(f'Test MAE: {test_mae}')

    wandb.log({
        'Test MAE': test_mae,
    })