Untitled

mail@pastecode.io avatar
unknown
python
a month ago
2.1 kB
1
Indexable
Never
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.layers import Dense, Conv2D, Flatten
from tensorflow.keras.models import Sequential
import numpy as np
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator

def load_train(path):
    # указываем загрузчику, что валидация содержит 
    # 25% случайных объектов
    datagen = ImageDataGenerator(validation_split=0.25)

    datagen_flow = datagen.flow_from_directory(
        # папка, в которой хранится датасет
        '/datasets/fruits_small/',
        # к какому размеру приводить изображения
        target_size=(150, 150), 
        # размер батча
        batch_size=16,
        # в каком виде выдавать метки классов
        class_mode='sparse',
        # фиксируем генератор случайных чисел (от англ. random seed)
        seed=12345)
    # англ. индексы классов
#     print(datagen_flow.class_indices)
    return datagen_flow



def create_model(input_shape):
    model = Sequential()
    model.add(Conv2D(filters=4, 
                     kernel_size=(3, 3), 
                     input_shape=(150, 150, 3), 
                     activation='relu', 
                     padding='same'))
    model.add(Flatten())
    model.add(Dense(units=12, activation='softmax'))
    model.compile(loss='sparse_categorical_crossentropy', optimizer='SGD', metrics=['acc'])

    return model

def train_model(model, train_data, test_data, batch_size=None, epochs=10,
                steps_per_epoch=None, validation_steps=None):
    model.fit(train_data,
              validation_data=test_data,
              batch_size=batch_size, epochs=epochs,
              steps_per_epoch=steps_per_epoch,
              validation_steps=validation_steps,
              verbose=2)
    return model
Leave a Comment