nvflare.app_common.workflows.scaffold module

class Scaffold(*args, num_clients: int = 3, num_rounds: int = 5, start_round: int = 0, memory_gc_rounds: int = 0, **kwargs)[source]

Bases: BaseFedAvg

Controller for Scaffold Workflow. Note: This class is based on ModelController. Implements [SCAFFOLD](https://proceedings.mlr.press/v119/karimireddy20a.html).

Provides the implementations for the run routine, controlling the main workflow:
  • def run(self)

The parent classes provide the default implementations for other routines.

Parameters:
  • num_clients (int, optional) – The number of clients. Defaults to 3.

  • num_rounds (int, optional) – The total number of training rounds. Defaults to 5.

  • persistor_id (str, optional) – ID of the persistor component. Defaults to “persistor”.

  • ignore_result_error (bool or None, optional) – How to handle client result errors. - None: Dynamic mode (default) - ignore errors if min_responses still reachable, panic otherwise. - False: Strict mode - panic on any client error. - True: Resilient mode - always ignore client errors.

  • allow_empty_global_weights (bool, optional) – whether to allow empty global weights. Some pipelines can have empty global weights at first round, such that clients start training from scratch without any global info. Defaults to False.

  • memory_gc_rounds (int, optional) – Run memory cleanup (gc.collect + malloc_trim) every N rounds. Set to 0 to disable. Defaults to 0 (inherited from BaseFedAvg).

The base controller for FedAvg Workflow. Note: This class is based on the ModelController.

Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629).

A model persistor can be configured via the persistor_id argument of the ModelController. The model persistor is used to load the initial global model which is sent to a list of clients. Each client sends it’s updated weights after local training which is aggregated. Next, the global model is updated. The model_persistor will also save the model after training.

Provides the default implementations for the follow routines:
  • def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel

  • def update_model(self, aggr_result)

The run routine needs to be implemented by the derived class:

  • def run(self)

Parameters:
  • num_clients (int, optional) – The number of clients. Defaults to 3. NOTE: this argument should not be here

  • release. (we will remove this argument in next)

  • num_rounds (int, optional) – The total number of training rounds. Defaults to 5.

  • start_round (int, optional) – The starting round number.

  • memory_gc_rounds (int, optional) – Run memory cleanup (gc.collect + malloc_trim) every N rounds. Set to 0 to disable. Defaults to 0 (disabled).

initialize(fl_ctx)[source]

Called by the framework to initialize the Learner object. This is called before the Learner can train or validate. This is called only once.

run() None[source]

Main run routine for the controller workflow.

scaffold_aggregate_fn(results: List[FLModel]) FLModel[source]