Untitled
unknown
plain_text
a year ago
1.7 kB
8
Indexable
% Load images and create an image datastore
faceData = imageDatastore({'Face1', 'Face2', 'Face3'}, 'IncludeSubfolders', true, 'LabelSource', 'foldernames');
% Define a function to resize images
resizeFcn = @(img) imresize(img, [224 224]);
% Apply the resize function to the datastore
faceData.ReadFcn = @(loc) resizeFcn(imread(loc));
% Count the number of unique labels
labelCount = numel(unique(faceData.Labels));
% Split the data into training and validation sets
[trainData, valData] = splitEachLabel(faceData, 0.8, 'randomize');
% Load a pre-trained MobileNetV2 network
net = mobilenetv2;
% Create a layer graph from the network
lgraph = layerGraph(net);
% Find the layer to replace
oldLayer = lgraph.Layers(end-2);
newFCLayer = fullyConnectedLayer(labelCount, 'Name', 'new_fc', 'WeightLearnRateFactor', 10, 'BiasLearnRateFactor', 10);
% Replace the fully connected layer with the new one
lgraph = replaceLayer(lgraph, oldLayer.Name, newFCLayer);
% Define the new classification layers
newLayers = [
softmaxLayer('Name', 'new_softmax')
classificationLayer('Name', 'new_classoutput')];
% Connect the new layers to the network
lgraph = addLayers(lgraph, newLayers);
lgraph = connectLayers(lgraph, 'new_fc', 'new_softmax');
% Set training options
options = trainingOptions('sgdm', ...
'MiniBatchSize', 10, ...
'MaxEpochs', 10, ...
'InitialLearnRate', 0.001, ...
'ValidationData', valData, ...
'ValidationFrequency', 10, ...
'Verbose', false, ...
'Plots', 'training-progress');
% Train the network
netTransfer = trainNetwork(trainData, lgraph, options);
% Save the trained network
save('trainedFaceNet.mat', 'netTransfer');
Editor is loading...
Leave a Comment