Untitled
unknown
plain_text
7 months ago
2.6 kB
1
Indexable
Never
import tensorflow as tf from tensorflow.keras.layers import Dense, Reshape, Flatten from tensorflow.keras.models import Sequential from tensorflow.keras.optimizers import Adam import numpy as np import matplotlib.pyplot as plt # Define the generator model def build_generator(latent_dim): model = Sequential() model.add(Dense(128, input_dim=latent_dim, activation='relu')) model.add(Dense(784, activation='sigmoid')) model.add(Reshape((28, 28, 1))) return model # Define the discriminator model def build_discriminator(): model = Sequential() model.add(Flatten(input_shape=(28, 28, 1))) model.add(Dense(128, activation='relu')) model.add(Dense(1, activation='sigmoid')) return model # Define the combined GAN model def build_gan(generator, discriminator): discriminator.trainable = False model = Sequential() model.add(generator) model.add(discriminator) return model # Create the GAN latent_dim = 100 generator = build_generator(latent_dim) discriminator = build_discriminator() gan = build_gan(generator, discriminator) gan.compile(loss='binary_crossentropy', optimizer=Adam()) # Load and preprocess the dataset (e.g., MNIST) (x_train, _), (_, _) = tf.keras.datasets.mnist.load_data() x_train = x_train.astype('float32') / 255.0 x_train = np.expand_dims(x_train, axis=-1) # Training loop epochs = 10000 batch_size = 128 for epoch in range(epochs): # Select a random batch of images idx = np.random.randint(0, x_train.shape[0], batch_size) real_images = x_train[idx] # Generate a batch of fake images noise = np.random.normal(0, 1, (batch_size, latent_dim)) fake_images = generator.predict(noise) # Combine real and fake images x_combined = np.concatenate((real_images, fake_images)) # Labels for real and fake images y_real = np.ones((batch_size, 1)) y_fake = np.zeros((batch_size, 1)) y_combined = np.concatenate((y_real, y_fake)) # Train the discriminator d_loss = discriminator.train_on_batch(x_combined, y_combined) # Train the generator (via the GAN model) noise = np.random.normal(0, 1, (batch_size, latent_dim)) y_misleading = np.ones((batch_size, 1)) g_loss = gan.train_on_batch(noise, y_misleading) # Print progress if epoch % 100 == 0: print(f"Epoch: {epoch}, Discriminator Loss: {d_loss}, Generator Loss: {g_loss}") # Generate and save example images noise = np.random.normal(0, 1, (1, latent_dim)) generated_image = generator.predict(noise) plt.imshow(generated_image[0, :, :, 0], cmap='gray') plt.axis('off') plt.show()
Leave a Comment