Untitled
unknown
python
3 years ago
2.3 kB
6
Indexable
def define_loss() -> [nn.MSELoss, nn.MSELoss, ContentLoss, nn.BCEWithLogitsLoss]: psnr_criterion = nn.MSELoss().to("cuda") pixel_criterion = nn.MSELoss().to("cuda") content_criterion = ContentLoss().to("cuda") adversarial_criterion = nn.BCEWithLogitsLoss().to("cuda") return psnr_criterion, pixel_criterion, content_criterion, adversarial_criterion psnr_criterion, pixel_criterion, content_criterion, adversarial_criterion = define_loss() def train_model(generator, discriminator, g_optimizer, d_optimizer, pixel_loss, content_loss, adversarial_loss, loader, batch_size, best_loss, best_psnr, loss_chart, batch_count, ): for batch in range(batch_count): lr, hr = loader.get_training_batch(batch_size) hr = hr.to("cuda") lr = lr.to("cuda") real_label = torch.ones((lr.size(0), 1)).to("cuda") fake_label = torch.zeros((lr.size(0), 1)).to("cuda") sr = generator(lr) # Initialize the discriminator optimizer gradient d_optimizer.zero_grad() # Calculate the loss of the discriminator on the high-resolution image hr_output = discriminator(hr) d_loss_hr = adversarial_loss(hr_output, real_label) # Gradient zoom d_loss_hr.backward() # Calculate the loss of the discriminator on the super-resolution image. sr_output = discriminator(sr.detach()) d_loss_sr = adversarial_loss(sr_output, fake_label) # Gradient zoom d_loss_sr.backward() # Update discriminator parameters d_optimizer.step() g_optimizer.zero_grad() output = discriminator(sr) pixel_loss = 1.0 * pixel_loss(sr, hr.detach()) content_loss = 1.0 * content_loss(sr, hr.detach()) adversarial_loss = 0.001 * adversarial_loss(output, real_label) # Count discriminator total loss g_loss = pixel_loss + content_loss + adversarial_loss # Gradient zoom g_loss.backward() # Update generator parameters g_optimizer.step()
Editor is loading...