Untitled
unknown
plain_text
a year ago
4.5 kB
8
Indexable
import os import cv2 import numpy as np import matplotlib.pyplot as plt from sklearn.model_selection import train_test_split import tensorflow as tf from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization from tensorflow.keras.preprocessing.image import ImageDataGenerator # Path to the UTKFace dataset (use raw string) dataset_path = r'C:\Users\60111\Downloads\UTKFace' # Function to load and preprocess images and labels def load_and_preprocess_utkface_data(dataset_path, target_size=(64, 64)): images = [] ages = [] for file_name in os.listdir(dataset_path): if file_name.endswith('.jpg'): age = int(file_name.split('_')[0]) img_path = os.path.join(dataset_path, file_name) img = cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.resize(img, target_size) # Resize image to target size img = img / 255.0 # Normalize pixel values images.append(img) ages.append(age) return np.array(images), np.array(ages) # Load and preprocess images and labels images, ages = load_and_preprocess_utkface_data(dataset_path) # Display the number of images and labels print(f'Number of images: {len(images)}') print(f'Number of labels: {len(ages)}') # Split data into training, validation, and test sets X_train, X_temp, y_train, y_temp = train_test_split(images, ages, test_size=0.2, random_state=42) X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42) print(f'Training set size: {X_train.shape}') print(f'Validation set size: {X_val.shape}') print(f'Test set size: {X_test.shape}') # Define the model architecture with batch normalization and dropout model = Sequential([ Conv2D(32, (3, 3), activation='relu', input_shape=(64, 64, 3)), BatchNormalization(), MaxPooling2D((2, 2)), Conv2D(64, (3, 3), activation='relu'), BatchNormalization(), MaxPooling2D((2, 2)), Conv2D(128, (3, 3), activation='relu'), BatchNormalization(), MaxPooling2D((2, 2)), Flatten(), Dense(128, activation='relu'), Dropout(0.5), Dense(1) # Age is a single continuous value ]) model.compile(optimizer='adam', loss='mean_squared_error', metrics=['mae']) model.summary() # Train the model with early stopping and increased epochs early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True) history = model.fit(X_train, y_train, epochs=20, validation_data=(X_val, y_val), callbacks=[early_stopping]) # Set number of epochs here # Evaluate the model on the test set test_loss, test_mae = model.evaluate(X_test, y_test) print(f'Test MAE: {test_mae}') # Make predictions on the test set predictions = model.predict(X_test) predictions = predictions.flatten() # Calculate Accuracy Metric (Example: Percentage within ±5 years) error_threshold = 5 abs_errors = np.abs(predictions - y_test) accurate_predictions = np.sum(abs_errors <= error_threshold) accuracy_percentage = (accurate_predictions / len(y_test)) * 100 print(f'Accuracy (±{error_threshold} years): {accuracy_percentage:.2f}%') # Save the model model.save('age_prediction_model.h5') # Display some sample predictions and images plt.figure(figsize=(15, 10)) for i in range(5): # Display the image plt.subplot(2, 5, i + 1) plt.imshow(X_test[i]) plt.title(f'Actual Age: {y_test[i]}') plt.axis('off') # Display the predicted age plt.subplot(2, 5, i + 6) plt.bar(['Actual', 'Predicted'], [y_test[i], predictions[i]]) plt.ylim([0, 100]) plt.ylabel('Age') plt.title(f'Predicted Age: {predictions[i]:.2f}') plt.tight_layout() plt.show() # Display training history plt.figure(figsize=(10, 5)) plt.subplot(1, 2, 1) plt.plot(history.history['loss'], label='Training Loss') plt.plot(history.history['val_loss'], label='Validation Loss') plt.title('Training and Validation Loss') plt.xlabel('Epochs') plt.ylabel('Loss') plt.legend() plt.subplot(1, 2, 2) plt.plot(history.history['mae'], label='Training MAE') plt.plot(history.history['val_mae'], label='Validation MAE') plt.title('Training and Validation MAE') plt.xlabel('Epochs') plt.ylabel('MAE') plt.legend() plt.tight_layout() plt.show()
Editor is loading...
Leave a Comment