Untitled

mail@pastecode.io avatar
unknown
python
2 years ago
6.5 kB
2
Indexable
Never
def main():
    batch_size      = 30
    best_loss_train = 99999.0
    best_psnr_train = 0.0
    best_loss_val   = 99999.0
    best_psnr_val   = 0.0
    epochs          = 9

    training_path   = ("C:/Users/Samuel/PycharmProjects/Super_ressolution/dataset4")
    testing_path    = ("C:/Users/Samuel/PycharmProjects/Super_ressolution/datik")
    loader          = Process_dataset(in_ress=64, out_ress=64 * 4, training_path=training_path, testing_path=testing_path, aug_count=4)

    batch_count_train   = (loader.get_training_count() + batch_size) // batch_size
    batch_count_val     = (loader.get_testing_count() + batch_size) // batch_size
    loss_chart_train    = CreateGraph(batch_count_train, "Generator and discriminator loss")

    generator   = Generator().to("cuda")
    PATH        = './Ich_generator_PSNR.pth'
    generator.load_state_dict(torch.load(PATH))

    discriminator = Discriminator().to("cuda")

    g_optimizer = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.9, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.9, 0.999))

    d_scheduler = lr_scheduler.StepLR(d_optimizer, epochs // 2, 0.1)
    g_scheduler = lr_scheduler.StepLR(g_optimizer, epochs // 2, 0.1)

    psnr_criterion, pixel_criterion, content_criterion, adversarial_criterion = define_loss()

    scaler = amp.GradScaler()

    for epoch in range(epochs):
        print(best_loss_train, best_psnr_train,"TRAIN")
        best_loss_train, best_psnr_train = train_model(generator, discriminator, g_optimizer, d_optimizer,
                                                       pixel_criterion,content_criterion, adversarial_criterion,
                                                       loader, batch_size, best_loss_train,best_psnr_train, scaler,
                                                       loss_chart_train, batch_count_train, epoch)

        d_scheduler.step()
        g_scheduler.step()
        loss_chart_train.count(epoch)

        print(best_loss_val, best_psnr_val,"VAL")
        best_loss_val, best_psnr_val = validate_model(generator,pixel_criterion,content_criterion,loader,batch_size,
                                                      best_loss_val,best_psnr_val,batch_count_val,epoch)


        torch.save(generator.state_dict(), './Generator_SR_epoch_GIT{}.pth'.format(epoch + 1))
        print('Model saved at {} epoch'.format(epoch + 1))


def define_loss() -> [nn.MSELoss, nn.MSELoss, ContentLoss, nn.BCEWithLogitsLoss]:
    psnr_criterion          = nn.MSELoss().to("cuda")
    pixel_criterion         = nn.MSELoss().to("cuda")
    content_criterion       = ContentLoss().to("cuda")
    adversarial_criterion   = nn.BCEWithLogitsLoss().to("cuda")

    return psnr_criterion, pixel_criterion, content_criterion, adversarial_criterion



def train_model(generator,
                discriminator,
                g_optimizer,
                d_optimizer,
                pixel_loss,
                content_loss,
                adversarial_loss,
                loader,
                batch_size,
                best_loss,
                best_psnr,
                scaler,
                loss_chart,
                batch_count,
                epoch):

    discriminator.train()
    generator.train()
    for batch in range(batch_count):
        lr, hr = loader.get_training_batch(batch_size)

        hr = hr.to("cuda")
        lr = lr.to("cuda")

        real_label = torch.ones((lr.size(0), 1)).to("cuda")
        fake_label = torch.zeros((lr.size(0), 1)).to("cuda")

        sr = generator(lr)
        d_optimizer.zero_grad()

        with amp.autocast():
            hr_output = discriminator(hr)
            d_loss_hr = adversarial_loss(hr_output, real_label)
        # Gradient zoom

        d_loss_hr.backward()

        # Calculate the loss of the discriminator on the super-resolution image.
        with amp.autocast():
            sr_output = discriminator(sr.detach())
            d_loss_sr = adversarial_loss(sr_output, fake_label)

        # Gradient zoom
        d_loss_sr.backward()
        # Update discriminator parameters
        d_optimizer.step()

        # Count discriminator total loss
        d_loss = d_loss_hr + d_loss_sr
        loss_chart.num_for_D += float(d_loss.item())

        g_optimizer.zero_grad()

        with amp.autocast():
            output = discriminator(sr)

            pixel_loss_tensor = 1.0 * pixel_loss(sr, hr.detach())
            content_loss_tensor = 1.0 * content_loss(sr, hr.detach())
            adversarial_loss_tensor = 0.001 * adversarial_loss(output, real_label)

        # Count discriminator total loss
        g_loss = pixel_loss_tensor + content_loss_tensor + adversarial_loss_tensor
        # Gradient zoom
        g_loss.backward()
        # Update generator parameters
        g_optimizer.step()

        loss_chart.num_for_G += float(g_loss.item())


    return best_loss, best_psnr


def validate_model(generator,
                    pixel_loss,
                    content_loss,
                    loader,
                    batch_size,
                    best_loss,
                    best_psnr,
                    batch_count,
                    epoch):

    with torch.no_grad():
        generator.eval()
        for batch in range(batch_count):
            lr, hr  = loader.get_testing_batch(batch_size)
            lr, hr  = lr.to("cuda"), hr.to("cuda")
            hr_pred = generator(lr)

            PSNR = 10. * torch.log10(1. / (((hr_pred - hr) ** 2).mean()))

            if epoch > -1 and best_psnr < float(PSNR.data.cpu().numpy()):
                torch.save(generator.state_dict(), './GAN_generator_PSNR_GIT.pth')
                best_psnr = float(PSNR.data.cpu().numpy())
                print("Model saved at PSNR{}".format(best_psnr))

            pixel_loss_tensor   = 1.0 * pixel_loss(hr_pred, hr.detach())
            content_loss_tensor = 1.0 * content_loss(hr_pred, hr.detach())

            g_loss = pixel_loss_tensor + content_loss_tensor

            if best_loss > float(g_loss.item()) or best_psnr < float(PSNR.item()):

                torch.save(generator.state_dict(), './GAN_lowest_loss_GIT')
                best_loss = float(g_loss.item())
                print("Model saved at loss {}".format(best_loss))


    return best_loss, best_psnr


if __name__ == "__main__":
    main()