Untitled

mail@pastecode.io avatar
unknown
plain_text
a month ago
2.4 kB
1
Indexable
Never
% Step 2: Training Script for Two Objects
clear;
clc;

% Load pre-trained MobileNetV2 with ImageNet weights
net = mobilenetv2('Weights','imagenet','InputSize',[224 224 3]);

% Modify the network for transfer learning
numClasses = 2; % Number of classes (faces)
newLayers = [
    globalAveragePooling2dLayer('Name','pool')
    fullyConnectedLayer(numClasses,'Name','fc','WeightLearnRateFactor',10,'BiasLearnRateFactor',10)
    softmaxLayer('Name','softmax')
    classificationLayer('Name','classoutput')];
net = modifyNetwork(net,newLayers);

% Data Directory for face1 and face2
dataDir1 = fullfile(pwd, 'face1'); % Assuming data for face 1 is stored in a folder named 'face1'
dataDir2 = fullfile(pwd, 'face2'); % Assuming data for face 2 is stored in a folder named 'face2'

% Load images and labels for face1
imds1 = imageDatastore(dataDir1, 'IncludeSubfolders', true, 'LabelSource', 'foldernames');
labels1 = repmat(categorical({'face1'}), numel(imds1.Files), 1);

% Load images and labels for face2
imds2 = imageDatastore(dataDir2, 'IncludeSubfolders', true, 'LabelSource', 'foldernames');
labels2 = repmat(categorical({'face2'}), numel(imds2.Files), 1);

% Combine data from both faces
imds = imageDatastore(cat(1, imds1.Files, imds2.Files), 'IncludeSubfolders', true, 'LabelSource', 'foldernames');
labels = cat(1, labels1, labels2);

% Split data into training and validation sets (80-20 split)
[imdsTrain, imdsValidation] = splitEachLabel(imds, 0.8, 'randomized');

% Data Augmentation
augmenter = imageDataAugmenter( ...
    'RandXReflection', true, ...
    'RandYReflection', true, ...
    'RandRotation', [-10 10]);

% Preprocess Images
inputSize = net.Layers(1).InputSize;
augmentedTrainingData = augmentedImageDatastore(inputSize(1:2), imdsTrain, 'DataAugmentation', augmenter);
augmentedValidationData = augmentedImageDatastore(inputSize(1:2), imdsValidation);

% Options for training
options = trainingOptions('sgdm', ...
    'MiniBatchSize', 32, ...
    'MaxEpochs', 10, ...
    'InitialLearnRate', 1e-4, ...
    'ValidationData', augmentedValidationData, ...
    'ValidationFrequency', 10, ...
    'Verbose', true, ...
    'Plots', 'training-progress');

% Train the network
net = trainNetwork(augmentedTrainingData, net, options);

% Save the trained model
save('faceRecognitionModel_2objects.mat', 'net');
Leave a Comment