Untitled
unknown
plain_text
a month ago
2.6 kB
3
Indexable
import torch import torch.nn as nn import torch.optim as optim from dataset import load_dataset from models.JERA.encoder import Encoder from models.JERA.distillNet import DistillNet from models.JERA.randomNet import RandomNet path = '../data/tiny-imagenet-200/train' val_path = '../data/tiny-imagenet-200/val' image_size = 64 batch_size = 100 hidden_dim = 128 # Load dataset dataset, dataloader = load_dataset(image_size, batch_size, path) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Initialize networks encoder = Encoder(image_size, hidden_dim).to(device) distill_net = DistillNet(image_size, hidden_dim).to(device) # Define multiple random networks num_random_nets = 3 random_nets = [RandomNet(image_size, hidden_dim).to(device) for _ in range(num_random_nets)] # Freeze all random_net parameters for random_net in random_nets: for param in random_net.parameters(): param.requires_grad = False # Define a single optimizer for both encoder and distill_net optimizer = optim.Adam( list(encoder.parameters()) + list(distill_net.parameters()), lr=0.001 ) # Define a loss function loss_fn = nn.MSELoss() # Function to reset weights of a network def reset_weights(network): for layer in network.children(): if hasattr(layer, 'reset_parameters'): layer.reset_parameters() random_reset_interval = 10 num_epochs = 500 # Training Loop for epoch in range(0, num_epochs): for i, data in enumerate(dataloader, 0): # Prepare batch x = data[0].to(device) # Forward pass through encoder encoder_output = encoder(x) # Forward pass through distillation network distill_output = distill_net(encoder_output) # Forward pass through all random networks random_outputs = [random_net(x) for random_net in random_nets] # Compute loss as the average of differences with all random networks loss = sum(loss_fn(distill_output, random_output) for random_output in random_outputs) / num_random_nets # Update encoder and distill_net optimizer.zero_grad() loss.backward() optimizer.step() # Periodically reset random_nets and distill_net if (epoch + 1) % random_reset_interval == 0: for random_net in random_nets: reset_weights(random_net) reset_weights(distill_net) print(f"Random and Distill networks reset at epoch {epoch + 1}") # Logging print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item():.4f}")
Editor is loading...
Leave a Comment