Servers

Customizing servers using inheritance

The common practice is to customize the server using inheritance for important features that change the state of the server. To customize the server using inheritance, subclass the fedavg.Server (or fedavg_cs.Server for cross-silo federated learning) class in plato.servers, and override the following methods:

configure()

def configure(self) -> None

Override this method to implement additional tasks for initializing and configuring the server. Make sure that super().configure() is called first.

Example:

def configure(self) -> None:
    """Configure the model information like weight shapes and parameter numbers."""
    super().configure()

    self.total_rounds = Config().trainer.rounds

init_trainer()

def init_trainer(self) -> None

Override this method to implement additional tasks for initializing and configuring the trainer. Make sure that super().init_trainer() is called first.

Example (from examples/knot/knot_server.py):

def init_trainer(self) -> None:
    """Load the trainer and initialize the dictionary that maps cluster IDs to client IDs."""
    super().init_trainer()

    self.algorithm.init_clusters(self.clusters)

choose_clients()

def choose_clients(self, clients_pool, clients_count)

Override this method to implement a customized client selection algorithm, choosing a subset of clients from the client pool.

clients_pool a list of available clients for selection.

clients_count the number of clients that need to be selected in this round.

Returns: a list of selected client IDs.

weights_received()

def weights_received(self, weights_received)

Override this method to complete additional tasks after the updated weights have been received.

weights_received the updated weights that have been received from the clients.

Example:

def weights_received(self, weights_received):
    """
    Event called after the updated weights have been received.
    """
    self.control_variates_received = [weight[1] for weight in weights_received]
    return [weight[0] for weight in weights_received]

aggregate_deltas()

async def aggregate_deltas(self, updates, deltas_received)

In most cases, it is more convenient to aggregate the model deltas from the clients, because this can be performed in a framework-agnostic fashion. Override this method to aggregate the deltas received. This method is needed if aggregate_weights() (below) is not defined.

updates the client updates received at the server.

deltas_received the weight deltas received from the clients.

aggregate_weights()

async def aggregate_weights(self, updates, baseline_weights, weights_received)

Sometimes it is more convenient to aggregate the received model weights directly to the global model. In this case, override this method to aggregate the weights received directly to baseline weights. This method is optional, and the server will call this method rather than aggregate_deltas when it is defined. Refer to examples/fedasync/fedasync_server.py for an example.

updates the client updates received at the server.

baseline_weights the current weights in the global model.

weights_received the weights received from the clients.

weights_aggregated()

def weights_aggregated(self, updates)

Override this method to complete additional tasks after aggregating weights.

updates the client updates received at the server.

customize_server_response()

def customize_server_response(self, server_response: dict, client_id) -> dict

Override this method to return a customize server response with any additional information.

server_response key-value pairs (from a string to an instance) for the server response before customization.

client_id the client ID.

Example:

def customize_server_response(self, server_response: dict, client_id) -> dict:
    """
    Customizes the server response with any additional information.
    """
    server_response["pruning_amount"] = self.pruning_amount_list
    return server_response

customize_server_payload()

def customize_server_payload(self, payload)

Override this method to customize the server payload before sending it to the clients.

Returns: Customized server payload to be sent to the clients.

clients_selected()

def clients_selected(self, selected_clients) -> None

Override this method to complete additional tasks after clients have been selected in each round.

selected_clients a list of client IDs that have just been selected by the server.

clients_processed()

def clients_processed(self) -> None

Override this method to complete additional tasks after all client reports have been processed.

get_logged_items()

def get_logged_items(self) -> dict

Override this method to return items to be logged by the LogProgressCallback class in a .csv file.

Returns: a dictionary of items to be logged.

Example: (from examples/knot/knot_server)

def get_logged_items(self):
    """Get items to be logged by the LogProgressCallback class in a .csv file."""
    logged_items = super().get_logged_items()

    clusters_accuracy = [
        self.clustered_test_accuracy[cluster_id]
        for cluster_id in range(self.num_clusters)
    ]

    clusters_accuracy = "; ".join([str(acc) for acc in clusters_accuracy])

    logged_items["clusters_accuracy"] = clusters_accuracy

    return logged_items

should_request_update()

def should_request_update(self, client_id, start_time, finish_time, client_staleness, report):

Override this method to save additional information when the server saves checkpoints at the end of each around.

client_id The client ID for the client to be considered.

start_time The wall-clock time when the client started training.

finish_time The wall-clock time when the client finished training.

client_staleness The number of rounds that elapsed since this client started training.

report The report sent by the client.

Returns: True if the server should explicitly request an update from the client client_id; False otherwise.

Example: (from servers/base.py)

    def should_request_update(
        self, client_id, start_time, finish_time, client_staleness, report
    ):
        """Determines if an explicit request for model update should be sent to the client."""
        return client_staleness > self.staleness_bound and finish_time > self.wall_time

save_to_checkpoint()

def save_to_checkpoint(self) -> None

Override this method to save additional information when the server saves checkpoints at the end of each around.

training_will_start()

def training_will_start(self) -> None

Override this method to complete additional tasks before selecting clients for the first round of training.

periodic_task()

periodic_task(self) -> None

Override this async method to perform periodic tasks in asynchronous mode, where this method will be called periodically.

wrap_up()

async def wrap_up(self) -> None

Override this method to complete additional tasks at the end of each round.

server_will_close()

def server_will_close(self) -> None:

Override this method to complete additional tasks before closing the server.

Customizing servers using callbacks

For infrastructure changes, such as logging and recording metrics, we tend to customize the global training process using callbacks instead. The advantage of using callbacks is that one can pass a list of multiple callbacks to the server when it is initialized, and they will be called in their order in the provided list. This helps when it is necessary to group features into different callback classes.

Within the implementation of these callback methods, one can access additional information about the global training by using the server instance.

To use callbacks, subclass the ServerCallback class in plato.callbacks.server, and override the following methods, then pass it to the server when it is initialized, or call server.add_callbacks after initialization. Examples can be found in examples/callbacks.

on_weights_received()

def on_weights_received(self, server, weights_received)

Override this method to complete additional tasks after the updated weights have been received.

weights_received the updated weights that have been received from the clients.

on_weights_aggregated()

def on_weights_aggregated(self, server, updates)

Override this method to complete additional tasks after aggregating weights.

updates the client updates received at the server.

Example:

def on_weights_aggregated(self, server, updates):
    logging.info("[Server #%s] Finished aggregating weights.", os.getpid())

on_clients_selected()

def on_clients_selected(self, server, selected_clients)

Override this method to complete additional tasks after clients have been selected in each round.

selected_clients a list of client IDs that have just been selected by the server.

on_clients_processed()

def on_clients_processed(self, server)

Override this method to complete additional tasks after all client reports have been processed.

on_training_will_start()

def on_training_will_start(self, server)

Override this method to complete additional tasks before selecting clients for the first round of training.

on_server_will_close()

def on_server_will_close(self, server)

Override this method to complete additional tasks before closing the server.

Example:

def on_server_will_close(self, server):
    logging.info("[Server #%s] Closing the server.", os.getpid())