Plato trainers use the same composition model as clients and servers. Every
ComposableTrainer instance wires a small set of interchangeable strategies,
letting you swap behaviour without subclassing:
LossCriterionStrategy computes the objective.
OptimizerStrategy builds and updates the optimiser.
TrainingStepStrategy runs the forward/backward pass.
LRSchedulerStrategy adjusts learning rates.
ModelUpdateStrategy maintains auxiliary state (control variates, fine-tuning).
DataLoaderStrategy creates train/test loaders.
TestingStrategy evaluates the model.
Strategies share state through TrainingContext, which mirrors the trainer’s
model, optimiser, device, round counters, and an extensible state dictionary.
Quick Start
fromplato.trainers.composableimportComposableTrainer# Default stack: sensible strategies for supervised learning.trainer=ComposableTrainer(model=my_model_fn)# Mix and match to customise behaviour.fromplato.trainers.strategiesimportAdamOptimizerStrategyfromplato.trainers.strategies.algorithmsimportFedProxLossStrategyfedprox_trainer=ComposableTrainer(model=my_model_fn,loss_strategy=FedProxLossStrategy(mu=0.01),optimizer_strategy=AdamOptimizerStrategy(lr=1e-3),)
Pass trainer=fedprox_trainer when instantiating clients or servers to reuse the
same strategy stack in every round.
Strategy Extension Points
LossCriterionStrategy: add regularisers or alternate objectives; pull round
metadata from context when needed.
OptimizerStrategy: build custom optimisers or parameter groups; return a
ready-to-use optimiser instance.
TrainingStepStrategy: implement bespoke loops (LG-FedAvg, gradient clipping);
keep tensors on device and reuse the supplied loss_criterion.
LRSchedulerStrategy: wire warmup or timm schedulers by overriding
create_scheduler and optional lifecycle hooks.
ModelUpdateStrategy: persist control variates or personalised heads in
context.state.
DataLoaderStrategy: control sampling, augmentation, or worker config while
honouring batch sizes from the config.
TestingStrategy: customise evaluation logic and return scalar metrics for
downstream logging.
Structured evaluators are layered after the testing strategy. In other
words, TestingStrategy still returns the trainer's scalar metric (accuracy,
perplexity, loss, and so on), and an optional [evaluation] section can then
run a named benchmark adapter such as Lighteval or Nanochat CORE. See
Evaluators for that layer.
Each concrete strategy inherits optional setup/teardown hooks. To fire
callback events from within a strategy, hold a reference to the trainer and
call trainer.callback_handler.call_event(...) directly. The
TrainingContext passed to strategies does not carry a callback_handler
attribute; only ClientContext (for client strategies) does.
Composing Trainers
ComposableTrainer accepts either concrete strategy instances or None for the defaults. You can start from plato.trainers.basic.Trainer (which simply wraps the defaults) and override only the pieces you need:
fromplato.trainers.composableimportComposableTrainerfromplato.trainers.strategies.training_stepimportGradientClippingStepStrategyclassClippedTrainer(ComposableTrainer):def__init__(self,*,model=None,callbacks=None,max_norm=1.0):super().__init__(model=model,callbacks=callbacks,training_step_strategy=GradientClippingStepStrategy(max_norm=max_norm),# All other strategies default to their standard implementations.)
See the references under plato.trainers.strategies for ready-made options
such as Scaffold, FedProx, FedDyn, and personalised-FL adaptation strategies.
FedNova is a server-side aggregation algorithm and lives under
plato.servers.strategies, not the trainer strategies.
config: the training configuration dictionary for the current round.
state: a plain dictionary for cross-strategy coordination at runtime.
Note that optimizer, lr_scheduler, and run_history are attributes of
ComposableTrainer itself, not of TrainingContext. The active data loader
is stored at context.state["train_loader"] during training. A metadata
dictionary exists on ClientContext (for client strategies) but not on
TrainingContext.
Prefer context.state for sharing transient values between strategies, and
trainer.run_history when you need to read or update per-epoch metrics from
callbacks.
Structured Evaluators and Trainer State
When [evaluation] is configured, ComposableTrainer.test_model(...) calls the
configured evaluator after the testing strategy finishes. The evaluator stores
its structured payload in TrainingContext.state, which is then consumed by the
server logger.
Important details:
summary metrics become evaluation_* CSV columns automatically;
detailed Lighteval task metrics are flattened into additional evaluation_*
columns;
when trainer.max_concurrency causes testing to run in a spawned subprocess,
Plato persists and restores the evaluator state so those metrics still reach
the parent server process.
fromplato.trainers.strategies.baseimportLossCriterionStrategy,TrainingContextimporttorchimporttorch.nnasnnclassMyCustomLossStrategy(LossCriterionStrategy):""" Custom loss strategy with L2 regularization. This strategy adds L2 regularization to the base loss. Args: weight: Regularization weight (default: 0.01) base_loss_fn: Base loss function (default: CrossEntropyLoss) Example: >>> strategy = MyCustomLossStrategy(weight=0.01) >>> trainer = ComposableTrainer(loss_strategy=strategy) """def__init__(self,weight=0.01,base_loss_fn=None):self.weight=weightself.base_loss_fn=base_loss_fnself._criterion=Nonedefsetup(self,context:TrainingContext):"""Initialize loss criterion."""ifself.base_loss_fnisNone:self._criterion=nn.CrossEntropyLoss()else:self._criterion=self.base_loss_fndefcompute_loss(self,outputs,labels,context):"""Compute loss with L2 regularization."""# Base lossbase_loss=self._criterion(outputs,labels)# L2 regularizationl2_reg=0.0forparamincontext.model.parameters():l2_reg+=torch.norm(param,p=2)returnbase_loss+self.weight*l2_reg
For infrastructure changes, such as logging, recording metrics, and stopping the training loop early, we tend to customize the training loop using callbacks instead. The advantage of using callbacks is that one can pass a list of multiple callbacks to the trainer when it is initialized, and they will be called in their order in the provided list. This helps when it is necessary to group features into different callback classes.
Within the implementation of these callback methods, one can access additional information about the training loop by using the trainer instance. For example, trainer.sampler can be used to access the sampler used by the train dataloader, trainer.trainloader can be used to access the current train dataloader, and trainer.current_epoch can be used to access the current epoch number.
To use callbacks, subclass the TrainerCallback class in plato.callbacks.trainer, and override the following methods, then pass it to the trainer when it is initialized, or call trainer.add_callbacks after initialization. For built-in trainers that user has no access to the initialization, one can also pass the trainer callbacks to client through parameter trainer_callbacks, which will be delivered to trainers later. Examples can be found in examples/callbacks.
on_train_run_start()
def on_train_run_start(self, trainer, config)
Override this method to complete additional tasks before the training loop starts.
trainer the trainer instance that activated this callback upon the occurrence of the corresponding event.
config the configuration settings used in the training loop. It corresponds directly to the trainer section in the configuration file.
Example:
defon_train_run_start(self,trainer,config):logging.info("[Client #%d] Loading the dataset with size %d.",trainer.client_id,len(list(trainer.sampler)),)
on_train_run_end()
def on_train_run_end(self, trainer, config)
Override this method to complete additional tasks after the training loop ends.
trainer the trainer instance that activated this callback upon the occurrence of the corresponding event.
config the configuration settings used in the training loop. It corresponds directly to the trainer section in the configuration file.
Example:
defon_train_run_end(self,trainer,config):logging.info("[Client #%d] Completed the training loop.",trainer.client_id)
on_train_epoch_start()
def on_train_epoch_start(self, trainer, config)
Override this method to complete additional tasks at the starting point of each training epoch.
trainer the trainer instance that activated this callback upon the occurrence of the corresponding event.
config the configuration settings used in the training loop. It corresponds directly to the trainer section in the configuration file.
Example:
deftrain_epoch_start(self,trainer,config):logging.info("[Client #%d] Started training epoch %d.",trainer.client_id,trainer.current_epoch)
on_train_epoch_end()
def on_train_epoch_end(self, trainer, config)
Override this method to complete additional tasks at the end of each training epoch.
trainer the trainer instance that activated this callback upon the occurrence of the corresponding event.
config the configuration settings used in the training loop. It corresponds directly to the trainer section in the configuration file.
Example:
defon_train_epoch_end(self,trainer,config):logging.info("[Client #%d] Finished training epoch %d.",trainer.client_id,trainer.current_epoch)
Override this method to complete additional tasks at the beginning of each step within a training epoch.
trainer the trainer instance that activated this callback upon the occurrence of the corresponding event.
config the configuration settings used in the training loop. It corresponds directly to the trainer section in the configuration file.
batch the index of the current batch of data that has just been processed in the current step.
Example:
defon_train_step_start(self,trainer,config,batch):logging.info("[Client #%d] Started training epoch %d batch %d.",trainer.client_id,trainer.current_epoch,batch)
Accessing and Customizing the Run History During Training
An instance of the plato.trainers.tracking.RunHistory class, called self.run_history, is used to store any number of performance metrics during the training process, one iterable list of values for each performance metric. By default, it stores the average loss values in each epoch.
The run history in the trainer can be accessed by the client as well, using self.trainer.run_history. It can also be read, updated, or reset in the hooks or callback methods. For example, in the implementation of some algorithms such as Oort, a per-step loss value needs to be stored by calling update_metric() in train_step_end():
When using the strategy pattern is no longer feasible, it is also possible to customize the training or testing procedure using subclassing, and overriding hook methods. To customize the training loop using subclassing, subclass the basic.Trainer class in plato.trainers, and override the following hook methods: