Untitled

 avatar
unknown
python
4 years ago
775 B
4
Indexable
def get_closest_number_to_digit(digit, G):
    noise = np.random.uniform(-noise_bound, noise_bound, size=[100000, noise_dim])
    latent_vectors = np.array(G.predict(noise))
    results = np.array(classifier.predict(latent_vectors))
    digit_results = results[:,digit]
    indices = np.argsort(digit_results)[-10:]
    # argmax_arr = np.argmax(results, axis=0, out=None)
    generated_images = decoder.predict(latent_vectors[indices])
    n = len(generated_images)
    fig, axes = plt.subplots(1, n, figsize=(20,4))
    for i, image in enumerate(generated_images):
        ax = axes[i]
        ax.imshow(image.reshape(32, 32))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    
    return fig, axes
Editor is loading...