Epoch before/after
unknown
plain_text
2 years ago
1.2 kB
14
Indexable
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import Callback
class DisplayImagesCallback(Callback):
    def __init__(self, test_images, display_freq=1):
        super().__init__()
        self.test_images = test_images
        self.display_freq = display_freq
    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % self.display_freq == 0:
            
            random_index = np.random.randint(0, len(self.test_images))
            test_image = self.test_images[random_index:random_index+1]
     
            cleaned_image = self.model.predict(test_image)
       
            plt.figure(figsize=(6, 3))
            plt.subplot(1, 2, 1)
            plt.imshow(test_image[0], cmap='gray')
            plt.title('Before')
            plt.axis('on')
            plt.subplot(1, 2, 2)
            plt.imshow(cleaned_image[0], cmap='gray')
            plt.title('After')
            plt.axis('on')
            plt.show()
# Example usage:
display_images_callback = DisplayImagesCallback(train_clean)
model.fit(
    train_noisy, 
    train_clean, 
    epochs=20, 
    batch_size=32, 
    validation_data=(test_noisy, test_clean),
    callbacks=[display_images_callback]
)
Editor is loading...
Leave a Comment