Untitled
unknown
python
3 years ago
6.5 kB
6
Indexable
def main(): batch_size = 30 best_loss_train = 99999.0 best_psnr_train = 0.0 best_loss_val = 99999.0 best_psnr_val = 0.0 epochs = 9 training_path = ("C:/Users/Samuel/PycharmProjects/Super_ressolution/dataset4") testing_path = ("C:/Users/Samuel/PycharmProjects/Super_ressolution/datik") loader = Process_dataset(in_ress=64, out_ress=64 * 4, training_path=training_path, testing_path=testing_path, aug_count=4) batch_count_train = (loader.get_training_count() + batch_size) // batch_size batch_count_val = (loader.get_testing_count() + batch_size) // batch_size loss_chart_train = CreateGraph(batch_count_train, "Generator and discriminator loss") generator = Generator().to("cuda") PATH = './Ich_generator_PSNR.pth' generator.load_state_dict(torch.load(PATH)) discriminator = Discriminator().to("cuda") g_optimizer = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.9, 0.999)) d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.9, 0.999)) d_scheduler = lr_scheduler.StepLR(d_optimizer, epochs // 2, 0.1) g_scheduler = lr_scheduler.StepLR(g_optimizer, epochs // 2, 0.1) psnr_criterion, pixel_criterion, content_criterion, adversarial_criterion = define_loss() scaler = amp.GradScaler() for epoch in range(epochs): print(best_loss_train, best_psnr_train,"TRAIN") best_loss_train, best_psnr_train = train_model(generator, discriminator, g_optimizer, d_optimizer, pixel_criterion,content_criterion, adversarial_criterion, loader, batch_size, best_loss_train,best_psnr_train, scaler, loss_chart_train, batch_count_train, epoch) d_scheduler.step() g_scheduler.step() loss_chart_train.count(epoch) print(best_loss_val, best_psnr_val,"VAL") best_loss_val, best_psnr_val = validate_model(generator,pixel_criterion,content_criterion,loader,batch_size, best_loss_val,best_psnr_val,batch_count_val,epoch) torch.save(generator.state_dict(), './Generator_SR_epoch_GIT{}.pth'.format(epoch + 1)) print('Model saved at {} epoch'.format(epoch + 1)) def define_loss() -> [nn.MSELoss, nn.MSELoss, ContentLoss, nn.BCEWithLogitsLoss]: psnr_criterion = nn.MSELoss().to("cuda") pixel_criterion = nn.MSELoss().to("cuda") content_criterion = ContentLoss().to("cuda") adversarial_criterion = nn.BCEWithLogitsLoss().to("cuda") return psnr_criterion, pixel_criterion, content_criterion, adversarial_criterion def train_model(generator, discriminator, g_optimizer, d_optimizer, pixel_loss, content_loss, adversarial_loss, loader, batch_size, best_loss, best_psnr, scaler, loss_chart, batch_count, epoch): discriminator.train() generator.train() for batch in range(batch_count): lr, hr = loader.get_training_batch(batch_size) hr = hr.to("cuda") lr = lr.to("cuda") real_label = torch.ones((lr.size(0), 1)).to("cuda") fake_label = torch.zeros((lr.size(0), 1)).to("cuda") sr = generator(lr) d_optimizer.zero_grad() with amp.autocast(): hr_output = discriminator(hr) d_loss_hr = adversarial_loss(hr_output, real_label) # Gradient zoom d_loss_hr.backward() # Calculate the loss of the discriminator on the super-resolution image. with amp.autocast(): sr_output = discriminator(sr.detach()) d_loss_sr = adversarial_loss(sr_output, fake_label) # Gradient zoom d_loss_sr.backward() # Update discriminator parameters d_optimizer.step() # Count discriminator total loss d_loss = d_loss_hr + d_loss_sr loss_chart.num_for_D += float(d_loss.item()) g_optimizer.zero_grad() with amp.autocast(): output = discriminator(sr) pixel_loss_tensor = 1.0 * pixel_loss(sr, hr.detach()) content_loss_tensor = 1.0 * content_loss(sr, hr.detach()) adversarial_loss_tensor = 0.001 * adversarial_loss(output, real_label) # Count discriminator total loss g_loss = pixel_loss_tensor + content_loss_tensor + adversarial_loss_tensor # Gradient zoom g_loss.backward() # Update generator parameters g_optimizer.step() loss_chart.num_for_G += float(g_loss.item()) return best_loss, best_psnr def validate_model(generator, pixel_loss, content_loss, loader, batch_size, best_loss, best_psnr, batch_count, epoch): with torch.no_grad(): generator.eval() for batch in range(batch_count): lr, hr = loader.get_testing_batch(batch_size) lr, hr = lr.to("cuda"), hr.to("cuda") hr_pred = generator(lr) PSNR = 10. * torch.log10(1. / (((hr_pred - hr) ** 2).mean())) if epoch > -1 and best_psnr < float(PSNR.data.cpu().numpy()): torch.save(generator.state_dict(), './GAN_generator_PSNR_GIT.pth') best_psnr = float(PSNR.data.cpu().numpy()) print("Model saved at PSNR{}".format(best_psnr)) pixel_loss_tensor = 1.0 * pixel_loss(hr_pred, hr.detach()) content_loss_tensor = 1.0 * content_loss(hr_pred, hr.detach()) g_loss = pixel_loss_tensor + content_loss_tensor if best_loss > float(g_loss.item()) or best_psnr < float(PSNR.item()): torch.save(generator.state_dict(), './GAN_lowest_loss_GIT') best_loss = float(g_loss.item()) print("Model saved at loss {}".format(best_loss)) return best_loss, best_psnr if __name__ == "__main__": main()
Editor is loading...