nvflare.app_common.ccwf.swarm_client_ctl module
- class Gatherer(task_data: Shareable, fl_ctx: FLContext, for_round: int, executor: ClientSideController, aggregator: Aggregator, metric_comparator: MetricComparator, all_clients: list, trainers: list, min_responses_required: int, wait_time_after_min_resps_received: float, timeout, max_concurrent_submissions: int = 1)[source]
Bases:
FLComponentInit FLComponent.
The FLComponent is the base class of all FL Components. (executors, controllers, responders, filters, aggregators, and widgets are all FLComponents)
FLComponents have the capability to handle and fire events and contain various methods for logging.
- class SwarmClientController(task_name_prefix='swarm', learn_task_name='train', persistor_id='persistor', shareable_generator_id='shareable_generator', aggregator_id='aggregator', metric_comparator_id=None, learn_task_check_interval=1.0, learn_task_abort_timeout=5.0, learn_task_ack_timeout=10, learn_task_timeout=None, final_result_ack_timeout=10, min_responses_required: int = 1, wait_time_after_min_resps_received: float = 10.0, request_to_submit_result_max_wait=None, request_to_submit_result_msg_timeout=5.0, request_to_submit_result_interval: float = 1.0, max_concurrent_submissions: int = 1, memory_gc_rounds: int = 1, cuda_empty_cache: bool = False)[source]
Bases:
ClientSideControllerConstructor of a ClientSideController object.
- Parameters:
task_name_prefix – prefix of task names. All CCWF task names are prefixed with this.
learn_task_name – name for the Learning Task (LT)
persistor_id – ID of the persistor component
shareable_generator_id – ID of the shareable generator component
aggregator_id – ID of the aggregator
metric_comparator_id – ID of metric comparator to be used for determining best model. If not specified, the default NumberMetricComparator is used.
learn_task_check_interval – interval for checking incoming Learning Task (LT)
learn_task_ack_timeout – timeout for sending the LT to other client(s)
learn_task_timeout – max time allowed for a training task
final_result_ack_timeout – timeout for sending final result to participating clients
learn_task_abort_timeout – time to wait for the LT to become stopped after aborting it
min_responses_required – minimum number of responses required for the aggregation
wait_time_after_min_resps_received – how long to wait after min responses (but not all responses) are received.
request_to_submit_result_max_wait – max amount of time to wait for the permission from the aggregation client. If the permission is not received within this period of time, the training result will not be submitted. If this value is not specified (None), then the training client will keep trying forever.
request_to_submit_result_msg_timeout – the timeout for “submission request” message. Since submission req is a tiny message, this timeout value should be small.
request_to_submit_result_interval – interval between requests to submit result.
max_concurrent_submissions – max number of concurrent submissions allowed on the aggregation client.
memory_gc_rounds – run gc.collect() + malloc_trim on the aggregator every N FL rounds. Defaults to 1 (every round) to match legacy behavior where gc.collect() was called unconditionally after each trainer submission. Set to 0 to disable.
cuda_empty_cache – also call torch.cuda.empty_cache() during aggregator-side cleanup. In swarm learning the aggregator runs on the same client as the trainer, so GPU memory may be relevant. Defaults to False.
Note that if the max_concurrent_submissions is set to 1, it practically means that all training results will be submitted to the aggregation client sequentially. This lowers the resource pressure on the aggr client, but makes the overall training process longer. The value of request_to_submit_result_max_wait, if specified, should be long enough to allow the aggr client sufficient time to process training results.
- do_learn_task(name: str, task_data: Shareable, fl_ctx: FLContext, abort_signal: Signal)[source]
This is called to do a Learn Task. Subclass must implement this method.
- Parameters:
name – task name
data – task data
fl_ctx – FL context of the task
abort_signal – abort signal for the task execution
Returns:
- execute(task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) Shareable[source]
Executes a task.
- Parameters:
task_name (str) – task name.
shareable (Shareable) – input shareable.
fl_ctx (FLContext) – fl context.
abort_signal (Signal) – signal to check during execution to determine whether this task is aborted.
- Returns:
An output shareable.
- handle_event(event_type: str, fl_ctx: FLContext)[source]
Handles events.
- Parameters:
event_type (str) – event type fired by workflow.
fl_ctx (FLContext) – FLContext information.
- process_config(fl_ctx: FLContext)[source]
This is called to allow the subclass to process config props.
Returns: None
- start_workflow(shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) Shareable[source]
This is called for the subclass to start the workflow. This only happens on the starting_client.
- Parameters:
shareable – the initial task data (e.g. initial model weights)
fl_ctx – FL context
abort_signal – abort signal for task execution
Returns: