Untitled
unknown
python
3 years ago
1.5 kB
5
Indexable
def plot_model_history(model_history): """ Plot Accuracy and Loss curves given the model_history """ fig, axs = plt.subplots(1,2,figsize=(15,5)) # summarize history for accuracy axs[0].plot(range(1,len(model_history.history['accuracy'])+1),model_history.history['accuracy']) axs[0].plot(range(1,len(model_history.history['val_accuracy'])+1),model_history.history['val_accuracy']) axs[0].set_title('Model Accuracy') axs[0].set_ylabel('Accuracy') axs[0].set_xlabel('Epoch') # axs[0].set_xticks(np.arange(1,len(model_history.history['accuracy'])+1),len(model_history.history['accuracy'])/10) axs[0].set_xticks(np.arange(1,len(model_history.history['accuracy'])+1,len(model_history.history['accuracy'])/10)) axs[0].legend(['train', 'val'], loc='best') # summarize history for loss axs[1].plot(range(1,len(model_history.history['loss'])+1),model_history.history['loss']) axs[1].plot(range(1,len(model_history.history['val_loss'])+1),model_history.history['val_loss']) axs[1].set_title('Model Loss') axs[1].set_ylabel('Loss') axs[1].set_xlabel('Epoch') # axs[1].set_xticks(np.arange(1,len(model_history.history['loss'])+1),len(model_history.history['loss'])/10) axs[1].set_xticks(np.arange(1,len(model_history.history['loss'])+1,len(model_history.history['loss'])/10)) axs[1].legend(['train', 'val'], loc='best') fig.savefig('plot.png') plt.show()
Editor is loading...