Untitled
unknown
python
a year ago
2.5 kB
0
Indexable
Never
def upsample(self, filters, size, apply_dropout=False): initializer = tf.random_normal_initializer(0., 0.02) result = tf.keras.Sequential() result.add( tf.keras.layers.Conv2DTranspose(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False)) result.add(tf.keras.layers.BatchNormalization()) if apply_dropout: result.add(tf.keras.layers.Dropout(0.5)) result.add(tf.keras.layers.ReLU()) return result def downsample(self, filters, size, apply_batchnorm=True): initializer = tf.random_normal_initializer(0., 0.02) result = tf.keras.Sequential() result.add( tf.keras.layers.Conv2D(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False)) if apply_batchnorm: result.add(tf.keras.layers.BatchNormalization()) result.add(tf.keras.layers.LeakyReLU()) return result def build_generator(self): inputs = Input(shape=self.noise_shape, name='generator_noise') # pix2pix concept down_stack = [ downsample(64, 3, apply_batchnorm=False), downsample(128, 3), downsample(256, 3), downsample(512, 3), downsample(512, 3), ] up_stack = [ upsample(512, 4, apply_dropout=True), # (batch_size, 2, 2, 1024) upsample(512, 4, apply_dropout=True), # (batch_size, 2, 2, 1024) upsample(256, 4, apply_dropout=True), # (batch_size, 4, 4, 1024) upsample(128, 4, apply_dropout=True), # (batch_size, 8, 8, 1024) ] initializer = tf.random_normal_initializer(0., 0.02) last = tf.keras.layers.Conv2DTranspose(3, 4, strides=2, padding='same', kernel_initializer=initializer, activation='tanh') x = inputs skips = [] for down in down_stack: x = down(x) skips.append(x) skips = reversed(skips[:-1]) # Upsampling and establishing the skip connections for up, skip in zip(up_stack, skips): x = up(x) x = tf.keras.layers.Concatenate()([x, skip]) x = last(x) generator = Model(inputs=inputs, outputs=x, name='generator') generator.summary() return generator