Untitled
unknown
python
5 years ago
775 B
6
Indexable
def get_closest_number_to_digit(digit, G):
noise = np.random.uniform(-noise_bound, noise_bound, size=[100000, noise_dim])
latent_vectors = np.array(G.predict(noise))
results = np.array(classifier.predict(latent_vectors))
digit_results = results[:,digit]
indices = np.argsort(digit_results)[-10:]
# argmax_arr = np.argmax(results, axis=0, out=None)
generated_images = decoder.predict(latent_vectors[indices])
n = len(generated_images)
fig, axes = plt.subplots(1, n, figsize=(20,4))
for i, image in enumerate(generated_images):
ax = axes[i]
ax.imshow(image.reshape(32, 32))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
return fig, axes
Editor is loading...