Untitled
unknown
plain_text
10 months ago
1.7 kB
3
Indexable
% Load images and create an image datastore faceData = imageDatastore({'Face1', 'Face2', 'Face3'}, 'IncludeSubfolders', true, 'LabelSource', 'foldernames'); % Define a function to resize images resizeFcn = @(img) imresize(img, [224 224]); % Apply the resize function to the datastore faceData.ReadFcn = @(loc) resizeFcn(imread(loc)); % Count the number of unique labels labelCount = numel(unique(faceData.Labels)); % Split the data into training and validation sets [trainData, valData] = splitEachLabel(faceData, 0.8, 'randomize'); % Load a pre-trained MobileNetV2 network net = mobilenetv2; % Create a layer graph from the network lgraph = layerGraph(net); % Find the layer to replace oldLayer = lgraph.Layers(end-2); newFCLayer = fullyConnectedLayer(labelCount, 'Name', 'new_fc', 'WeightLearnRateFactor', 10, 'BiasLearnRateFactor', 10); % Replace the fully connected layer with the new one lgraph = replaceLayer(lgraph, oldLayer.Name, newFCLayer); % Define the new classification layers newLayers = [ softmaxLayer('Name', 'new_softmax') classificationLayer('Name', 'new_classoutput')]; % Connect the new layers to the network lgraph = addLayers(lgraph, newLayers); lgraph = connectLayers(lgraph, 'new_fc', 'new_softmax'); % Set training options options = trainingOptions('sgdm', ... 'MiniBatchSize', 10, ... 'MaxEpochs', 10, ... 'InitialLearnRate', 0.001, ... 'ValidationData', valData, ... 'ValidationFrequency', 10, ... 'Verbose', false, ... 'Plots', 'training-progress'); % Train the network netTransfer = trainNetwork(trainData, lgraph, options); % Save the trained network save('trainedFaceNet.mat', 'netTransfer');
Editor is loading...
Leave a Comment