Untitled

 avatar
unknown
plain_text
4 years ago
8.6 kB
6
Indexable
from matplotlib import pyplot as plt
import tensorflow as tf

import tensorflow.keras as keras
import tensorflow.keras.layers as layers

from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras import losses
import numpy as np
import os
from keras.layers import LeakyReLU

# tf.random.set_seed(100)
# np.random.seed(100)

latent_dim = 64
noise_sigma = 0.35
train_AE = False
TASK2 = False
TASK3 = True
TASK4 = False

TRAIN_GAN = True
TRAIN_CONDITIONAL_GAN = False
sml_train_size = 50

# load train and test images, and pad & reshape them to (-1,32,32,1)
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = np.reshape(x_train, (len(x_train), 28, 28, 1)).astype('float32') / 255.0
x_test = np.reshape(x_test, (len(x_test), 28, 28, 1)).astype('float32') / 255.0
x_train = np.pad(x_train, ((0, 0), (2, 2), (2, 2), (0, 0)))
x_test = np.pad(x_test, ((0, 0), (2, 2), (2, 2), (0, 0)))
print(x_train.shape)
print(x_test.shape)
y_train = keras.utils.to_categorical(y_train, num_classes=10, dtype='float32')
y_test = keras.utils.to_categorical(y_test, num_classes=10, dtype='float32')

encoder = Sequential()
encoder.add(layers.Conv2D(16, (4, 4), strides=(2, 2), activation='relu', padding='same', input_shape=(32, 32, 1)))
encoder.add(layers.Conv2D(32, (3, 3), strides=(2, 2), activation='relu', padding='same'))
encoder.add(layers.Conv2D(64, (3, 3), strides=(2, 2), activation='relu', padding='same'))
encoder.add(layers.Conv2D(96, (3, 3), strides=(2, 2), activation='relu', padding='same'))
encoder.add(layers.Reshape((2 * 2 * 96,)))
encoder.add(layers.Dense(latent_dim))

# at this point the representation is (4, 4, 8) i.e. 128-dimensional
decoder = Sequential()
decoder.add(layers.Dense(2 * 2 * 96, activation='relu', input_shape=(latent_dim,)))
decoder.add(layers.Reshape((2, 2, 96)))
decoder.add(layers.Conv2DTranspose(64, (3, 3), strides=(2, 2), activation='relu', padding='same'))
decoder.add(layers.Conv2DTranspose(32, (3, 3), strides=(2, 2), activation='relu', padding='same'))
decoder.add(layers.Conv2DTranspose(16, (4, 4), strides=(2, 2), activation='relu', padding='same'))
decoder.add(layers.Conv2DTranspose(1, (4, 4), strides=(2, 2), activation='sigmoid', padding='same'))

autoencoder = keras.Model(encoder.inputs, decoder(encoder.outputs))
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')

checkpoint_path = "model_save/cp.ckpt"

if train_AE:
    checkpoint_dir = os.path.dirname(checkpoint_path)
    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                     save_weights_only=True)
    autoencoder.fit(x_train + noise_sigma * np.random.randn(*x_train.shape), x_train,
                    epochs=15,
                    batch_size=128,
                    shuffle=True,
                    validation_data=(x_test, x_test),
                    callbacks=[cp_callback])
else:
    autoencoder.load_weights(checkpoint_path)

decoded_imgs = autoencoder.predict(x_test)
latent_codes = encoder.predict(x_test)
decoded_imgs = decoder.predict(latent_codes)

n = 10
plt.figure(figsize=(20, 4))
for i in range(1, n + 1):
    # Display original
    ax = plt.subplot(2, n, i)
    plt.imshow(x_test[i].reshape(32, 32))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # Display reconstruction
    ax = plt.subplot(2, n, i + n)
    plt.imshow(decoded_imgs[i].reshape(32, 32))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

# Your code starts here:
if TASK2:
    # Classifer Network - currently minimal
    classifier = Sequential()
    classifier.add(layers.Dense(128, activation='relu', input_shape=(latent_dim,)))
    classifier.add(layers.Dense(128, activation='relu'))
    classifier.add(layers.Dense(64, activation='relu'))
    classifier.add(layers.Dense(10, activation='softmax'))

    train_codes = encoder.predict(x_train[:sml_train_size])
    test_codes = encoder.predict(x_test)

    classifier.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

    # es_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, min_delta=0.001)

    classifier.fit(train_codes, y_train[:sml_train_size],
                   epochs=200,
                   batch_size=16,
                   shuffle=True,
                   validation_data=(test_codes, y_test))

    full_cls_enc = keras.models.clone_model(encoder)
    full_cls_cls = keras.models.clone_model(classifier)
    full_cls = keras.Model(full_cls_enc.inputs, full_cls_cls(full_cls_enc.outputs))

    full_cls.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

    full_cls.fit(x_train[:sml_train_size], y_train[:sml_train_size],
                 epochs=100,
                 batch_size=16,
                 shuffle=True,
                 validation_data=(x_test, y_test))

# Task 3 Functions:


def get_gen_model(noise_shape, latent_dim):
    # gets gaussian sample
    # outputs latent_space
    generator = Sequential()
    generator.add(layers.Dense(1024, activation=LeakyReLU(alpha=0.2), input_shape=(noise_shape,)))
    generator.add(layers.Dense(1024, activation=LeakyReLU(alpha=0.2)))
    generator.add(layers.Dense(512, activation=LeakyReLU(alpha=0.2)))
    generator.add(layers.Dense(latent_dim))
    return generator


def get_disc_model(latent_dim):
    # gets gaussian sample
    # outputs latent_space
    discriminator = Sequential()
    discriminator.add(layers.Dense(128, activation=LeakyReLU(alpha=0.2), input_shape=(latent_dim,)))
    discriminator.add(layers.Dropout(0.4))
    discriminator.add(layers.Dense(128, activation=LeakyReLU(alpha=0.2)))
    discriminator.add(layers.Dropout(0.4))
    discriminator.add(layers.Dense(1, activation="sigmoid"))
    opt = tf.keras.optimizers.Adam(lr=0.0002, beta_1=0.5)
    discriminator.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
    discriminator.summary()
    return discriminator


def build_gan(gen, disc):
    disc.trainable = False
    GAN = Sequential()
    GAN.add(gen)
    GAN.add(disc)
    opt = tf.keras.optimizers.Adam(lr=0.0002, beta_1=0.5)
    GAN.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
    GAN.summary()
    return GAN


def GAN_train(epoch_num, batch_size, noise_size):
    batch_per_epoch = int(x_train.shape[0] / batch_size)
    print(batch_per_epoch)
    gen = get_gen_model(noise_size, latent_dim)
    disc = get_disc_model(latent_dim)
    GAN = build_gan(gen, disc)
    for i in range(epoch_num):
        for j in range(batch_per_epoch):
            # sample real
            real_samples = encoder.predict(tf.random.shuffle(x_train)[:int(batch_size / 2)])
            # sample fake
            random_samples = tf.random.normal(shape=[int(batch_size / 2), noise_size])
            fake_samples = gen.predict(random_samples)
            X = np.vstack([real_samples, fake_samples])
            y = np.vstack([np.ones(shape=[int(batch_size / 2), 1]), np.zeros(shape=[int(batch_size / 2), 1])])
            disc_loss, disc_acc = disc.train_on_batch(X, y)
            random_sample = tf.random.normal(shape=[batch_size, noise_size])
            labels = np.ones(batch_size)
            gan_loss, __ = GAN.train_on_batch(random_sample, labels)
        print("\nEPOCH {} gan_loss = {}, disc_loss = {}, disc_acc = {}".format(i + 1, gan_loss, disc_loss, disc_acc))
        var1 = gen(tf.ones([1, noise_size]))
        var2 = gen(tf.zeros([1, noise_size]))
        var3 = gen(tf.random.uniform(shape=[1, noise_size]))

        im = decoder.predict(var1)
        im1 = decoder.predict(var2)
        im2 = decoder.predict(var3)
        fig1 = plt.figure()
        plt.imshow(tf.squeeze(im, [0, 3]), cmap="gray")
        plt.show()
    return GAN

def GAN_interpolate(GAN):
    """

    :return:
    """

# Task 3 Constants


noise_size = np.array(64)
batch_size = 256
epoch_num = 100
batch_per_epoch = 20


if TASK3:
    GAN_path = 'model_save/gan_cp'
    if TRAIN_GAN:
        GAN_model = GAN_train(epoch_num, batch_size, noise_size)
        GAN_model.save(GAN_path)
    else:
        # perform interpolation task
        GAN_model = tf.keras.models.load_model(GAN_path)
        AE_interpolation, latent_interpolation = GAN_interpolate()



Editor is loading...