Untitled
unknown
plain_text
a year ago
2.8 kB
3
Indexable
import os os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' import numpy as np import tensorflow as tf from tensorflow.keras.preprocessing.image import ImageDataGenerator from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout from tensorflow.keras.callbacks import EarlyStopping # Define directories train_dir = 'data/train' validation_dir = 'data/validation' # Image dimensions img_width, img_height = 640, 640 batch_size = 1 # Data augmentation and rescaling train_datagen = ImageDataGenerator( rescale=1./255, shear_range=0.2, zoom_range=0.2, horizontal_flip=True ) validation_datagen = ImageDataGenerator(rescale=1./255) train_generator = train_datagen.flow_from_directory( train_dir, target_size=(img_width, img_height), batch_size=batch_size, class_mode='binary' ) validation_generator = validation_datagen.flow_from_directory( validation_dir, target_size=(img_width, img_height), batch_size=batch_size, class_mode='binary' ) # Build the model model = Sequential([ Conv2D(32, (3, 3), activation='relu', input_shape=(img_width, img_height, 3)), MaxPooling2D(pool_size=(2, 2)), Conv2D(64, (3, 3), activation='relu'), MaxPooling2D(pool_size=(2, 2)), Conv2D(128, (3, 3), activation='relu'), MaxPooling2D(pool_size=(2, 2)), Flatten(), Dense(512, activation='relu'), Dropout(0.5), Dense(1, activation='sigmoid') ]) model.compile( loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'] ) # Early stopping to avoid overfitting early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True) # Train the model history = model.fit( train_generator, steps_per_epoch=train_generator.samples // batch_size, validation_steps=validation_generator.samples // batch_size, validation_data=validation_generator, epochs=50, callbacks=[early_stopping] ) # Save the model model.save('fish_detection_model.h5') # Evaluate the model loss, accuracy = model.evaluate(validation_generator) print(f"Validation loss: {loss}") print(f"Validation accuracy: {accuracy}") # Function to predict if an image contains a fish def predict_image(image_path): from tensorflow.keras.preprocessing import image img = image.load_img(image_path, target_size=(img_width, img_height)) img_tensor = image.img_to_array(img) / 255. img_tensor = np.expand_dims(img_tensor, axis=0) prediction = model.predict(img_tensor) return 'Fish' if prediction[0][0] > 0.5 else 'Not Fish' # Example usage print(predict_image('data/test/fish.jpg'))
Editor is loading...
Leave a Comment