Split Learning Algorithms
Split Learning
Split learning aims to collaboratively train deep learning models with the server performing a portion of the training process. In split learning, each training iteration is separated into two phases: the clients first send extracted features at a specific cut layer to the server, and then the server continues the forward pass and computes gradients, which will be sent back to the clients to complete the backward pass of the training. Unlike federated learning, split learning clients sequentially interact with the server, and the global model is synchronized implicitly through the model on the server side, which is shared and updated by all clients.
uv run plato.py -c configs/CIFAR10/split_learning_resnet18.toml
Reference: Vepakomma et al., "Split Learning for Health: Distributed Deep Learning without Sharing Raw Patient Data," in Proc. NeurIPS, 2018.
Split Learning for Training LLM
This is an example of fine-tuning a Hugging Face large language model with split learning. One can fine-tune the entire model, or with the LoRA algorithm in a parameter-efficient fashion. The cut layer in the configuration file should be set as an integer, indicating cutting at which transformer block in the transformer model.
To fine-tune the entire model:
cd examples/split_learning/llm_split_learning
uv run split_learning_main.py -c split_learning_wikitext2_gpt2.toml
To fine-tune with LoRA:
cd examples/split_learning/llm_split_learning
uv run split_learning_main.py -c split_learning_wikitext2_gpt2_lora.toml