Untitled
unknown
python
a year ago
4.6 kB
7
Indexable
from tensorflow.keras.models import Sequential, Model class GAN(Model): def __init__(self, discriminator, generator, noise): super(GAN, self).__init__() self.discriminator = discriminator self.generator = generator self.noise = noise self.gen_loss_tracker = tf.keras.metrics.Mean(name='generator_loss') self.disc_loss_tracker = tf.keras.metrics.Mean(name='discriminator_loss') self.kl = tf.keras.metrics.KLDivergence() def compile(self, d_optimizer, g_optimizer, loss_fn): super(GAN, self).compile() self.d_optimizer = d_optimizer self.g_optimizer = g_optimizer self.loss_fn = loss_fn def train_step(self, real_images): if isinstance(real_images, tuple): real_images = real_images[0] # Sample random points in the latent space batch_size = tf.shape(real_images)[0] random_latent_vectors = tf.random.normal( shape=(batch_size, self.noise)) # Decode them to fake images generated_images = self.generator(random_latent_vectors) # Combine them with real images combined_images = tf.concat([generated_images, real_images], axis=0) # Assemble labels discriminating real from fake images labels = tf.concat( [tf.zeros((batch_size, 1)), tf.ones((batch_size, 1))], axis=0 ) # Add random noise to the labels - important trick! labels += 0.05 * tf.random.uniform(tf.shape(labels)) # Train the discriminator with tf.GradientTape() as tape: predictions = self.discriminator(combined_images) d_loss = self.loss_fn(labels, predictions) grads = tape.gradient(d_loss, self.discriminator.trainable_weights) self.d_optimizer.apply_gradients( zip(grads, self.discriminator.trainable_weights) ) # TRAINING GENERATOR # Sample random points in the latent space random_latent_vectors = tf.random.normal( shape=(batch_size, self.noise)) # Assemble labels that say "all real images" misleading_labels = tf.ones((batch_size, 1)) # Train the generator (note that we should *not* update the weights # of the discriminator)! with tf.GradientTape() as tape: generated_images = self.generator(random_latent_vectors) predictions = self.discriminator(generated_images) g_loss = self.loss_fn(misleading_labels, predictions) grads = tape.gradient(g_loss, self.generator.trainable_weights) self.g_optimizer.apply_gradients( zip(grads, self.generator.trainable_weights)) # Monitor Loss self.disc_loss_tracker.update_state(d_loss) self.gen_loss_tracker.update_state(g_loss) self.kl.update_state(y_true=real_images, y_pred=generated_images) return { "d_loss": self.disc_loss_tracker.result(), "g_loss": self.gen_loss_tracker.result(), "KL Divergence": self.kl.result(), } class GANMonitor(tf.keras.callbacks.Callback): def __init__(self, num_img=10, noise=128, patience=10, vmin=0, vmax=1): self.num_img = num_img self.noise = noise self.patience = patience self.vmin = vmin self.vmax = vmax self.constant_noise = tf.random.normal( shape=(self.num_img, self.noise)) def generate_plot(self): # Generate Images generated_images = self.model.generator(self.constant_noise) # Normalise Image from [vmin, vmax] to [0, 1] generated_images -= self.vmin generated_images /= (self.vmax - self.vmin) row_size = int(np.ceil(self.num_img/5)) fig = plt.figure(figsize=(10, 2*row_size), tight_layout=True) for i in range(self.num_img): ax = fig.add_subplot(row_size, 5, i+1) ax.imshow(generated_images[i]) ax.axis('off') plt.show() def on_epoch_begin(self, epoch, logs=None): if epoch % self.patience == 0: self.generate_plot() def on_train_end(self, epoch, logs=None): self.generate_plot() self.save_weights('Full Train') BATCH_SIZE = 64 AUTO = tf.data.AUTOTUNE EPOCHS = 100 BUFFER_SIZE = 1024 NOISE = 128 callbacks = [ GANMonitor(num_img=15, noise=128, patience=5, vmin=-1, vmax=1), ] x_train = x_train.astype('float32') x_train /= (255/2) x_train -= 1 latent_dim = 128
Editor is loading...
Leave a Comment