Skip to content

Customized Client Training Loops

SCAFFOLD

SCAFFOLD is a synchronous federated learning algorithm that performs server aggregation with control variates to better handle statistical heterogeneity. It has been quite widely cited and compared with in the federated learning literature. In this example, two processors, called ExtractControlVariatesProcessor and SendControlVariateProcessor, have been introduced to the client using a callback class, called ScaffoldCallback. They are used for sending control variates between the clients and the server. Each client also tries to maintain its own control variates for local optimization using files.

cd examples/customized_client_training
uv run scaffold/scaffold.py -c scaffold/scaffold_MNIST_lenet5.toml

Reference: Karimireddy et al., "SCAFFOLD: Stochastic Controlled Averaging for Federated Learning," in Proc. International Conference on Machine Learning (ICML), 2020.

Alignment with the paper

The callbacks wire Δci\Delta c_i through the payload exactly as Algorithm 1 prescribes: clients attach their control-variate deltas in examples/customized_client_training/scaffold/scaffold_callback.py:33-82, the server strips them off and averages c=c+(1/m)Δcic = c + (1/m) * \sum \Delta c_i in examples/customized_client_training/scaffold/scaffold_server.py:34-53, and the updated server control variate is sent back in the next payload.

On each client round, plato/trainers/strategies/algorithms/scaffold_strategy.py:190-345 applies the correction w=wη(g+cci)w = w - \eta * (g + c - c_i) after every optimizer step, and recomputes ci,new=c(xlocalxglobal)/(ητ)c_{i, \textrm{new}} = c - (x_{\textrm{local}} - x_{\textrm{global}}) / (\eta * \tau) before emitting Δci\Delta c_i, mirroring the Option 2 formula that Karimireddy et al. (2020) derive for SCAFFOLD control variates.

Because the paper was released without official source code, the Plato example persists the same state transitions defined in Algorithm 1 via examples/customized_client_training/scaffold/scaffold_client.py:23-64, yielding the message flow (server cc, client Δci\Delta c_i) required for the theoretical convergence guarantees.


FedProx

To better handle system heterogeneity, the FedProx algorithm introduced a proximal term in the optimizer used by local training on the clients. It has been quite widely cited and compared with in the federated learning literature.

cd examples/customized_client_training
uv run fedprox/fedprox.py -c fedprox/fedprox_MNIST_lenet5.toml

Reference: T. Li, A. K. Sahu, M. Zaheer, M. Sanjabi, A. Talwalkar, V. Smith. "Federated Optimization in Heterogeneous Networks," in Proc. Machine Learning and Systems (MLSys), 2020.

Alignment with the paper

plato/trainers/strategies/algorithms/fedprox_strategy.py:111-193 snapshots the global iterate wtw^t at round start and augments the loss with (μ/2)wwt(\mu / 2) * ||w - w^t||, which is the FedProx objective hk(w;wt)=Fk(w)+(mu/2)wwt2h_k(w; w^t) = F_k(w) + (mu / 2) * ||w - w^t||^2 defined in Section 3 of Li et al. (2020). Autograd therefore produces the perturbed-gradient step without requiring a bespoke optimizer.

The config-aware wrapper FedProxLossStrategyFromConfig (plato/trainers/strategies/algorithms/fedprox_strategy.py:208-247) reads μ\mu from the same knobs (clients.proximal_term_penalty_constant / algorithm.fedprox_mu) that the paper exposes in Algorithms 1 and 2, so experiments reproduce the authors' hyperparameter schedules.

The reference TensorFlow release (litian96/FedProx/flearn/optimizer/pgd.py#L27-L92) applies an identical perturbation, computing g+μ(wwt)g + \mu * (w - w^t) before the gradient step; Plato mirrors that logic in PyTorch by letting the proximal penalty backpropagate through the loss term, yielding a line-for-line correspondence with Perturbed Gradient Descent.


FedDyn

FedDyn is proposed to provide communication savings by dynamically updating each participating device's regularizer in each round of training. It is a method proposed to solve data heterogeneity in federated learning.

cd examples/customized_client_training
uv run feddyn/feddyn.py -c feddyn/feddyn_MNIST_lenet5.toml

Reference: Acar, D.A.E., Zhao, Y., Navarro, R.M., Mattina, M., Whatmough, P.N. and Saligrama, V. "Federated learning based on dynamic regularization," Proceedings of International Conference on Learning Representations (ICLR), 2021.

Alignment with the paper

The loss strategy plato/trainers/strategies/algorithms/feddyn_strategy.py:148-205 evaluates Lk(w)+α<w,wglobal+hk>+(α/2)wwglobal2L_k(w) + \alpha * <w, -w_{\textrm{global}} + h_k> + (\alpha / 2) * ||w - w_{\textrm{global}}||^2, exactly the dynamic-regularization objective introduced in Section 3 of Acar et al. (2021), with _get_alpha_coefficient reproducing the client-weighted scaling discussed beneath that formulation.

FedDynUpdateStrategy.on_train_end (plato/trainers/strategies/algorithms/feddyn_strategy.py:284-317) updates the cumulative gradient state via hk=hk+(wkwglobal)h_k = h_k + (w_k - w_{\textrm{global}}) before persisting it, which is the recursion that Algorithm 1 relies on to couple successive local solutions.

The authors' PyTorch implementation (alpemreacar/FedDyn/utils_methods.py#L286-L399) performs the same bookkeeping - after every client run it accumulates curr_model_par - cld_mdl_param into local_param_list and averages the corrected weights - demonstrating that Plato's composable trainer follows the released reference code step for step.


MOON

MOON (Model-Contrastive Federated Learning) enhances standard FedAvg by adding a model-level contrastive regularizer. Each client augments the shared model with a projection head, clones the incoming global model as a positive anchor, and reuses a small buffer of its historical checkpoints as negatives. The server still performs sample-weighted averaging but records a short history of global states for downstream analysis or warm restarts.

cd examples/server_aggregation/moon/
uv run moon.py -c moon_MNIST_lenet5.toml

Key configuration parameters:

  • algorithm.mu: Weight assigned to the contrastive term (default: 5.0).
  • algorithm.temperature: Softmax temperature applied to cosine similarities (default: 0.5).
  • algorithm.history_size: Number of historical local models cached per client as negatives (default: 2).
  • trainer.model_name: Name used for checkpointing the projection-ready backbone (default: moon_lenet5).

Reference: Qinbin Li, Bingsheng He, Dawn Song. “Model-Contrastive Federated Learning,” in Proc. CVPR, 2021.

Alignment with the paper

Here’s how Plato's implementation lines up with Li et al. (CVPR 2021) and the authors’ reference implementation:

  • Projection head & representations – moon_model.py:31-79 implements the LeNet-style backbone plus a two-layer projection head, returning both logits and L2-normalised embeddings. The paper’s Eq. (3) (and typical contrastive-learning practice) calls for that projection step; the public repo’s simple CNN head even hints at it (they keep the projection MLP commented out). So keeping the projection in our model is faithful and helps the cosine similarities stay well behaved.

  • Local training objective – moon_trainer.py:26-152 combines the supervised cross-entropy with the temperature-scaled contrastive loss exactly like Eq. (1): positives come from the frozen global model, negatives from the stored local-history models, using the same μ\mu and τ\tau hyper-parameters exposed in the config (moon_MNIST_lenet5.toml:41-45). This mirrors train_net_fedcon in the reference implementation, which also weights the contrastive term by μ\mu and uses CrossEntropy on logits built from cosine similarities.

  • Historical model buffer – the client keeps a FIFO queue of past local checkpoints (moon_client.py:21-64), equivalent to model_buffer_size in the paper and the author's reference implementation; that buffer is fed into the trainer through the strategy context so MOON always has negatives available.

  • Server aggregation – the server still performs sample-weighted FedAvg (moon_server.py:12-35, moon_server_strategy.py:19-63), matching the MOON design which leaves the aggregation rule unchanged. The extra global-history deque is bookkeeping-only.

  • Shared architecture – moon.py:8-15 now instantiates MoonModel once and passes it into both the client and server (model=model). That guarantees the projection-enabled architecture is shared exactly, as required for the contrastive comparisons.

The only intentional deviation is that we L2-normalise the projection outputs before computing cosine similarities (moon_model.py:76-79), which the paper assumes implicitly and improves stability. Aside from that, the workflow, hyper-parameters, and loss all line up with the CVPR paper and the publicly released PyTorch reference.


FedMoS

FedMoS is a communication-efficient FL framework with coupled double momentum-based update and adaptive client selection, to jointly mitigate the intrinsic variance.

cd examples/customized_client_training
uv run fedmos/fedmos.py -c fedmos/fedmos_MNIST_lenet5.toml

Reference: X. Wang, Y. Chen, Y. Li, X. Liao, H. Jin and B. Li, "FedMoS: Taming Client Drift in Federated Learning with Double Momentum and Adaptive Selection," IEEE INFOCOM 2023.

Alignment with the paper

plato/trainers/strategies/algorithms/fedmos_strategy.py:104-205 implements FedMoS double-momentum update by first computing dt=gt+(1a)(dt1gt1)d_t = g_t + (1 - a) * (d_{t-1} - g_{t-1}) and then stepping w=(1μ)wηdt+μwglobalw = (1 - \mu) * w - \eta * d_t + \mu * w_{\textrm{global}}; these are the same recursions described in Algorithm 1 of Wang et al. (2023).

The training loop enforces the paper's sequencing: FedMosStepStrategy.training_step (plato/trainers/strategies/algorithms/fedmos_strategy.py:487-538) calls update_momentum() immediately after backward() and passes the cached global model from FedMosUpdateStrategy.on_train_start (plato/trainers/strategies/algorithms/fedmos_strategy.py:329-347) into the optimizer step so the proximal pull uses the broadcast parameters from the server.

The official repository (Distributed-Learning-Networking-Group/FedMoS/optimizers/fedoptimizer.py#L27-L92) mirrors the same gradient-difference momentum and proximal correction, confirming the one-to-one correspondence between the Plato optimizer and the authors' release.