Epoch before/after
unknown
plain_text
a year ago
1.2 kB
13
Indexable
import matplotlib.pyplot as plt from tensorflow.keras.callbacks import Callback class DisplayImagesCallback(Callback): def __init__(self, test_images, display_freq=1): super().__init__() self.test_images = test_images self.display_freq = display_freq def on_epoch_end(self, epoch, logs=None): if (epoch + 1) % self.display_freq == 0: random_index = np.random.randint(0, len(self.test_images)) test_image = self.test_images[random_index:random_index+1] cleaned_image = self.model.predict(test_image) plt.figure(figsize=(6, 3)) plt.subplot(1, 2, 1) plt.imshow(test_image[0], cmap='gray') plt.title('Before') plt.axis('on') plt.subplot(1, 2, 2) plt.imshow(cleaned_image[0], cmap='gray') plt.title('After') plt.axis('on') plt.show() # Example usage: display_images_callback = DisplayImagesCallback(train_clean) model.fit( train_noisy, train_clean, epochs=20, batch_size=32, validation_data=(test_noisy, test_clean), callbacks=[display_images_callback] )
Editor is loading...
Leave a Comment