nvflare.app_common.workflows.scaffold module

class Scaffold(*args, enable_tensor_disk_offload: bool = False, **kwargs)[source]

Bases: BaseFedAvg

Controller for Scaffold Workflow. Note: This class is based on ModelController. Implements [SCAFFOLD](https://proceedings.mlr.press/v119/karimireddy20a.html).

Provides the implementations for the run routine, controlling the main workflow:
  • def run(self)

The parent classes provide the default implementations for other routines.

Parameters:
  • num_clients (int, optional) – The number of clients. Defaults to 3.

  • num_rounds (int, optional) – The total number of training rounds. Defaults to 5.

  • persistor_id (str, optional) – ID of the persistor component. Defaults to “persistor”.

  • ignore_result_error (bool or None, optional) – How to handle client result errors. - None: Dynamic mode (default) - ignore errors if min_responses still reachable, panic otherwise. - False: Strict mode - panic on any client error. - True: Resilient mode - always ignore client errors.

  • allow_empty_global_weights (bool, optional) – whether to allow empty global weights. Some pipelines can have empty global weights at first round, such that clients start training from scratch without any global info. Defaults to False.

  • memory_gc_rounds (int, optional) – Run memory cleanup (gc.collect + malloc_trim) every N rounds. Set to 0 to disable. Defaults to 0 (inherited from BaseFedAvg).

  • enable_tensor_disk_offload (bool, optional) – Download tensors to disk during FOBS streaming instead of holding them in memory, reducing server memory pressure for large models. Only applies to streamed PyTorch tensor payloads. Defaults to False.

The base controller for FedAvg Workflow. Note: This class is based on the ModelController.

Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629).

A model persistor can be configured via the persistor_id argument of the ModelController. The model persistor is used to load the initial global model which is sent to a list of clients. Each client sends it’s updated weights after local training which is aggregated. Next, the global model is updated. The model_persistor will also save the model after training.

Provides the default implementations for the follow routines:
  • def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel

  • def update_model(self, aggr_result)

The run routine needs to be implemented by the derived class:

  • def run(self)

Parameters:
  • num_clients (int, optional) – The number of clients. Defaults to 3. NOTE: this argument should not be here

  • release. (we will remove this argument in next)

  • num_rounds (int, optional) – The total number of training rounds. Defaults to 5.

  • start_round (int, optional) – The starting round number.

  • memory_gc_rounds (int, optional) – Run memory cleanup (gc.collect + malloc_trim) every N rounds. Set to 0 to disable. Defaults to 0 (disabled).

initialize(fl_ctx)[source]

Called by the framework to initialize the Learner object. This is called before the Learner can train or validate. This is called only once.

run() None[source]

Main run routine for the controller workflow.

scaffold_aggregate_fn(results: List[FLModel]) FLModel[source]