Untitled
unknown
python
3 years ago
9.1 kB
2
Indexable
Never
from process_dataset_2 import Process_dataset from generator_model import Generator from discriminator_model import Discriminator from test_model import Super_ress_model from vgg19_model import VGG19 import torch import time from datetime import datetime # from sobel_operator_model import Sobel_operator from torch.optim import lr_scheduler import numpy import PIL from PIL import Image from graphing_class_SR import CreateGraph import numpy as np import matplotlib.pyplot as plt import torch import config from torch import nn from torch import optim from torch.cuda import amp from tqdm import tqdm from torchvision.models import vgg19 torch.backends.cudnn.benchmark = True from torch import Tensor import torchvision.models as models import torch.nn.functional as F import sys import os class ContentLoss(nn.Module): def __init__(self) -> None: super(ContentLoss, self).__init__() vgg19 = models.vgg19(pretrained=True).eval() self.feature_extractor = nn.Sequential(*list(vgg19.features.children())[:36]) for parameters in self.feature_extractor.parameters(): parameters.requires_grad = False #Normalization self.register_buffer("mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) self.register_buffer("std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) def forward(self, sr: Tensor, hr: Tensor) -> Tensor: sr = sr.sub(self.mean).div(self.std) hr = hr.sub(self.mean).div(self.std) loss = F.l1_loss(self.feature_extractor(sr), self.feature_extractor(hr)) return loss def main(): batch_size = 30 best_loss_train = 99999.0 best_psnr_train = 0.0 best_loss_val = 99999.0 best_psnr_val = 0.0 epochs = 9 training_path = ("C:/Users/Samuel/PycharmProjects/Super_ressolution/dataset4") testing_path = ("C:/Users/Samuel/PycharmProjects/Super_ressolution/datik") loader = Process_dataset(in_ress=64, out_ress=64 * 4, training_path=training_path,testing_path=testing_path) batch_count_train = (loader.get_training_count() + batch_size) // batch_size batch_count_val = (loader.get_testing_count() + batch_size) // batch_size loss_chart_train = CreateGraph(batch_count_train, "Generator and discriminator loss") generator = Generator().to("cuda") PATH = './Ich_generator_PSNR.pth' generator.load_state_dict(torch.load(PATH)) discriminator = Discriminator().to("cuda") g_optimizer = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.9, 0.999)) d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.9, 0.999)) d_scheduler = lr_scheduler.StepLR(d_optimizer, epochs // 2, 0.1) g_scheduler = lr_scheduler.StepLR(g_optimizer, epochs // 2, 0.1) psnr_criterion, pixel_criterion, content_criterion, adversarial_criterion = define_loss() scaler = amp.GradScaler() for epoch in range(epochs): print(best_loss_train, best_psnr_train,"TRAIN") best_loss_train, best_psnr_train = train_model(generator, discriminator, g_optimizer, d_optimizer, pixel_criterion,content_criterion, adversarial_criterion, loader, batch_size, best_loss_train,best_psnr_train, scaler, loss_chart_train, batch_count_train, epoch) d_scheduler.step() g_scheduler.step() loss_chart_train.count(epoch) print(best_loss_val, best_psnr_val,"VAL") best_loss_val, best_psnr_val = validate_model(generator,pixel_criterion,content_criterion,loader,batch_size, best_loss_val,best_psnr_val,batch_count_val,epoch) torch.save(generator.state_dict(), './Generator_SR_epoch_GIT{}.pth'.format(epoch + 1)) print('Model saved at {} epoch'.format(epoch + 1)) def define_loss() -> [nn.MSELoss, nn.MSELoss, ContentLoss, nn.BCEWithLogitsLoss]: psnr_criterion = nn.MSELoss().to("cuda") pixel_criterion = nn.MSELoss().to("cuda") content_criterion = ContentLoss().to("cuda") adversarial_criterion = nn.BCEWithLogitsLoss().to("cuda") return psnr_criterion, pixel_criterion, content_criterion, adversarial_criterion def train_model(generator, discriminator, g_optimizer, d_optimizer, pixel_loss, content_loss, adversarial_loss, loader, batch_size, best_loss, best_psnr, scaler, loss_chart, batch_count, epoch): discriminator.train() generator.train() for batch in range(batch_count): lr, hr = loader.get_training_batch(batch_size) hr = hr.to("cuda") lr = lr.to("cuda") real_label = torch.ones((lr.size(0), 1)).to("cuda") fake_label = torch.zeros((lr.size(0), 1)).to("cuda") sr = generator(lr) PSNR = 10. * torch.log10(1. / (((sr - hr) ** 2).mean())) if epoch > -1 and best_psnr < float(PSNR.data.cpu().numpy()): torch.save(generator.state_dict(), './GAN_generator_PSNR_GIT.pth') best_psnr = float(PSNR.data.cpu().numpy()) print("Model saved at PSNR{}".format(best_psnr)) for p in discriminator.parameters(): p.requires_grad = True # Initialize the discriminator optimizer gradient d_optimizer.zero_grad() # Calculate the loss of the discriminator on the high-resolution image with amp.autocast(): hr_output = discriminator(hr) d_loss_hr = adversarial_loss(hr_output, real_label) # Gradient zoom scaler.scale(d_loss_hr).backward() # Calculate the loss of the discriminator on the super-resolution image. with amp.autocast(): sr_output = discriminator(sr.detach()) d_loss_sr = adversarial_loss(sr_output, fake_label) # Gradient zoom scaler.scale(d_loss_sr).backward() # Update discriminator parameters scaler.step(d_optimizer) scaler.update() # Count discriminator total loss d_loss = d_loss_hr + d_loss_sr loss_chart.num_for_D += float(d_loss.item()) for p in discriminator.parameters(): p.requires_grad = False g_optimizer.zero_grad() with amp.autocast(): output = discriminator(sr) pixel_loss_tensor = 1.0 * pixel_loss(sr, hr.detach()) content_loss_tensor = 1.0 * content_loss(sr, hr.detach()) adversarial_loss_tensor = 0.001 * adversarial_loss(output, real_label) # Count discriminator total loss g_loss = pixel_loss_tensor + content_loss_tensor + adversarial_loss_tensor # Gradient zoom scaler.scale(g_loss).backward() # Update generator parameters scaler.step(g_optimizer) scaler.update() loss_chart.num_for_G += float(g_loss.item()) if epoch > -1 and best_loss > float(g_loss.data.cpu().numpy()): torch.save(generator.state_dict(), './GAN_lowest_loss_GIT') best_loss = float(g_loss.data.cpu().numpy()) print("Model saved at loss {}".format(best_loss)) return best_loss, best_psnr def validate_model(generator, pixel_loss, content_loss, loader, batch_size, best_loss, best_psnr, batch_count, epoch): generator.eval() for batch in range(batch_count): lr, hr = loader.get_testing_batch(batch_size) lr, hr = lr.to("cuda"), hr.to("cuda") hr_pred = generator(lr) PSNR = 10. * torch.log10(1. / (((hr_pred - hr) ** 2).mean())) if epoch > -1 and best_psnr < float(PSNR.data.cpu().numpy()): torch.save(generator.state_dict(), './GAN_generator_PSNR_GIT.pth') best_psnr = float(PSNR.data.cpu().numpy()) print("Model saved at PSNR{}".format(best_psnr)) pixel_loss_tensor = 1.0 * pixel_loss(hr_pred, hr.detach()) content_loss_tensor = 1.0 * content_loss(hr_pred, hr.detach()) g_loss = pixel_loss_tensor + content_loss_tensor if best_loss > float(g_loss.item()) or best_psnr < float(PSNR.item()): torch.save(generator.state_dict(), './GAN_lowest_loss_GIT') best_loss = float(g_loss.item()) print("Model saved at loss {}".format(best_loss)) return best_loss, best_psnr if __name__ == "__main__": main()