Untitled

mail@pastecode.io avatar
unknown
python
2 years ago
4.3 kB
3
Indexable
Never
def main():
    batch_size = 10
    best_loss_train = 99999.0
    best_psnr_train = 0.0
    best_loss_val = 99999.0
    best_psnr_val = 0.0
    epochs = 9

    training_path = ("C:/Users/Samuel/PycharmProjects/Super_ressolution/dataset4")
    testing_path = ("C:/Users/Samuel/PycharmProjects/Super_ressolution/datik")
    loader = Process_dataset(in_ress=64, out_ress=64 * 4, training_path=training_path)

    batch_count = (loader.get_training_count() + batch_size) // batch_size
    loss_chart_train = CreateGraph(batch_count, "Generator and discriminator loss")
    generator = Generator().to("cuda")
    PATH = './Ich_generator_PSNR.pth'

    psnr_criterion, pixel_criterion, content_criterion, adversarial_criterion = define_loss()

    generator.load_state_dict(torch.load(PATH))
    discriminator = Discriminator().to("cuda")

    g_optimizer = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.9, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.9, 0.999))

    d_scheduler = lr_scheduler.StepLR(d_optimizer, epochs // 2, 0.1)
    g_scheduler = lr_scheduler.StepLR(g_optimizer, epochs // 2, 0.1)

    scaler = amp.GradScaler()


    for epoch in range(epochs):
        train_model(generator, discriminator, g_optimizer, d_optimizer, pixel_criterion, content_criterion, adversarial_criterion,
                    loader, batch_size, best_loss_train, best_psnr_train, scaler, loss_chart_train, batch_count, epoch)

        d_scheduler.step()
        g_scheduler.step()
        loss_chart_train.count(epoch)

        # validate_model(generator,batch_size,best_psnr_val,best_loss_val, pixel_loss, content_loss, adversarial_loss,loader)

        torch.save(generator.state_dict(), './Generator_SR_epoch_GIT{}.pth'.format(epoch + 1))
        print('Model saved at {} epoch'.format(epoch + 1))

def define_loss() -> [nn.MSELoss, nn.MSELoss, ContentLoss, nn.BCEWithLogitsLoss]:
    psnr_criterion = nn.MSELoss().to("cuda")
    pixel_criterion = nn.MSELoss().to("cuda")
    content_criterion = ContentLoss().to("cuda")
    adversarial_criterion = nn.BCEWithLogitsLoss().to("cuda")

    return psnr_criterion, pixel_criterion, content_criterion, adversarial_criterion

psnr_criterion, pixel_criterion, content_criterion, adversarial_criterion = define_loss()


def train_model(generator,
                discriminator,
                g_optimizer,
                d_optimizer,
                pixel_loss,
                content_loss,
                adversarial_loss,
                loader,
                batch_size,
                best_loss,
                best_psnr,
                loss_chart,
                batch_count,
                ):

    for batch in range(batch_count):
        lr, hr = loader.get_training_batch(batch_size)

        hr = hr.to("cuda")
        lr = lr.to("cuda")

        real_label = torch.ones((lr.size(0), 1)).to("cuda")
        fake_label = torch.zeros((lr.size(0), 1)).to("cuda")

        sr = generator(lr)

        # Initialize the discriminator optimizer gradient
        d_optimizer.zero_grad()

        # Calculate the loss of the discriminator on the high-resolution image

        hr_output = discriminator(hr)
        d_loss_hr = adversarial_loss(hr_output, real_label)
        # Gradient zoom

        d_loss_hr.backward()

        # Calculate the loss of the discriminator on the super-resolution image.

        sr_output = discriminator(sr.detach())
        d_loss_sr = adversarial_loss(sr_output, fake_label)
        # Gradient zoom
        d_loss_sr.backward()
        # Update discriminator parameters
        d_optimizer.step()


        g_optimizer.zero_grad()


        output = discriminator(sr)

        pixel_loss = 1.0 * pixel_loss(sr, hr.detach())
        content_loss = 1.0 * content_loss(sr, hr.detach())
        adversarial_loss = 0.001 * adversarial_loss(output, real_label)

        # Count discriminator total loss
        g_loss = pixel_loss + content_loss + adversarial_loss
        # Gradient zoom
        g_loss.backward()
        # Update generator parameters
        g_optimizer.step()



if __name__ == "__main__":
    main()