nvflare.app_common.workflows.scatter_and_gather_scaffold module

class ScatterAndGatherScaffold(min_clients: int = 1000, num_rounds: int = 5, start_round: int = 0, wait_time_after_min_received: int = 10, aggregator_id='aggregator', persistor_id='persistor', shareable_generator_id='shareable_generator', train_task_name='train', train_timeout: int = 0, ignore_result_error: bool = False, task_check_period: float = 0.5, persist_every_n_rounds: int = 1, snapshot_every_n_rounds: int = 1)[source]

Bases: ScatterAndGather

The controller for ScatterAndGatherScaffold workflow.

The model persistor (persistor_id) is used to load the initial global model which is sent to all clients. Each client sends it’s updated weights after local training which is aggregated (aggregator_id). The shareable generator is used to convert the aggregated weights to shareable and shareable back to weight. The model_persistor also saves the model after training.

Parameters:
  • min_clients (int, optional) – The minimum number of clients responses before SAG starts to wait for wait_time_after_min_received. Note that SAG will move forward when all available clients have responded regardless of this value. Defaults to 1000.

  • num_rounds (int, optional) – The total number of training rounds. Defaults to 5.

  • start_round (int, optional) – Start round for training. Defaults to 0.

  • wait_time_after_min_received (int, optional) – Time to wait before beginning aggregation after minimum number of clients responses has been received. Defaults to 10.

  • aggregator_id (str, optional) – ID of the aggregator component. Defaults to “aggregator”.

  • persistor_id (str, optional) – ID of the persistor component. Defaults to “persistor”.

  • shareable_generator_id (str, optional) – ID of the shareable generator. Defaults to “shareable_generator”.

  • train_task_name (str, optional) – Name of the train task. Defaults to “train”.

  • train_timeout (int, optional) – Time to wait for clients to do local training.

  • ignore_result_error (bool, optional) – whether this controller can proceed if client result has errors. Defaults to False.

  • task_check_period (float, optional) – interval for checking status of tasks. Defaults to 0.5.

  • persist_every_n_rounds (int, optional) – persist the global model every n rounds. Defaults to 1. If n is 0 then no persist.

  • snapshot_every_n_rounds (int, optional) – persist the server state every n rounds. Defaults to 1. If n is 0 then no persist.

control_flow(abort_signal: Signal, fl_ctx: FLContext) None[source]

This is the control logic for the RUN.

NOTE: this is running in a separate thread, and its life is the duration of the RUN.

Parameters:
  • fl_ctx – the FL context

  • abort_signal – the abort signal. If triggered, this method stops waiting and returns to the caller.

start_controller(fl_ctx: FLContext) None[source]

Starts the controller.

This method is called at the beginning of the RUN.

Parameters:
  • fl_ctx – the FL context. You can use this context to access services provided by the

  • example (framework. For)

  • your (you can get Command Register from it and register)

  • modules. (admin command)