def upsample(self, filters, size, apply_dropout=False):
initializer = tf.random_normal_initializer(0., 0.02)
result = tf.keras.Sequential()
result.add(
tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
padding='same',
kernel_initializer=initializer,
use_bias=False))
result.add(tf.keras.layers.BatchNormalization())
if apply_dropout:
result.add(tf.keras.layers.Dropout(0.5))
result.add(tf.keras.layers.ReLU())
return result
def downsample(self, filters, size, apply_batchnorm=True):
initializer = tf.random_normal_initializer(0., 0.02)
result = tf.keras.Sequential()
result.add(
tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
kernel_initializer=initializer, use_bias=False))
if apply_batchnorm:
result.add(tf.keras.layers.BatchNormalization())
result.add(tf.keras.layers.LeakyReLU())
return result
def build_generator(self):
inputs = Input(shape=self.noise_shape, name='generator_noise')
# pix2pix concept
down_stack = [
downsample(64, 3, apply_batchnorm=False),
downsample(128, 3),
downsample(256, 3),
downsample(512, 3),
downsample(512, 3),
]
up_stack = [
upsample(512, 4, apply_dropout=True), # (batch_size, 2, 2, 1024)
upsample(512, 4, apply_dropout=True), # (batch_size, 2, 2, 1024)
upsample(256, 4, apply_dropout=True), # (batch_size, 4, 4, 1024)
upsample(128, 4, apply_dropout=True), # (batch_size, 8, 8, 1024)
]
initializer = tf.random_normal_initializer(0., 0.02)
last = tf.keras.layers.Conv2DTranspose(3, 4,
strides=2,
padding='same',
kernel_initializer=initializer,
activation='tanh')
x = inputs
skips = []
for down in down_stack:
x = down(x)
skips.append(x)
skips = reversed(skips[:-1])
# Upsampling and establishing the skip connections
for up, skip in zip(up_stack, skips):
x = up(x)
x = tf.keras.layers.Concatenate()([x, skip])
x = last(x)
generator = Model(inputs=inputs, outputs=x, name='generator')
generator.summary()
return generator