Epoch before/after

mail@pastecode.io avatar
unknown
plain_text
2 months ago
1.2 kB
10
Indexable
Never
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import Callback

class DisplayImagesCallback(Callback):
    def __init__(self, test_images, display_freq=1):
        super().__init__()
        self.test_images = test_images
        self.display_freq = display_freq

    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % self.display_freq == 0:
            
            random_index = np.random.randint(0, len(self.test_images))
            test_image = self.test_images[random_index:random_index+1]

     
            cleaned_image = self.model.predict(test_image)

       
            plt.figure(figsize=(6, 3))
            plt.subplot(1, 2, 1)
            plt.imshow(test_image[0], cmap='gray')
            plt.title('Before')
            plt.axis('on')

            plt.subplot(1, 2, 2)
            plt.imshow(cleaned_image[0], cmap='gray')
            plt.title('After')
            plt.axis('on')

            plt.show()

# Example usage:
display_images_callback = DisplayImagesCallback(train_clean)

model.fit(
    train_noisy, 
    train_clean, 
    epochs=20, 
    batch_size=32, 
    validation_data=(test_noisy, test_clean),
    callbacks=[display_images_callback]
)
Leave a Comment