nvflare.app_common.workflows.scatter_and_gather module¶
- class ScatterAndGather(min_clients: int = 1, num_rounds: int = 5, start_round: int = 0, wait_time_after_min_received: int = 10, aggregator_id='aggregator', persistor_id='persistor', shareable_generator_id='shareable_generator', train_task_name='train', train_timeout: int = 0, ignore_result_error: bool = False, allow_empty_global_weights: bool = False, task_check_period: float = 0.5, persist_every_n_rounds: int = 1, snapshot_every_n_rounds: int = 1)[source]¶
Bases:
Controller
The controller for ScatterAndGather Workflow.
The ScatterAndGather workflow defines FederatedAveraging on all clients. The model persistor (persistor_id) is used to load the initial global model which is sent to all clients. Each client sends it’s updated weights after local training which is aggregated (aggregator_id). The shareable generator is used to convert the aggregated weights to shareable and shareable back to weight. The model_persistor also saves the model after training.
- Parameters:
min_clients (int, optional) – Min number of clients in training. Defaults to 1.
num_rounds (int, optional) – The total number of training rounds. Defaults to 5.
start_round (int, optional) – Start round for training. Defaults to 0.
wait_time_after_min_received (int, optional) – Time to wait before beginning aggregation after contributions received. Defaults to 10.
aggregator_id (str, optional) – ID of the aggregator component. Defaults to “aggregator”.
persistor_id (str, optional) – ID of the persistor component. Defaults to “persistor”.
shareable_generator_id (str, optional) – ID of the shareable generator. Defaults to “shareable_generator”.
train_task_name (str, optional) – Name of the train task. Defaults to “train”.
train_timeout (int, optional) – Time to wait for clients to do local training.
ignore_result_error (bool, optional) – whether this controller can proceed if client result has errors. Defaults to False.
allow_empty_global_weights (bool, optional) – whether to allow empty global weights. Some pipelines can have empty global weights at first round, such that clients start training from scratch without any global info. Defaults to False.
task_check_period (float, optional) – interval for checking status of tasks. Defaults to 0.5.
persist_every_n_rounds (int, optional) – persist the global model every n rounds. Defaults to 1. If n is 0 then no persist.
snapshot_every_n_rounds (int, optional) – persist the server state every n rounds. Defaults to 1. If n is 0 then no persist.
- Raises:
TypeError – when any of input arguments does not have correct type
ValueError – when any of input arguments is out of range
- control_flow(abort_signal: Signal, fl_ctx: FLContext) None [source]¶
This is the control logic for the RUN.
NOTE: this is running in a separate thread, and its life is the duration of the RUN.
- Parameters:
fl_ctx – the FL context
abort_signal – the abort signal. If triggered, this method stops waiting and returns to the caller.
- get_persist_state(fl_ctx: FLContext) dict [source]¶
Generate data from state to be persisted.
- Parameters:
fl_ctx – FLContext
- Returns:
A dict serializable persist data
- handle_event(event_type: str, fl_ctx: FLContext)[source]¶
Called when events are fired.
- Parameters:
event_type (str) – all event types, including AppEventType and EventType
fl_ctx (FLContext) – FLContext information with current event type
- process_result_of_unknown_task(client: Client, task_name, client_task_id, result: Shareable, fl_ctx: FLContext) None [source]¶
Process result when no task is found for it.
This is called when a result submission is received from a client, but no standing task can be found for it (from the task queue)
This could happen when: - the client’s submission is too late - the task is already completed - the Controller lost the task, e.g. the Server is restarted
- Parameters:
client – the client that the result comes from
task_name – the name of the task
client_task_id – ID of the task
result – the result from the client
fl_ctx – the FL context that comes with the client’s submission
- restore(state_data: dict, fl_ctx: FLContext)[source]¶
Restore the state from persisted data.
- Parameters:
state_data – serialized persist data
fl_ctx – FLContext
- start_controller(fl_ctx: FLContext) None [source]¶
Starts the controller.
This method is called at the beginning of the RUN.
- Parameters:
fl_ctx – the FL context. You can use this context to access services provided by the
example (framework. For) –
your (you can get Command Register from it and register) –
modules. (admin command) –
- stop_controller(fl_ctx: FLContext)[source]¶
Stops the controller.
This method is called right before the RUN is ended.
- Parameters:
fl_ctx – the FL context. You can use this context to access services provided by the
example (framework. For) –
your (you can get Command Register from it and unregister) –
modules. (admin command) –