Untitled

 avatar
unknown
plain_text
a month ago
2.6 kB
3
Indexable
import torch
import torch.nn as nn
import torch.optim as optim
from dataset import load_dataset
from models.JERA.encoder import Encoder
from models.JERA.distillNet import DistillNet
from models.JERA.randomNet import RandomNet

path = '../data/tiny-imagenet-200/train'
val_path = '../data/tiny-imagenet-200/val'
image_size = 64
batch_size = 100

hidden_dim = 128

# Load dataset
dataset, dataloader = load_dataset(image_size, batch_size, path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize networks
encoder = Encoder(image_size, hidden_dim).to(device)
distill_net = DistillNet(image_size, hidden_dim).to(device)

# Define multiple random networks
num_random_nets = 3
random_nets = [RandomNet(image_size, hidden_dim).to(device)
               for _ in range(num_random_nets)]

# Freeze all random_net parameters
for random_net in random_nets:
    for param in random_net.parameters():
        param.requires_grad = False

# Define a single optimizer for both encoder and distill_net
optimizer = optim.Adam(
    list(encoder.parameters()) + list(distill_net.parameters()), lr=0.001
)

# Define a loss function
loss_fn = nn.MSELoss()

# Function to reset weights of a network
def reset_weights(network):
    for layer in network.children():
        if hasattr(layer, 'reset_parameters'):
            layer.reset_parameters()


random_reset_interval = 10
num_epochs = 500
# Training Loop
for epoch in range(0, num_epochs):
    for i, data in enumerate(dataloader, 0):
        # Prepare batch
        x = data[0].to(device)

        # Forward pass through encoder
        encoder_output = encoder(x)

        # Forward pass through distillation network
        distill_output = distill_net(encoder_output)

        # Forward pass through all random networks
        random_outputs = [random_net(x) for random_net in random_nets]

        # Compute loss as the average of differences with all random networks
        loss = sum(loss_fn(distill_output, random_output)
                   for random_output in random_outputs) / num_random_nets

        # Update encoder and distill_net
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Periodically reset random_nets and distill_net
    if (epoch + 1) % random_reset_interval == 0:
        for random_net in random_nets:
            reset_weights(random_net)
        reset_weights(distill_net)
        print(f"Random and Distill networks reset at epoch {epoch + 1}")

    # Logging
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item():.4f}")
Editor is loading...
Leave a Comment