def main():
batch_size = 30
best_loss_train = 99999.0
best_psnr_train = 0.0
best_loss_val = 99999.0
best_psnr_val = 0.0
epochs = 9
training_path = ("C:/Users/Samuel/PycharmProjects/Super_ressolution/dataset4")
testing_path = ("C:/Users/Samuel/PycharmProjects/Super_ressolution/datik")
loader = Process_dataset(in_ress=64, out_ress=64 * 4, training_path=training_path, testing_path=testing_path, aug_count=4)
batch_count_train = (loader.get_training_count() + batch_size) // batch_size
batch_count_val = (loader.get_testing_count() + batch_size) // batch_size
loss_chart_train = CreateGraph(batch_count_train, "Generator and discriminator loss")
generator = Generator().to("cuda")
PATH = './Ich_generator_PSNR.pth'
generator.load_state_dict(torch.load(PATH))
discriminator = Discriminator().to("cuda")
g_optimizer = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.9, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.9, 0.999))
d_scheduler = lr_scheduler.StepLR(d_optimizer, epochs // 2, 0.1)
g_scheduler = lr_scheduler.StepLR(g_optimizer, epochs // 2, 0.1)
psnr_criterion, pixel_criterion, content_criterion, adversarial_criterion = define_loss()
scaler = amp.GradScaler()
for epoch in range(epochs):
print(best_loss_train, best_psnr_train,"TRAIN")
best_loss_train, best_psnr_train = train_model(generator, discriminator, g_optimizer, d_optimizer,
pixel_criterion,content_criterion, adversarial_criterion,
loader, batch_size, best_loss_train,best_psnr_train, scaler,
loss_chart_train, batch_count_train, epoch)
d_scheduler.step()
g_scheduler.step()
loss_chart_train.count(epoch)
print(best_loss_val, best_psnr_val,"VAL")
best_loss_val, best_psnr_val = validate_model(generator,pixel_criterion,content_criterion,loader,batch_size,
best_loss_val,best_psnr_val,batch_count_val,epoch)
torch.save(generator.state_dict(), './Generator_SR_epoch_GIT{}.pth'.format(epoch + 1))
print('Model saved at {} epoch'.format(epoch + 1))
def define_loss() -> [nn.MSELoss, nn.MSELoss, ContentLoss, nn.BCEWithLogitsLoss]:
psnr_criterion = nn.MSELoss().to("cuda")
pixel_criterion = nn.MSELoss().to("cuda")
content_criterion = ContentLoss().to("cuda")
adversarial_criterion = nn.BCEWithLogitsLoss().to("cuda")
return psnr_criterion, pixel_criterion, content_criterion, adversarial_criterion
def train_model(generator,
discriminator,
g_optimizer,
d_optimizer,
pixel_loss,
content_loss,
adversarial_loss,
loader,
batch_size,
best_loss,
best_psnr,
scaler,
loss_chart,
batch_count,
epoch):
discriminator.train()
generator.train()
for batch in range(batch_count):
lr, hr = loader.get_training_batch(batch_size)
hr = hr.to("cuda")
lr = lr.to("cuda")
real_label = torch.ones((lr.size(0), 1)).to("cuda")
fake_label = torch.zeros((lr.size(0), 1)).to("cuda")
sr = generator(lr)
d_optimizer.zero_grad()
with amp.autocast():
hr_output = discriminator(hr)
d_loss_hr = adversarial_loss(hr_output, real_label)
# Gradient zoom
d_loss_hr.backward()
# Calculate the loss of the discriminator on the super-resolution image.
with amp.autocast():
sr_output = discriminator(sr.detach())
d_loss_sr = adversarial_loss(sr_output, fake_label)
# Gradient zoom
d_loss_sr.backward()
# Update discriminator parameters
d_optimizer.step()
# Count discriminator total loss
d_loss = d_loss_hr + d_loss_sr
loss_chart.num_for_D += float(d_loss.item())
g_optimizer.zero_grad()
with amp.autocast():
output = discriminator(sr)
pixel_loss_tensor = 1.0 * pixel_loss(sr, hr.detach())
content_loss_tensor = 1.0 * content_loss(sr, hr.detach())
adversarial_loss_tensor = 0.001 * adversarial_loss(output, real_label)
# Count discriminator total loss
g_loss = pixel_loss_tensor + content_loss_tensor + adversarial_loss_tensor
# Gradient zoom
g_loss.backward()
# Update generator parameters
g_optimizer.step()
loss_chart.num_for_G += float(g_loss.item())
return best_loss, best_psnr
def validate_model(generator,
pixel_loss,
content_loss,
loader,
batch_size,
best_loss,
best_psnr,
batch_count,
epoch):
with torch.no_grad():
generator.eval()
for batch in range(batch_count):
lr, hr = loader.get_testing_batch(batch_size)
lr, hr = lr.to("cuda"), hr.to("cuda")
hr_pred = generator(lr)
PSNR = 10. * torch.log10(1. / (((hr_pred - hr) ** 2).mean()))
if epoch > -1 and best_psnr < float(PSNR.data.cpu().numpy()):
torch.save(generator.state_dict(), './GAN_generator_PSNR_GIT.pth')
best_psnr = float(PSNR.data.cpu().numpy())
print("Model saved at PSNR{}".format(best_psnr))
pixel_loss_tensor = 1.0 * pixel_loss(hr_pred, hr.detach())
content_loss_tensor = 1.0 * content_loss(hr_pred, hr.detach())
g_loss = pixel_loss_tensor + content_loss_tensor
if best_loss > float(g_loss.item()) or best_psnr < float(PSNR.item()):
torch.save(generator.state_dict(), './GAN_lowest_loss_GIT')
best_loss = float(g_loss.item())
print("Model saved at loss {}".format(best_loss))
return best_loss, best_psnr
if __name__ == "__main__":
main()