Untitled
unknown
plain_text
10 months ago
2.6 kB
5
Indexable
import torch
import torch.nn as nn
import torch.optim as optim
from dataset import load_dataset
from models.JERA.encoder import Encoder
from models.JERA.distillNet import DistillNet
from models.JERA.randomNet import RandomNet
path = '../data/tiny-imagenet-200/train'
val_path = '../data/tiny-imagenet-200/val'
image_size = 64
batch_size = 100
hidden_dim = 128
# Load dataset
dataset, dataloader = load_dataset(image_size, batch_size, path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize networks
encoder = Encoder(image_size, hidden_dim).to(device)
distill_net = DistillNet(image_size, hidden_dim).to(device)
# Define multiple random networks
num_random_nets = 3
random_nets = [RandomNet(image_size, hidden_dim).to(device)
for _ in range(num_random_nets)]
# Freeze all random_net parameters
for random_net in random_nets:
for param in random_net.parameters():
param.requires_grad = False
# Define a single optimizer for both encoder and distill_net
optimizer = optim.Adam(
list(encoder.parameters()) + list(distill_net.parameters()), lr=0.001
)
# Define a loss function
loss_fn = nn.MSELoss()
# Function to reset weights of a network
def reset_weights(network):
for layer in network.children():
if hasattr(layer, 'reset_parameters'):
layer.reset_parameters()
random_reset_interval = 10
num_epochs = 500
# Training Loop
for epoch in range(0, num_epochs):
for i, data in enumerate(dataloader, 0):
# Prepare batch
x = data[0].to(device)
# Forward pass through encoder
encoder_output = encoder(x)
# Forward pass through distillation network
distill_output = distill_net(encoder_output)
# Forward pass through all random networks
random_outputs = [random_net(x) for random_net in random_nets]
# Compute loss as the average of differences with all random networks
loss = sum(loss_fn(distill_output, random_output)
for random_output in random_outputs) / num_random_nets
# Update encoder and distill_net
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Periodically reset random_nets and distill_net
if (epoch + 1) % random_reset_interval == 0:
for random_net in random_nets:
reset_weights(random_net)
reset_weights(distill_net)
print(f"Random and Distill networks reset at epoch {epoch + 1}")
# Logging
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item():.4f}")
Editor is loading...
Leave a Comment