Untitled
unknown
python
2 years ago
4.6 kB
15
Indexable
from tensorflow.keras.models import Sequential, Model
class GAN(Model):
def __init__(self, discriminator, generator, noise):
super(GAN, self).__init__()
self.discriminator = discriminator
self.generator = generator
self.noise = noise
self.gen_loss_tracker = tf.keras.metrics.Mean(name='generator_loss')
self.disc_loss_tracker = tf.keras.metrics.Mean(name='discriminator_loss')
self.kl = tf.keras.metrics.KLDivergence()
def compile(self, d_optimizer, g_optimizer, loss_fn):
super(GAN, self).compile()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.loss_fn = loss_fn
def train_step(self, real_images):
if isinstance(real_images, tuple):
real_images = real_images[0]
# Sample random points in the latent space
batch_size = tf.shape(real_images)[0]
random_latent_vectors = tf.random.normal(
shape=(batch_size, self.noise))
# Decode them to fake images
generated_images = self.generator(random_latent_vectors)
# Combine them with real images
combined_images = tf.concat([generated_images, real_images], axis=0)
# Assemble labels discriminating real from fake images
labels = tf.concat(
[tf.zeros((batch_size, 1)), tf.ones((batch_size, 1))], axis=0
)
# Add random noise to the labels - important trick!
labels += 0.05 * tf.random.uniform(tf.shape(labels))
# Train the discriminator
with tf.GradientTape() as tape:
predictions = self.discriminator(combined_images)
d_loss = self.loss_fn(labels, predictions)
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
self.d_optimizer.apply_gradients(
zip(grads, self.discriminator.trainable_weights)
)
# TRAINING GENERATOR
# Sample random points in the latent space
random_latent_vectors = tf.random.normal(
shape=(batch_size, self.noise))
# Assemble labels that say "all real images"
misleading_labels = tf.ones((batch_size, 1))
# Train the generator (note that we should *not* update the weights
# of the discriminator)!
with tf.GradientTape() as tape:
generated_images = self.generator(random_latent_vectors)
predictions = self.discriminator(generated_images)
g_loss = self.loss_fn(misleading_labels, predictions)
grads = tape.gradient(g_loss, self.generator.trainable_weights)
self.g_optimizer.apply_gradients(
zip(grads, self.generator.trainable_weights))
# Monitor Loss
self.disc_loss_tracker.update_state(d_loss)
self.gen_loss_tracker.update_state(g_loss)
self.kl.update_state(y_true=real_images, y_pred=generated_images)
return {
"d_loss": self.disc_loss_tracker.result(),
"g_loss": self.gen_loss_tracker.result(),
"KL Divergence": self.kl.result(),
}
class GANMonitor(tf.keras.callbacks.Callback):
def __init__(self, num_img=10, noise=128, patience=10, vmin=0, vmax=1):
self.num_img = num_img
self.noise = noise
self.patience = patience
self.vmin = vmin
self.vmax = vmax
self.constant_noise = tf.random.normal(
shape=(self.num_img, self.noise))
def generate_plot(self):
# Generate Images
generated_images = self.model.generator(self.constant_noise)
# Normalise Image from [vmin, vmax] to [0, 1]
generated_images -= self.vmin
generated_images /= (self.vmax - self.vmin)
row_size = int(np.ceil(self.num_img/5))
fig = plt.figure(figsize=(10, 2*row_size), tight_layout=True)
for i in range(self.num_img):
ax = fig.add_subplot(row_size, 5, i+1)
ax.imshow(generated_images[i])
ax.axis('off')
plt.show()
def on_epoch_begin(self, epoch, logs=None):
if epoch % self.patience == 0:
self.generate_plot()
def on_train_end(self, epoch, logs=None):
self.generate_plot()
self.save_weights('Full Train')
BATCH_SIZE = 64
AUTO = tf.data.AUTOTUNE
EPOCHS = 100
BUFFER_SIZE = 1024
NOISE = 128
callbacks = [
GANMonitor(num_img=15, noise=128, patience=5, vmin=-1, vmax=1),
]
x_train = x_train.astype('float32')
x_train /= (255/2)
x_train -= 1
latent_dim = 128
Editor is loading...
Leave a Comment