Untitled
unknown
plain_text
a year ago
5.7 kB
4
Indexable
Never
import torch import torch.nn as nn from torch_geometric.loader import DataLoader from tqdm import tqdm import copy import wandb from torch_scatter import scatter_add from torch_geometric.nn import global_add_pool from torch_geometric.datasets import QM9 from torch_geometric.transforms import RadiusGraph, AddRandomWalkPE, Compose class EGNNLayer(nn.Module): def __init__(self, num_hidden): super().__init__() self.message_mlp = nn.Sequential(nn.Linear(2 * num_hidden + 1, num_hidden), nn.SiLU(), nn.Linear(num_hidden, num_hidden), nn.SiLU()) self.update_mlp = nn.Sequential(nn.Linear(2 * num_hidden, num_hidden), nn.SiLU(), nn.Linear(num_hidden, num_hidden)) self.edge_net = nn.Sequential(nn.Linear(num_hidden, 1), nn.Sigmoid()) def forward(self, x, pos, edge_index): send, rec = edge_index dist = torch.linalg.norm(pos[send] - pos[rec], dim=1).unsqueeze(1) state = torch.cat((x[send], x[rec], dist), dim=1) message = self.message_mlp(state) # message = self.edge_net(message_pre) * message_pre aggr = scatter_add(dist*message, rec, dim=0) update = self.update_mlp(torch.cat((x, aggr), dim=1)) return update class EGNN(nn.Module): def __init__(self, num_in, num_hidden, num_out, num_layers, pe_dim): super().__init__() self.embed = nn.Sequential(nn.Linear(num_in+pe_dim, num_hidden)) self.layers = nn.ModuleList([EGNNLayer(num_hidden) for _ in range(num_layers)]) self.pre_readout = nn.Sequential(nn.Linear(num_hidden, num_hidden), nn.SiLU(), nn.Linear(num_hidden, num_hidden)) self.readout = nn.Sequential(nn.Linear(num_hidden, num_hidden), nn.SiLU(), nn.Linear(num_hidden, num_out)) def forward(self, data): x, pos, edge_index, batch, rw = data.x, data.pos, data.edge_index, data.batch, data.random_walk_pe x = torch.cat([x, rw], dim = -1) x = self.embed(x) for layer in self.layers: x = x + layer(x, pos, edge_index) x = self.pre_readout(x) x = global_add_pool(x, batch) out = self.readout(x) return torch.squeeze(out) if __name__ == '__main__': wandb.init(project=f"DL2-EGNN", name="NO_FC_Yes_PE_Yes_weighted_distance") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # transform = RadiusGraph(r=1e6) pe_dim = 24 t_compose = Compose([AddRandomWalkPE(walk_length = pe_dim)l]) dataset = QM9('./data_PE24', pre_transform = t_compose) epochs = 1000 n_train, n_test = 100000, 110000 train_dataset = dataset[:n_train] test_dataset = dataset[n_train:n_test] val_dataset = dataset[n_test:] print("Total number of edges: ", train_dataset.data.edge_index.shape[1] + val_dataset.data.edge_index.shape[1] + test_dataset.data.edge_index.shape[1]) # dataloaders train_loader = DataLoader(train_dataset, batch_size=96, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=96, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=96, shuffle=False) # graph.y[:, 1] for alpha values = [torch.squeeze(graph.y[:, 1]) for graph in train_loader.dataset] mean = sum(values) / len(values) mad = sum([abs(v - mean) for v in values]) / len(values) mean, mad = mean.to(device), mad.to(device) model = EGNN( num_in=11, num_hidden=128, num_out=1, num_layers=7, pe_dim=pe_dim ).to(device) criterion = torch.nn.L1Loss(reduction='sum') optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-16) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs) best_train_mae, best_val_mae, best_model = float('inf'), float('inf'), None num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f'Number of parameters: {num_params}') for _ in tqdm(range(epochs)): epoch_mae_train, epoch_mae_val = 0, 0 model.train() for _, batch in enumerate(train_loader): optimizer.zero_grad() batch = batch.to(device) pred = model(batch) target = torch.squeeze(batch.y[:, 1]) # batch.y[1] for alpha loss = criterion(pred, (target - mean) / mad) mae = criterion(pred * mad + mean, target) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() epoch_mae_train += mae.item() model.eval() for _, batch in enumerate(val_loader): batch = batch.to(device) target = torch.squeeze(batch.y[:, 1]) # batch.y[1] for alpha pred = model(batch) mae = criterion(pred * mad + mean, target) epoch_mae_val += mae.item() epoch_mae_train /= len(train_loader.dataset) epoch_mae_val /= len(val_loader.dataset) if epoch_mae_val < best_val_mae: best_val_mae = epoch_mae_val best_model = copy.deepcopy(model) scheduler.step() wandb.log({ 'Train MAE': epoch_mae_train, 'Validation MAE': epoch_mae_val }) test_mae = 0 best_model.eval() for _, batch in enumerate(test_loader): batch = batch.to(device) target = torch.squeeze(batch.y[:, 1]) # batch.y[:, 1] for alpha pred = best_model(batch) mae = criterion(pred * mad + mean, target) test_mae += mae.item() test_mae /= len(test_loader.dataset) print(f'Test MAE: {test_mae}') wandb.log({ 'Test MAE': test_mae, })