Trainer
type
The type of the trainer. The following types are available:
basica basic trainer with a standard training loop.timm_basica basic trainer with the timm learning rate scheduler.-
diff_privacya trainer that supports local differential privacy in its training loop by adding noise to the gradients during each step of training.max_physical_batch_size
The limit on the physical batch size when using the
diff_privacytrainer.Default value:
128. The GPU memory usage of one process training the ResNet-18 model is around 2817 MB.dp_epsilon
Total privacy budget of epsilon with the
diff_privacytrainer.Default value:
10.0dp_delta
Total privacy budget of delta with the
diff_privacytrainer.Default value:
1e-5dp_max_grad_norm
The maximum norm of the per-sample gradients with the
diff_privacytrainer. Any gradient with norm higher than this will be clipped to this value.Default value:
1.0 -
split_learninga trainer that supports the split learning framework. self_supervised_learninga trainer that supports personalized federated learning based on self supervised learning.gana trainer for Generative Adversarial Networks (GANs).
rounds
The maximum number of training rounds.
round could be any positive integer.
max_concurrency
The maximum number of clients (of each edge server in cross-silo training) running concurrently on each available GPU. If this is not defined, no new processes are spawned for training.
Note
Plato will automatically use all available GPUs to maximize the concurrency of training, launching the same number of clients on every GPU. If max_concurrency is 7 and 3 GPUs are available, 21 client processes will be launched for concurrent training.
target_accuracy
The target accuracy of the global model.
target_perplexity
The target perplexity of the global Natural Language Processing (NLP) model.
epochs
The total number of epochs in local training in each communication round.
batch_size
The size of the mini-batch of data in each step (iteration) of the training loop.
optimizer
The type of the optimizer. The following options are supported:
AdamAdadeltaAdagradAdaHessian(from thetorch_optimizerpackage)AdamWSparseAdamAdamaxASGDLBFGSNAdamRAdamRMSpropRpropSGD
lr_scheduler
The learning rate scheduler. The following learning rate schedulers are supported:
CosineAnnealingLRLambdaLRMultiStepLRStepLRReduceLROnPlateauConstantLRLinearLRExponentialLRCyclicLRCosineAnnealingWarmRestarts
Alternatively, all four schedulers from timm are supported if lr_scheduler is specified as timm and trainer -> type is specified as timm_basic. For example, to use the SGDR scheduler, we specify cosine as sched in its arguments (parameters -> learning_rate):
[trainer]
type = "timm_basic"
[parameters]
[parameters.learning_rate]
sched = cosine
min_lr = 1.e-6
warmup_lr = 0.0001
warmup_epochs = 3
cooldown_epochs = 10
loss_criterion
The loss criterion. The following options are supported:
L1LossMSELossBCELossBCEWithLogitsLossNLLLossPoissonNLLLossCrossEntropyLossHingeEmbeddingLossMarginRankingLossTripletMarginLossKLDivLossNegativeCosineSimilarityNTXentLossSwaVLoss
global_lr_scheduler
Whether the learning rate should be scheduled globally (true) or not (false).
If true, the learning rate of the first epoch in the next communication round is scheduled based on that of the last epoch in the previous communication round.
model_type
The repository where the machine learning model should be retrieved from. The following options are available:
cnn_encoder(for generating various encoders by extracting from CNN models such as ResNet models)general_multilayer(for generating a multi-layer perceptron using a provided configuration)huggingface(for HuggingFace causal language models)torch_hub(for models from PyTorch Hub)vit(for Vision Transformer models from HuggingFace, Tokens-to-Token ViT, and Deep Vision Transformer)
The name of the model should be specified below, in model_name.
Note
For vit, please replace the / in model name from https://huggingface.co/models with @. For example, use google@vit-base-patch16-224-in21k instead of google/vit-base-patch16-224-in21k. If you do not want to use the pretrained weights, set parameters -> model -> pretrained to false, as in the following example:
[parameters]
[parameters.model]
pretrained = false
model_name
The name of the machine learning model. The following options are available:
lenet5resnet_xvgg_xdcganmultilayer
Note
If the model_type above specified a model repository, supply the name of the model, such as gpt2, here.
For resnet_x, x = 18, 34, 50, 101, or 152; For vgg_x, x = 11, 13, 16, or 19.