Untitled

 avatar
unknown
plain_text
2 years ago
1.3 kB
7
Indexable
### YOUR CODE IS HERE ######
                if ((i + 1) % self.n_critic) == 0:
                    # делаем шаг обучения генератора
                    self.opt_gen.zero_grad()
                    x_fake = torch.tensor(
                        generate(self.generator, cond_batch, self.latent_dim)
                    )
                    loss_g = torch.mean(1 - self.discriminator(x_fake, cond_batch))
                    loss_g.backward()
                    self.opt_gen.step()
                    
                else:
                    # делаем шаг обучения дискриминатора
                    self.opt_disc.zero_grad()
                    x_fake = torch.tensor(
                      generate(self.generator, cond_batch, self.latent_dim)
                    )
                    loss_d = -(torch.mean(self.discriminator(real_batch, cond_batch)) \
                             -torch.mean(1 - self.discriminator(x_fake, cond_batch)))
                    loss_d.backward(retain_graph=True)
                    self.opt_disc.step()

                    #clip-clip
                    for param in self.discriminator.parameters():
                        param.data.clamp_(-0.01, 0.01)

                ### THE END OF YOUR CODE ###
Editor is loading...