nvflare.app_common.ccwf.swarm_client_ctl module

class Gatherer(task_data: Shareable, fl_ctx: FLContext, for_round: int, executor: ClientSideController, aggregator: Aggregator, metric_comparator: MetricComparator, all_clients: list, trainers: list, min_responses_required: int, wait_time_after_min_resps_received: float, timeout, max_concurrent_submissions: int = 1)[source]

Bases: FLComponent

Init FLComponent.

The FLComponent is the base class of all FL Components. (executors, controllers, responders, filters, aggregators, and widgets are all FLComponents)

FLComponents have the capability to handle and fire events and contain various methods for logging.

aggregate()[source]
can_accept_submission(client_name: str, result: Shareable, fl_ctx: FLContext) str[source]
gather(client_name: str, result: Shareable, fl_ctx: FLContext) Shareable[source]
is_done()[source]
class SwarmClientController(task_name_prefix='swarm', learn_task_name='train', persistor_id='persistor', shareable_generator_id='shareable_generator', aggregator_id='aggregator', metric_comparator_id=None, learn_task_check_interval=1.0, learn_task_abort_timeout=5.0, learn_task_ack_timeout=10, learn_task_timeout=None, final_result_ack_timeout=10, min_responses_required: int = 1, wait_time_after_min_resps_received: float = 10.0, request_to_submit_result_max_wait=None, request_to_submit_result_msg_timeout=5.0, request_to_submit_result_interval: float = 1.0, max_concurrent_submissions: int = 1, memory_gc_rounds: int = 1, cuda_empty_cache: bool = False)[source]

Bases: ClientSideController

Constructor of a ClientSideController object.

Parameters:
  • task_name_prefix – prefix of task names. All CCWF task names are prefixed with this.

  • learn_task_name – name for the Learning Task (LT)

  • persistor_id – ID of the persistor component

  • shareable_generator_id – ID of the shareable generator component

  • aggregator_id – ID of the aggregator

  • metric_comparator_id – ID of metric comparator to be used for determining best model. If not specified, the default NumberMetricComparator is used.

  • learn_task_check_interval – interval for checking incoming Learning Task (LT)

  • learn_task_ack_timeout – timeout for sending the LT to other client(s)

  • learn_task_timeout – max time allowed for a training task

  • final_result_ack_timeout – timeout for sending final result to participating clients

  • learn_task_abort_timeout – time to wait for the LT to become stopped after aborting it

  • min_responses_required – minimum number of responses required for the aggregation

  • wait_time_after_min_resps_received – how long to wait after min responses (but not all responses) are received.

  • request_to_submit_result_max_wait – max amount of time to wait for the permission from the aggregation client. If the permission is not received within this period of time, the training result will not be submitted. If this value is not specified (None), then the training client will keep trying forever.

  • request_to_submit_result_msg_timeout – the timeout for “submission request” message. Since submission req is a tiny message, this timeout value should be small.

  • request_to_submit_result_interval – interval between requests to submit result.

  • max_concurrent_submissions – max number of concurrent submissions allowed on the aggregation client.

  • memory_gc_rounds – run gc.collect() + malloc_trim on the aggregator every N FL rounds. Defaults to 1 (every round) to match legacy behavior where gc.collect() was called unconditionally after each trainer submission. Set to 0 to disable.

  • cuda_empty_cache – also call torch.cuda.empty_cache() during aggregator-side cleanup. In swarm learning the aggregator runs on the same client as the trainer, so GPU memory may be relevant. Defaults to False.

Note that if the max_concurrent_submissions is set to 1, it practically means that all training results will be submitted to the aggregation client sequentially. This lowers the resource pressure on the aggr client, but makes the overall training process longer. The value of request_to_submit_result_max_wait, if specified, should be long enough to allow the aggr client sufficient time to process training results.

do_learn_task(name: str, task_data: Shareable, fl_ctx: FLContext, abort_signal: Signal)[source]

This is called to do a Learn Task. Subclass must implement this method.

Parameters:
  • name – task name

  • data – task data

  • fl_ctx – FL context of the task

  • abort_signal – abort signal for the task execution

Returns:

execute(task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) Shareable[source]

Executes a task.

Parameters:
  • task_name (str) – task name.

  • shareable (Shareable) – input shareable.

  • fl_ctx (FLContext) – fl context.

  • abort_signal (Signal) – signal to check during execution to determine whether this task is aborted.

Returns:

An output shareable.

handle_event(event_type: str, fl_ctx: FLContext)[source]

Handles events.

Parameters:
  • event_type (str) – event type fired by workflow.

  • fl_ctx (FLContext) – FLContext information.

process_config(fl_ctx: FLContext)[source]

This is called to allow the subclass to process config props.

Returns: None

start_run(fl_ctx: FLContext)[source]
start_workflow(shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) Shareable[source]

This is called for the subclass to start the workflow. This only happens on the starting_client.

Parameters:
  • shareable – the initial task data (e.g. initial model weights)

  • fl_ctx – FL context

  • abort_signal – abort signal for task execution

Returns: