Main Content

For most tasks, you can control the training algorithm details using the `trainingOptions`

and `trainNetwork`

functions. If the `trainingOptions`

function does not provide the options you need for your task
(for example, a custom learning rate schedule), then you can define your own custom training
loop using a `dlnetwork`

object. A `dlnetwork`

object allows you to train a network specified as a layer graph
using automatic differentiation.

To specify the same options as the `trainingOptions`

, use these examples as a guide:

Training Option | `trainingOptions` Argument | Example |
---|---|---|

Adam solver | Adaptive Moment Estimation (ADAM) | |

RMSProp solver | Root Mean Square Propagation (RMSProp) | |

SGDM solver | Stochastic Gradient Descent with Momentum (SGDM) | |

Learn rate | `'InitialLearnRate'` | Learn Rate |

Learn rate schedule | Piecewise Learn Rate Schedule | |

Training progress | `'Plots'` | Plots |

Verbose output | Verbose Output | |

Mini-batch size | `'MiniBatchSize'` | Mini-Batch Size |

Number of epochs | `'MaxEpochs'` | Number of Epochs |

Validation | Validation | |

L_{2} regularization | `'L2Regularization'` | L2 Regularization |

Gradient clipping | Gradient Clipping | |

Single CPU or GPU training | `'ExecutionEnvironment'` | Single CPU or GPU Training |

Checkpoints | `'CheckpointPath'` | Checkpoints |

To specify the solver, use the `adamupdate`

,
`rmspropupdate`

, and `sgdmupdate`

functions for the update step in your training loop. To implement your own custom
solver, update the learnable parameters using the `dlupdate`

function.

To update your network parameters using Adam, use the `adamupdate`

function. Specify the gradient decay and the squared
gradient decay factors using the corresponding input arguments.

To update your network parameters using RMSProp, use the `rmspropupdate`

function. Specify the denominator offset (epsilon)
value using the corresponding input argument.

To update your network parameters using SGDM, use the `sgdmupdate`

function. Specify the momentum using the corresponding
input argument.

To specify the learn rate, use the learn rate input arguments of the `adamupdate`

,
`rmspropupdate`

, and `sgdmupdate`

functions.

To easily adjust the learn rate or use it for custom learn rate schedules, set the initial learn rate before the custom training loop.

learnRate = 0.01;

To automatically drop the learn rate during training using a piecewise learn rate schedule, multiply the learn rate by a given drop factor after a specified interval.

To easily specify a piecewise learn rate schedule, create the variables
`learnRate`

, `learnRateSchedule`

,
`learnRateDropFactor`

, and
`learnRateDropPeriod`

, where `learnRate`

is
the initial learn rate, `learnRateScedule`

contains either
`"piecewise"`

or `"none"`

,
`learnRateDropFactor`

is a scalar in the range [0, 1] that
specifies the factor for dropping the learning rate, and
`learnRateDropPeriod`

is a positive integer that specifies how
many epochs between dropping the learn
rate.

```
learnRate = 0.01;
learnRateSchedule = "piecewise"
learnRateDropPeriod = 10;
learnRateDropFactor = 0.1;
```

Inside the training loop, at the end of each epoch, drop the learn rate when the
`learnRateSchedule`

option is `"piecewise"`

and the current epoch number is a multiple of
`learnRateDropPeriod`

. Set the new learn rate to the product of the
learn rate and the learn rate drop
factor.

if learnRateSchedule == "piecewise" && mod(epoch,learnRateDropPeriod) == 0 learnRate = learnRate * learnRateDropFactor; end

To plot the training loss and accuracy during training, calculate the mini-batch loss and either the accuracy or the root-mean-squared-error (RMSE) in the model gradients function and plot them using an animated line.

To easily specify that the plot should be on or off, create the variable
`plots`

that contains either `"training-progress"`

or `"none"`

. To also plot validation metrics, use the same options
`validationData`

and `validationFrequency`

described in Validation.

```
plots = "training-progress";
validationData = {XValidation, YValidation};
validationFrequency = 50;
```

Before training, initialize the animated lines using the
`animatedline`

function. For classification tasks create a plot
for the training accuracy and the training loss. Also initialize animated lines for
validation metrics when validation data is specified.

if plots == "training-progress" figure subplot(2,1,1) lineAccuracyTrain = animatedline; ylabel("Accuracy") subplot(2,1,2) lineLossTrain = animatedline; xlabel("Iteration") ylabel("Loss") if ~isempty(validationData) subplot(2,1,1) lineAccuracyValidation = animatedline; subplot(2,1,2) lineLossValidation = animatedline; end end

For regression tasks, adjust the code by changing the variable names and labels so that it initializes plots for the training and validation RMSE instead of the training and validation accuracy.

Inside the training loop, at the end of an iteration, update the plot so that it includes the appropriate metrics for the network. For classification tasks, add points corresponding to the mini-batch accuracy and the mini-batch loss. If the validation data is nonempty, and the current iteration is either 1 or a multiple of the validation frequency option, then also add points for the validation data.

if plots == "training-progress" addpoints(lineAccuracyTrain,iteration,accuracyTrain) addpoints(lineLossTrain,iteration,lossTrain) if ~isempty(validationData) && (iteration == 1 || mod(iteration,validationFrequency) == 0) addpoints(lineAccuracyValidation,iteration,accuracyValidation) addpoints(lineLossValidation,iteration,lossValidation) end end

`accuracyTrain`

and `lossTrain`

correspond to the
mini-batch accuracy and loss calculated in the model gradients function. For regression
tasks, use the mini-batch RMSE losses instead of the mini-batch accuracies.**Tip**

The `addpoints`

function requires the data points to have type
`double`

. To extract numeric data from `dlarray`

objects, use the `extractdata`

function. To collect data from a
GPU, use the `gather`

function.

To learn how to compute validation metrics, see Validation.

To display the training loss and accuracy during training in a verbose table,
calculate the mini-batch loss and either the accuracy (for classification tasks) or the
RMSE (for regression tasks) in the model gradients function and display them using the
`disp`

function.

To easily specify that the verbose table should be on or off, create the variables
`verbose`

and `verboseFrequency`

, where
`verbose`

is `true`

or `false`

and `verbosefrequency`

specifies how many iterations between printing
verbose output. To display validation metrics, use the same options
`validationData`

and `validationFrequency`

described in Validation.

verbose = true verboseFrequency = 50; validationData = {XValidation, YValidation}; validationFrequency = 50;

Before training, display the verbose output table headings and initialize a timer
using the `tic`

function.

disp("|======================================================================================================================|") disp("| Epoch | Iteration | Time Elapsed | Mini-batch | Validation | Mini-batch | Validation | Base Learning |") disp("| | | (hh:mm:ss) | Accuracy | Accuracy | Loss | Loss | Rate |") disp("|======================================================================================================================|") start = tic;

For regression tasks, adjust the code so that it displays the training and validation RMSE instead of the training and validation accuracy.

Inside the training loop, at the end of an iteration, print the verbose output when
the `verbose`

option is `true`

and it is either the
first iteration or the iteration number is a multiple of
`verboseFrequency`

.

if verbose && (iteration == 1 || mod(iteration,verboseFrequency) == 0 D = duration(0,0,toc(start),'Format','hh:mm:ss'); if isempty(validationData) || mod(iteration,validationFrequency) ~= 0 accuracyValidation = ""; lossValidation = ""; end disp("| " + ... pad(epoch,7,'left') + " | " + ... pad(iteration,11,'left') + " | " + ... pad(D,14,'left') + " | " + ... pad(accuracyTrain,12,'left') + " | " + ... pad(accuracyValidation,12,'left') + " | " + ... pad(lossTrain,12,'left') + " | " + ... pad(lossValidation,12,'left') + " | " + ... pad(learnRate,15,'left') + " |") end

For regression tasks, adjust the code so that it displays the training and validation RMSE instead of the training and validation accuracy.

When training is finished, print the last border of the verbose table.

`disp("|======================================================================================================================|")`

To learn how to compute validation metrics, see Validation.

Setting the mini-batch size depends on the format of data or type of datastore used.

To easily specify the mini-batch size, create a variable
`miniBatchSize`

.

miniBatchSize = 128;

For data in an image datastore, before training, set the `ReadSize`

property of the datastore to the mini-batch
size.

imds.ReadSize = miniBatchSize;

For data in an augmented image datastore, before training, set the
`MiniBatchSize`

property of the datastore to the mini-batch
size.

augimds.MiniBatchSize = miniBatchSize;

For in-memory data, during training at the start of each iteration, read the observations directly from the array.

idx = ((iteration - 1)*miniBatchSize + 1):(iteration*miniBatchSize); X = XTrain(:,:,:,idx);

Specify the maximum number of epochs for training in the outer `for`

loop of the training loop.

To easily specify the maximum number of epochs, create the variable
`maxEpochs`

that contains the maximum number of
epochs.

maxEpochs = 30;

In the outer `for`

loop of the training loop, specify to loop over
the range 1, 2, …,
`maxEpochs`

.

for epoch = 1:maxEpochs ... end

To validate your network during training, set aside a held-out validation set and evaluate how well the network performs on that data.

To easily specify validation options, create the variables
`validationData`

and `validationFrequency`

, where
`validationData`

contains the validation data or is empty and
`validationFrequency`

specifies how many iterations between
validating the
network.

validationData = {XValidation,YValidation}; validationFrequency = 50;

During the training loop, after updating the network parameters, test how well the
network performs on the held-out validation set using the `predict`

function. Validate the network only when validation data is specified and it is either
the first iteration or the current iteration is a multiple of the
`validationFrequency`

option.

if iteration == 1 || mod(iteration,validationFrequency) == 0 dlYPredValidation = predict(dlnet,dlXValidation); lossValidation = crossentropy(softmax(dlYPredValidation), YValidation); [~,idx] = max(dlYPredValidation); labelsPredValidation = classNames(idx); accuracyValidation = mean(labelsPredValidation == labelsValidation); end

`YValidation`

is a dummy variable corresponding to the labels in
`classNames`

. To calculate the accuracy, convert
`YValidation`

to an array of labels.For regression tasks, adjust the code so that it calculates the validation RMSE instead of the validation accuracy.

To stop training early when the loss on the held-out validation stops decreasing, use a flag to break out of the training loops.

To easily specify the validation patience (the number of times that the validation
loss can be larger than or equal to the previously smallest loss before network
training stops), create the variable
`validationPatience`

.

validationPatience = 5;

Before training, initialize a variables `earlyStop`

and
`validationLosses`

, where `earlyStop`

is a
flag to stop training early and `validationLosses`

contains the
losses to compare. Initialize the early stopping flag with `false`

and array of validation losses with
`inf`

.

earlyStop = false; if isfinite(validationPatience) validationLosses = inf(1,validationPatience); end

Inside the training loop, in the loop over mini-batches, add the
`earlyStop`

flag to the loop
condition.

while hasdata(ds) && ~earlyStop ... end

During the validation step, append the new validation loss to the array
`validationLosses`

. If the first element of the array is the
smallest, then set the `earlyStop`

flag to `true`

.
Otherwise, remove the first
element.

if isfinite(validationPatience) validationLosses = [validationLosses validationLoss]; if min(validationLosses) == validationLosses(1) earlyStop = true; else validationLosses(1) = []; end end

To apply L_{2} regularization to the weights, use the
`dlupdate`

function.

To easily specify the L_{2} regularization factor, create the
variable `l2Regularization`

that contains the L_{2}
regularization
factor.

l2Regularization = 0.0001;

During training, after computing the model gradients, for each of the weight
parameters, add the product of the L_{2} regularization factor and
the weights to the computed gradients using the `dlupdate`

function.
To update only the weight parameters, extract the parameters with name
`"Weights"`

.

```
idx = dlnet.Learnables.Parameter == "Weights";
gradients(idx,:) = dlupdate(@(g,w) g + l2Regularization*w, gradients(idx,:), dlnet.Learnables(idx,:));
```

After adding the L_{2} regularization parameter to the gradients,
update the network parameters.

To clip the gradients, use the `dlupdate`

function.

To easily specify gradient clipping options, create the variables
`gradientThresholdMethod`

and `gradientThreshold`

,
where `gradientThresholdMethod`

contains
`"global-l2norm"`

, `"l2norm"`

, or
`"absolute-value"`

, and `gradientThreshold`

is a
positive scalar containing the threshold or
`inf`

.

```
gradientThresholdMethod = "global-l2norm";
gradientThreshold = 2;
```

Create functions named `thresholdGlobalL2Norm`

,
`thresholdL2Norm`

, and `thresholdAbsoluteValue`

that apply the `"global-l2norm"`

, `"l2norm"`

, and
`"absolute-value"`

threshold methods, respectively.

For the `"global-l2norm"`

option, the function operates on all
gradients of the
model.

function gradients = thresholdGlobalL2Norm(gradients,gradientThreshold) globalL2Norm = 0; for i = 1:numel(gradients) globalL2Norm = globalL2Norm + sum(gradients{i}(:).^2); end globalL2Norm = sqrt(globalL2Norm); if globalL2Norm > gradientThreshold normScale = gradientThreshold / globalL2Norm; for i = 1:numel(gradients) gradients{i} = gradients{i} * normScale; end end end

For the `"l2norm"`

and `"absolute-value"`

options,
the functions operate on each gradient
independently.

function gradients = thresholdL2Norm(gradients,gradientThreshold) gradientNorm = sqrt(sum(gradients(:).^2)); if gradientNorm > gradientThreshold gradients = gradients * (gradientThreshold / gradientNorm); end end

function gradients = thresholdAbsoluteValue(gradients,gradientThreshold) gradients(gradients > gradientThreshold) = gradientThreshold; gradients(gradients < -gradientThreshold) = -gradientThreshold; end

During training, after computing the model gradients, apply the appropriate gradient
clipping method to the gradients using the `dlupdate`

function.
Because the `"global-l2norm"`

option requires all the model gradients,
apply the `thresholdGlobalL2Norm`

function directly to the gradients.
For the `"l2norm"`

and `"absolute-value"`

options,
update the gradients independently using the `dlupdate`

function.

switch gradientThresholdMethod case "global-l2norm" gradients = thresholdGlobalL2Norm(gradients, gradientThreshold); case "l2norm" gradients = dlupdate(@(g) thresholdL2Norm(g, gradientThreshold),gradients); case "absolute-value" gradients = dlupdate(@(g) thresholdAbsoluteValue(g, gradientThreshold),gradients); end

After applying the gradient threshold operation, update the network parameters.

The software, by default, performs calculations using only the CPU. To train on a
single GPU, convert the data to `gpuArray`

objects. Using a GPU requires
Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Support by Release (Parallel Computing Toolbox).

To easily specify the execution environment, create the variable `executionEnvironment`

that contains either `"cpu"`

, `"gpu"`

, or `"auto"`

.

`executionEnvironment = "auto"`

During training, after reading a mini-batch, check the execution environment option and convert the data to a `gpuArray`

if necessary. The `canUseGPU`

function checks for useable GPUs.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlX = gpuArray(dlX); end

To save checkpoint networks during training save the network using the
`save`

function.

To easily specify whether checkpoints should be switched on, create the variable
`checkpointPath`

contains the folder for the checkpoint networks or
is
empty.

`checkpointPath = fullfile(tempdir,"checkpoints");`

If the checkpoint folder does not exist, then before training, create the checkpoint folder.

if ~exist(checkpointPath,"dir") mkdir(checkpointPath) end

During training, at the end of an epoch, save the network in a MAT file. Specify a file name containing the current iteration number, date, and time.

if ~isempty(checkpointPath) D = datestr(now,'yyyy_mm_dd__HH_MM_SS'); filename = "dlnet_checkpoint__" + iteration + "__" + D + ".mat"; save(filename,"dlnet") end

`dlnet`

is the `dlnetwork`

object to be
saved.`adamupdate`

| `rmspropupdate`

| `sgdmupdate`

| `dlupdate`

| `dlarray`

| `dlgradient`

| `dlfeval`

| `dlnetwork`

- Define Custom Training Loops, Loss Functions, and Networks
- Define Model Gradients Function for Custom Training Loop
- Train Network Using Custom Training Loop
- Train Network Using Model Function
- Make Predictions Using dlnetwork Object
- Make Predictions Using Model Function
- Initialize Learnable Parameters for Model Function
- Update Batch Normalization Statistics in Custom Training Loop
- Update Batch Normalization Statistics Using Model Function
- Train Generative Adversarial Network (GAN)
- List of Functions with dlarray Support