Untitled
unknown
python
4 years ago
9.1 kB
11
Indexable
from process_dataset_2 import Process_dataset
from generator_model import Generator
from discriminator_model import Discriminator
from test_model import Super_ress_model
from vgg19_model import VGG19
import torch
import time
from datetime import datetime
# from sobel_operator_model import Sobel_operator
from torch.optim import lr_scheduler
import numpy
import PIL
from PIL import Image
from graphing_class_SR import CreateGraph
import numpy as np
import matplotlib.pyplot as plt
import torch
import config
from torch import nn
from torch import optim
from torch.cuda import amp
from tqdm import tqdm
from torchvision.models import vgg19
torch.backends.cudnn.benchmark = True
from torch import Tensor
import torchvision.models as models
import torch.nn.functional as F
import sys
import os
class ContentLoss(nn.Module):
def __init__(self) -> None:
super(ContentLoss, self).__init__()
vgg19 = models.vgg19(pretrained=True).eval()
self.feature_extractor = nn.Sequential(*list(vgg19.features.children())[:36])
for parameters in self.feature_extractor.parameters():
parameters.requires_grad = False
#Normalization
self.register_buffer("mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer("std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward(self, sr: Tensor, hr: Tensor) -> Tensor:
sr = sr.sub(self.mean).div(self.std)
hr = hr.sub(self.mean).div(self.std)
loss = F.l1_loss(self.feature_extractor(sr), self.feature_extractor(hr))
return loss
def main():
batch_size = 30
best_loss_train = 99999.0
best_psnr_train = 0.0
best_loss_val = 99999.0
best_psnr_val = 0.0
epochs = 9
training_path = ("C:/Users/Samuel/PycharmProjects/Super_ressolution/dataset4")
testing_path = ("C:/Users/Samuel/PycharmProjects/Super_ressolution/datik")
loader = Process_dataset(in_ress=64, out_ress=64 * 4, training_path=training_path,testing_path=testing_path)
batch_count_train = (loader.get_training_count() + batch_size) // batch_size
batch_count_val = (loader.get_testing_count() + batch_size) // batch_size
loss_chart_train = CreateGraph(batch_count_train, "Generator and discriminator loss")
generator = Generator().to("cuda")
PATH = './Ich_generator_PSNR.pth'
generator.load_state_dict(torch.load(PATH))
discriminator = Discriminator().to("cuda")
g_optimizer = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.9, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.9, 0.999))
d_scheduler = lr_scheduler.StepLR(d_optimizer, epochs // 2, 0.1)
g_scheduler = lr_scheduler.StepLR(g_optimizer, epochs // 2, 0.1)
psnr_criterion, pixel_criterion, content_criterion, adversarial_criterion = define_loss()
scaler = amp.GradScaler()
for epoch in range(epochs):
print(best_loss_train, best_psnr_train,"TRAIN")
best_loss_train, best_psnr_train = train_model(generator, discriminator, g_optimizer, d_optimizer,
pixel_criterion,content_criterion, adversarial_criterion,
loader, batch_size, best_loss_train,best_psnr_train, scaler,
loss_chart_train, batch_count_train, epoch)
d_scheduler.step()
g_scheduler.step()
loss_chart_train.count(epoch)
print(best_loss_val, best_psnr_val,"VAL")
best_loss_val, best_psnr_val = validate_model(generator,pixel_criterion,content_criterion,loader,batch_size,
best_loss_val,best_psnr_val,batch_count_val,epoch)
torch.save(generator.state_dict(), './Generator_SR_epoch_GIT{}.pth'.format(epoch + 1))
print('Model saved at {} epoch'.format(epoch + 1))
def define_loss() -> [nn.MSELoss, nn.MSELoss, ContentLoss, nn.BCEWithLogitsLoss]:
psnr_criterion = nn.MSELoss().to("cuda")
pixel_criterion = nn.MSELoss().to("cuda")
content_criterion = ContentLoss().to("cuda")
adversarial_criterion = nn.BCEWithLogitsLoss().to("cuda")
return psnr_criterion, pixel_criterion, content_criterion, adversarial_criterion
def train_model(generator,
discriminator,
g_optimizer,
d_optimizer,
pixel_loss,
content_loss,
adversarial_loss,
loader,
batch_size,
best_loss,
best_psnr,
scaler,
loss_chart,
batch_count,
epoch):
discriminator.train()
generator.train()
for batch in range(batch_count):
lr, hr = loader.get_training_batch(batch_size)
hr = hr.to("cuda")
lr = lr.to("cuda")
real_label = torch.ones((lr.size(0), 1)).to("cuda")
fake_label = torch.zeros((lr.size(0), 1)).to("cuda")
sr = generator(lr)
PSNR = 10. * torch.log10(1. / (((sr - hr) ** 2).mean()))
if epoch > -1 and best_psnr < float(PSNR.data.cpu().numpy()):
torch.save(generator.state_dict(), './GAN_generator_PSNR_GIT.pth')
best_psnr = float(PSNR.data.cpu().numpy())
print("Model saved at PSNR{}".format(best_psnr))
for p in discriminator.parameters():
p.requires_grad = True
# Initialize the discriminator optimizer gradient
d_optimizer.zero_grad()
# Calculate the loss of the discriminator on the high-resolution image
with amp.autocast():
hr_output = discriminator(hr)
d_loss_hr = adversarial_loss(hr_output, real_label)
# Gradient zoom
scaler.scale(d_loss_hr).backward()
# Calculate the loss of the discriminator on the super-resolution image.
with amp.autocast():
sr_output = discriminator(sr.detach())
d_loss_sr = adversarial_loss(sr_output, fake_label)
# Gradient zoom
scaler.scale(d_loss_sr).backward()
# Update discriminator parameters
scaler.step(d_optimizer)
scaler.update()
# Count discriminator total loss
d_loss = d_loss_hr + d_loss_sr
loss_chart.num_for_D += float(d_loss.item())
for p in discriminator.parameters():
p.requires_grad = False
g_optimizer.zero_grad()
with amp.autocast():
output = discriminator(sr)
pixel_loss_tensor = 1.0 * pixel_loss(sr, hr.detach())
content_loss_tensor = 1.0 * content_loss(sr, hr.detach())
adversarial_loss_tensor = 0.001 * adversarial_loss(output, real_label)
# Count discriminator total loss
g_loss = pixel_loss_tensor + content_loss_tensor + adversarial_loss_tensor
# Gradient zoom
scaler.scale(g_loss).backward()
# Update generator parameters
scaler.step(g_optimizer)
scaler.update()
loss_chart.num_for_G += float(g_loss.item())
if epoch > -1 and best_loss > float(g_loss.data.cpu().numpy()):
torch.save(generator.state_dict(), './GAN_lowest_loss_GIT')
best_loss = float(g_loss.data.cpu().numpy())
print("Model saved at loss {}".format(best_loss))
return best_loss, best_psnr
def validate_model(generator,
pixel_loss,
content_loss,
loader,
batch_size,
best_loss,
best_psnr,
batch_count,
epoch):
generator.eval()
for batch in range(batch_count):
lr, hr = loader.get_testing_batch(batch_size)
lr, hr = lr.to("cuda"), hr.to("cuda")
hr_pred = generator(lr)
PSNR = 10. * torch.log10(1. / (((hr_pred - hr) ** 2).mean()))
if epoch > -1 and best_psnr < float(PSNR.data.cpu().numpy()):
torch.save(generator.state_dict(), './GAN_generator_PSNR_GIT.pth')
best_psnr = float(PSNR.data.cpu().numpy())
print("Model saved at PSNR{}".format(best_psnr))
pixel_loss_tensor = 1.0 * pixel_loss(hr_pred, hr.detach())
content_loss_tensor = 1.0 * content_loss(hr_pred, hr.detach())
g_loss = pixel_loss_tensor + content_loss_tensor
if best_loss > float(g_loss.item()) or best_psnr < float(PSNR.item()):
torch.save(generator.state_dict(), './GAN_lowest_loss_GIT')
best_loss = float(g_loss.item())
print("Model saved at loss {}".format(best_loss))
return best_loss, best_psnr
if __name__ == "__main__":
main()Editor is loading...