Untitled

mail@pastecode.io avatar
unknown
plain_text
7 months ago
2.6 kB
1
Indexable
Never
import tensorflow as tf
from tensorflow.keras.layers import Dense, Reshape, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt

# Define the generator model
def build_generator(latent_dim):
    model = Sequential()
    model.add(Dense(128, input_dim=latent_dim, activation='relu'))
    model.add(Dense(784, activation='sigmoid'))
    model.add(Reshape((28, 28, 1)))
    return model

# Define the discriminator model
def build_discriminator():
    model = Sequential()
    model.add(Flatten(input_shape=(28, 28, 1)))
    model.add(Dense(128, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))
    return model

# Define the combined GAN model
def build_gan(generator, discriminator):
    discriminator.trainable = False
    model = Sequential()
    model.add(generator)
    model.add(discriminator)
    return model

# Create the GAN
latent_dim = 100
generator = build_generator(latent_dim)
discriminator = build_discriminator()
gan = build_gan(generator, discriminator)
gan.compile(loss='binary_crossentropy', optimizer=Adam())

# Load and preprocess the dataset (e.g., MNIST)
(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_train = np.expand_dims(x_train, axis=-1)

# Training loop
epochs = 10000
batch_size = 128
for epoch in range(epochs):
    # Select a random batch of images
    idx = np.random.randint(0, x_train.shape[0], batch_size)
    real_images = x_train[idx]

    # Generate a batch of fake images
    noise = np.random.normal(0, 1, (batch_size, latent_dim))
    fake_images = generator.predict(noise)

    # Combine real and fake images
    x_combined = np.concatenate((real_images, fake_images))

    # Labels for real and fake images
    y_real = np.ones((batch_size, 1))
    y_fake = np.zeros((batch_size, 1))
    y_combined = np.concatenate((y_real, y_fake))

    # Train the discriminator
    d_loss = discriminator.train_on_batch(x_combined, y_combined)

    # Train the generator (via the GAN model)
    noise = np.random.normal(0, 1, (batch_size, latent_dim))
    y_misleading = np.ones((batch_size, 1))
    g_loss = gan.train_on_batch(noise, y_misleading)

    # Print progress
    if epoch % 100 == 0:
        print(f"Epoch: {epoch}, Discriminator Loss: {d_loss}, Generator Loss: {g_loss}")

        # Generate and save example images
        noise = np.random.normal(0, 1, (1, latent_dim))
        generated_image = generator.predict(noise)
        plt.imshow(generated_image[0, :, :, 0], cmap='gray')
        plt.axis('off')
        plt.show()
Leave a Comment