Epoch before/after
unknown
plain_text
2 years ago
1.2 kB
17
Indexable
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import Callback
class DisplayImagesCallback(Callback):
def __init__(self, test_images, display_freq=1):
super().__init__()
self.test_images = test_images
self.display_freq = display_freq
def on_epoch_end(self, epoch, logs=None):
if (epoch + 1) % self.display_freq == 0:
random_index = np.random.randint(0, len(self.test_images))
test_image = self.test_images[random_index:random_index+1]
cleaned_image = self.model.predict(test_image)
plt.figure(figsize=(6, 3))
plt.subplot(1, 2, 1)
plt.imshow(test_image[0], cmap='gray')
plt.title('Before')
plt.axis('on')
plt.subplot(1, 2, 2)
plt.imshow(cleaned_image[0], cmap='gray')
plt.title('After')
plt.axis('on')
plt.show()
# Example usage:
display_images_callback = DisplayImagesCallback(train_clean)
model.fit(
train_noisy,
train_clean,
epochs=20,
batch_size=32,
validation_data=(test_noisy, test_clean),
callbacks=[display_images_callback]
)
Editor is loading...
Leave a Comment