Untitled
unknown
plain_text
3 years ago
1.3 kB
10
Indexable
### YOUR CODE IS HERE ######
if ((i + 1) % self.n_critic) == 0:
# делаем шаг обучения генератора
self.opt_gen.zero_grad()
x_fake = torch.tensor(
generate(self.generator, cond_batch, self.latent_dim)
)
loss_g = torch.mean(1 - self.discriminator(x_fake, cond_batch))
loss_g.backward()
self.opt_gen.step()
else:
# делаем шаг обучения дискриминатора
self.opt_disc.zero_grad()
x_fake = torch.tensor(
generate(self.generator, cond_batch, self.latent_dim)
)
loss_d = -(torch.mean(self.discriminator(real_batch, cond_batch)) \
-torch.mean(1 - self.discriminator(x_fake, cond_batch)))
loss_d.backward(retain_graph=True)
self.opt_disc.step()
#clip-clip
for param in self.discriminator.parameters():
param.data.clamp_(-0.01, 0.01)
### THE END OF YOUR CODE ###Editor is loading...