Untitled
unknown
plain_text
a year ago
2.8 kB
7
Indexable
import os
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.callbacks import EarlyStopping
# Define directories
train_dir = 'data/train'
validation_dir = 'data/validation'
# Image dimensions
img_width, img_height = 640, 640
batch_size = 1
# Data augmentation and rescaling
train_datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True
)
validation_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='binary'
)
validation_generator = validation_datagen.flow_from_directory(
validation_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='binary'
)
# Build the model
model = Sequential([
Conv2D(32, (3, 3), activation='relu', input_shape=(img_width, img_height, 3)),
MaxPooling2D(pool_size=(2, 2)),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D(pool_size=(2, 2)),
Conv2D(128, (3, 3), activation='relu'),
MaxPooling2D(pool_size=(2, 2)),
Flatten(),
Dense(512, activation='relu'),
Dropout(0.5),
Dense(1, activation='sigmoid')
])
model.compile(
loss='binary_crossentropy',
optimizer='adam',
metrics=['accuracy']
)
# Early stopping to avoid overfitting
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
# Train the model
history = model.fit(
train_generator,
steps_per_epoch=train_generator.samples // batch_size,
validation_steps=validation_generator.samples // batch_size,
validation_data=validation_generator,
epochs=50,
callbacks=[early_stopping]
)
# Save the model
model.save('fish_detection_model.h5')
# Evaluate the model
loss, accuracy = model.evaluate(validation_generator)
print(f"Validation loss: {loss}")
print(f"Validation accuracy: {accuracy}")
# Function to predict if an image contains a fish
def predict_image(image_path):
from tensorflow.keras.preprocessing import image
img = image.load_img(image_path, target_size=(img_width, img_height))
img_tensor = image.img_to_array(img) / 255.
img_tensor = np.expand_dims(img_tensor, axis=0)
prediction = model.predict(img_tensor)
return 'Fish' if prediction[0][0] > 0.5 else 'Not Fish'
# Example usage
print(predict_image('data/test/fish.jpg'))
Editor is loading...
Leave a Comment