Untitled
unknown
plain_text
2 years ago
1.3 kB
7
Indexable
### YOUR CODE IS HERE ###### if ((i + 1) % self.n_critic) == 0: # делаем шаг обучения генератора self.opt_gen.zero_grad() x_fake = torch.tensor( generate(self.generator, cond_batch, self.latent_dim) ) loss_g = torch.mean(1 - self.discriminator(x_fake, cond_batch)) loss_g.backward() self.opt_gen.step() else: # делаем шаг обучения дискриминатора self.opt_disc.zero_grad() x_fake = torch.tensor( generate(self.generator, cond_batch, self.latent_dim) ) loss_d = -(torch.mean(self.discriminator(real_batch, cond_batch)) \ -torch.mean(1 - self.discriminator(x_fake, cond_batch))) loss_d.backward(retain_graph=True) self.opt_disc.step() #clip-clip for param in self.discriminator.parameters(): param.data.clamp_(-0.01, 0.01) ### THE END OF YOUR CODE ###
Editor is loading...