Untitled
unknown
plain_text
2 years ago
2.6 kB
11
Indexable
import tensorflow as tf
from tensorflow.keras.layers import Dense, Reshape, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt
# Define the generator model
def build_generator(latent_dim):
model = Sequential()
model.add(Dense(128, input_dim=latent_dim, activation='relu'))
model.add(Dense(784, activation='sigmoid'))
model.add(Reshape((28, 28, 1)))
return model
# Define the discriminator model
def build_discriminator():
model = Sequential()
model.add(Flatten(input_shape=(28, 28, 1)))
model.add(Dense(128, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
return model
# Define the combined GAN model
def build_gan(generator, discriminator):
discriminator.trainable = False
model = Sequential()
model.add(generator)
model.add(discriminator)
return model
# Create the GAN
latent_dim = 100
generator = build_generator(latent_dim)
discriminator = build_discriminator()
gan = build_gan(generator, discriminator)
gan.compile(loss='binary_crossentropy', optimizer=Adam())
# Load and preprocess the dataset (e.g., MNIST)
(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_train = np.expand_dims(x_train, axis=-1)
# Training loop
epochs = 10000
batch_size = 128
for epoch in range(epochs):
# Select a random batch of images
idx = np.random.randint(0, x_train.shape[0], batch_size)
real_images = x_train[idx]
# Generate a batch of fake images
noise = np.random.normal(0, 1, (batch_size, latent_dim))
fake_images = generator.predict(noise)
# Combine real and fake images
x_combined = np.concatenate((real_images, fake_images))
# Labels for real and fake images
y_real = np.ones((batch_size, 1))
y_fake = np.zeros((batch_size, 1))
y_combined = np.concatenate((y_real, y_fake))
# Train the discriminator
d_loss = discriminator.train_on_batch(x_combined, y_combined)
# Train the generator (via the GAN model)
noise = np.random.normal(0, 1, (batch_size, latent_dim))
y_misleading = np.ones((batch_size, 1))
g_loss = gan.train_on_batch(noise, y_misleading)
# Print progress
if epoch % 100 == 0:
print(f"Epoch: {epoch}, Discriminator Loss: {d_loss}, Generator Loss: {g_loss}")
# Generate and save example images
noise = np.random.normal(0, 1, (1, latent_dim))
generated_image = generator.predict(noise)
plt.imshow(generated_image[0, :, :, 0], cmap='gray')
plt.axis('off')
plt.show()Editor is loading...
Leave a Comment