Untitled

mail@pastecode.io avatar
unknown
python
3 years ago
9.1 kB
2
Indexable
Never
from process_dataset_2 import Process_dataset
from generator_model import Generator
from discriminator_model import Discriminator
from test_model import Super_ress_model
from vgg19_model import VGG19
import torch
import time
from datetime import datetime
# from sobel_operator_model import Sobel_operator
from torch.optim import lr_scheduler
import numpy
import PIL
from PIL import Image
from graphing_class_SR import CreateGraph
import numpy as np
import matplotlib.pyplot as plt
import torch
import config
from torch import nn
from torch import optim
from torch.cuda import amp
from tqdm import tqdm
from torchvision.models import vgg19
torch.backends.cudnn.benchmark = True
from torch import Tensor
import torchvision.models as models
import torch.nn.functional as F
import sys
import os



class ContentLoss(nn.Module):
    def __init__(self) -> None:
        super(ContentLoss, self).__init__()

        vgg19 = models.vgg19(pretrained=True).eval()

        self.feature_extractor = nn.Sequential(*list(vgg19.features.children())[:36])

        for parameters in self.feature_extractor.parameters():
            parameters.requires_grad = False

        #Normalization
        self.register_buffer("mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer("std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def forward(self, sr: Tensor, hr: Tensor) -> Tensor:
        sr = sr.sub(self.mean).div(self.std)
        hr = hr.sub(self.mean).div(self.std)


        loss = F.l1_loss(self.feature_extractor(sr), self.feature_extractor(hr))

        return loss


def main():
    batch_size      = 30
    best_loss_train = 99999.0
    best_psnr_train = 0.0
    best_loss_val   = 99999.0
    best_psnr_val   = 0.0
    epochs          = 9

    training_path   = ("C:/Users/Samuel/PycharmProjects/Super_ressolution/dataset4")
    testing_path    = ("C:/Users/Samuel/PycharmProjects/Super_ressolution/datik")
    loader          = Process_dataset(in_ress=64, out_ress=64 * 4, training_path=training_path,testing_path=testing_path)

    batch_count_train   = (loader.get_training_count() + batch_size) // batch_size
    batch_count_val     = (loader.get_testing_count() + batch_size) // batch_size
    loss_chart_train    = CreateGraph(batch_count_train, "Generator and discriminator loss")

    generator   = Generator().to("cuda")
    PATH        = './Ich_generator_PSNR.pth'
    generator.load_state_dict(torch.load(PATH))

    discriminator = Discriminator().to("cuda")

    g_optimizer = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.9, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.9, 0.999))

    d_scheduler = lr_scheduler.StepLR(d_optimizer, epochs // 2, 0.1)
    g_scheduler = lr_scheduler.StepLR(g_optimizer, epochs // 2, 0.1)

    psnr_criterion, pixel_criterion, content_criterion, adversarial_criterion = define_loss()

    scaler = amp.GradScaler()

    for epoch in range(epochs):
        print(best_loss_train, best_psnr_train,"TRAIN")
        best_loss_train, best_psnr_train = train_model(generator, discriminator, g_optimizer, d_optimizer,
                                                       pixel_criterion,content_criterion, adversarial_criterion,
                                                       loader, batch_size, best_loss_train,best_psnr_train, scaler,
                                                       loss_chart_train, batch_count_train, epoch)

        d_scheduler.step()
        g_scheduler.step()
        loss_chart_train.count(epoch)

        print(best_loss_val, best_psnr_val,"VAL")
        best_loss_val, best_psnr_val = validate_model(generator,pixel_criterion,content_criterion,loader,batch_size,
                                                      best_loss_val,best_psnr_val,batch_count_val,epoch)


        torch.save(generator.state_dict(), './Generator_SR_epoch_GIT{}.pth'.format(epoch + 1))
        print('Model saved at {} epoch'.format(epoch + 1))


def define_loss() -> [nn.MSELoss, nn.MSELoss, ContentLoss, nn.BCEWithLogitsLoss]:
    psnr_criterion = nn.MSELoss().to("cuda")
    pixel_criterion = nn.MSELoss().to("cuda")
    content_criterion = ContentLoss().to("cuda")
    adversarial_criterion = nn.BCEWithLogitsLoss().to("cuda")

    return psnr_criterion, pixel_criterion, content_criterion, adversarial_criterion



def train_model(generator,
                discriminator,
                g_optimizer,
                d_optimizer,
                pixel_loss,
                content_loss,
                adversarial_loss,
                loader,
                batch_size,
                best_loss,
                best_psnr,
                scaler,
                loss_chart,
                batch_count,
                epoch):

    discriminator.train()
    generator.train()
    for batch in range(batch_count):
        lr, hr = loader.get_training_batch(batch_size)

        hr = hr.to("cuda")
        lr = lr.to("cuda")

        real_label = torch.ones((lr.size(0), 1)).to("cuda")
        fake_label = torch.zeros((lr.size(0), 1)).to("cuda")

        sr = generator(lr)

        PSNR = 10. * torch.log10(1. / (((sr - hr) ** 2).mean()))

        if epoch > -1 and best_psnr < float(PSNR.data.cpu().numpy()):
            torch.save(generator.state_dict(), './GAN_generator_PSNR_GIT.pth')
            best_psnr = float(PSNR.data.cpu().numpy())
            print("Model saved at PSNR{}".format(best_psnr))

        for p in discriminator.parameters():
            p.requires_grad = True

        # Initialize the discriminator optimizer gradient
        d_optimizer.zero_grad()

        # Calculate the loss of the discriminator on the high-resolution image
        with amp.autocast():
            hr_output = discriminator(hr)
            d_loss_hr = adversarial_loss(hr_output, real_label)
        # Gradient zoom

        scaler.scale(d_loss_hr).backward()

        # Calculate the loss of the discriminator on the super-resolution image.
        with amp.autocast():
            sr_output = discriminator(sr.detach())
            d_loss_sr = adversarial_loss(sr_output, fake_label)

        # Gradient zoom
        scaler.scale(d_loss_sr).backward()
        # Update discriminator parameters
        scaler.step(d_optimizer)
        scaler.update()

        # Count discriminator total loss
        d_loss = d_loss_hr + d_loss_sr

        loss_chart.num_for_D += float(d_loss.item())

        for p in discriminator.parameters():
            p.requires_grad = False


        g_optimizer.zero_grad()

        with amp.autocast():
            output = discriminator(sr)

            pixel_loss_tensor = 1.0 * pixel_loss(sr, hr.detach())
            content_loss_tensor = 1.0 * content_loss(sr, hr.detach())
            adversarial_loss_tensor = 0.001 * adversarial_loss(output, real_label)

        # Count discriminator total loss
        g_loss = pixel_loss_tensor + content_loss_tensor + adversarial_loss_tensor
        # Gradient zoom
        scaler.scale(g_loss).backward()
        # Update generator parameters
        scaler.step(g_optimizer)
        scaler.update()

        loss_chart.num_for_G += float(g_loss.item())

        if epoch > -1 and best_loss > float(g_loss.data.cpu().numpy()):
            torch.save(generator.state_dict(), './GAN_lowest_loss_GIT')
            best_loss = float(g_loss.data.cpu().numpy())
            print("Model saved at loss {}".format(best_loss))


    return best_loss, best_psnr




def validate_model(generator,
                    pixel_loss,
                    content_loss,
                    loader,
                    batch_size,
                    best_loss,
                    best_psnr,
                    batch_count,
                    epoch):

    generator.eval()
    
    for batch in range(batch_count):
        lr, hr  = loader.get_testing_batch(batch_size)
        lr, hr  = lr.to("cuda"), hr.to("cuda")
        hr_pred = generator(lr)

        PSNR = 10. * torch.log10(1. / (((hr_pred - hr) ** 2).mean()))

        if epoch > -1 and best_psnr < float(PSNR.data.cpu().numpy()):
            torch.save(generator.state_dict(), './GAN_generator_PSNR_GIT.pth')
            best_psnr = float(PSNR.data.cpu().numpy())
            print("Model saved at PSNR{}".format(best_psnr))

        pixel_loss_tensor   = 1.0 * pixel_loss(hr_pred, hr.detach())
        content_loss_tensor = 1.0 * content_loss(hr_pred, hr.detach())

        g_loss = pixel_loss_tensor + content_loss_tensor

        if best_loss > float(g_loss.item()) or best_psnr < float(PSNR.item()):

            torch.save(generator.state_dict(), './GAN_lowest_loss_GIT')
            best_loss = float(g_loss.item())
            print("Model saved at loss {}".format(best_loss))


    return best_loss, best_psnr


if __name__ == "__main__":
    main()