# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import random
import threading
import time
from nvflare.apis.controller_spec import Task
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import FLContextKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import ReservedHeaderKey, Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.app_common.abstract.aggregator import Aggregator
from nvflare.app_common.abstract.learnable import Learnable
from nvflare.app_common.abstract.metric_comparator import MetricComparator
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.app_event_type import AppEventType
from nvflare.app_common.ccwf.client_ctl import ClientSideController
from nvflare.app_common.ccwf.common import Constant, NumberMetricComparator, ResultType, make_task_name
from nvflare.fuel.utils.validation_utils import check_non_empty_str, check_positive_int, check_positive_number
from nvflare.security.logging import secure_format_traceback
class _TrainerStatus:
def __init__(self, name: str):
self.name = name
self.last_submit_req_time = None # the last time this trainer requested to submit result
self.busy = False # whether this trainer is busy
self.reply_time = None # the time this trainer's result is received
[docs]
class Gatherer(FLComponent):
def __init__(
self,
task_data: Shareable,
fl_ctx: FLContext,
for_round: int,
executor: ClientSideController,
aggregator: Aggregator,
metric_comparator: MetricComparator,
all_clients: list,
trainers: list,
min_responses_required: int,
wait_time_after_min_resps_received: float,
timeout,
max_concurrent_submissions: int = 1,
):
FLComponent.__init__(self)
self.fl_ctx = fl_ctx
self.executor = executor
self.aggregator = aggregator
self.metric_comparator = metric_comparator
self.all_clients = all_clients
self.trainers = trainers
self.for_round = for_round
self.trainer_statuses = {}
self.start_time = time.time()
self.timeout = timeout
self.max_concurrent_submissions = max_concurrent_submissions
for t in trainers:
self.trainer_statuses[t] = _TrainerStatus(t)
if min_responses_required <= 0 or min_responses_required >= len(trainers):
min_responses_required = len(trainers)
self.min_responses_required = min_responses_required
self.wait_time_after_min_resps_received = wait_time_after_min_resps_received
self.min_resps_received_time = None
self.lock = threading.Lock()
self.perm_lock = threading.Lock()
self.current_best_client = task_data.get_header(Constant.CLIENT)
self.current_best_global_metric = task_data.get_header(Constant.METRIC)
self.current_best_round = task_data.get_header(Constant.ROUND)
if not self.current_best_client:
self.log_info(fl_ctx, "gatherer starting from scratch")
else:
self.log_info(
fl_ctx,
f"gatherer starting with previous best result from client {self.current_best_client} "
f"with metric {self.current_best_global_metric} "
f"at round {self.current_best_round}",
)
[docs]
def gather(self, client_name: str, result: Shareable, fl_ctx: FLContext) -> Shareable:
with self.lock:
try:
return self._do_gather(client_name, result, fl_ctx)
except:
self.log_error(fl_ctx, f"exception gathering: {secure_format_traceback()}")
return make_reply(ReturnCode.EXECUTION_EXCEPTION)
finally:
with self.perm_lock:
client_status = self.trainer_statuses.get(client_name)
if client_status:
client_status.busy = False
[docs]
def can_accept_submission(self, client_name: str, result: Shareable, fl_ctx: FLContext) -> str:
with self.perm_lock:
result_round = result.get_header(AppConstants.CURRENT_ROUND)
client_status = self.trainer_statuses.get(client_name)
if not client_status:
self.log_error(
fl_ctx, f"submission request from {client_name} for round {result_round}, but it is not a trainer"
)
return ReturnCode.MODEL_UNRECOGNIZED
client_status.last_submit_req_time = time.time()
if client_status.busy:
# we already granted permission
self.log_debug(fl_ctx, f"already granted permission to client {client_name}")
return ReturnCode.OK
# how many are busy now?
busy = 0
for ts in self.trainer_statuses.values():
if ts.busy:
busy += 1
if busy >= self.max_concurrent_submissions:
self.log_debug(
fl_ctx,
f"asked client {client_name} to wait: busy clients {busy} >= {self.max_concurrent_submissions}",
)
return ReturnCode.SERVICE_UNAVAILABLE
else:
# we can accept
client_status.busy = True
self.log_debug(fl_ctx, f"OK to accept submission from {client_name}")
return ReturnCode.OK
def _do_gather(self, client_name: str, result: Shareable, fl_ctx: FLContext) -> Shareable:
result_round = result.get_header(AppConstants.CURRENT_ROUND)
ts = self.trainer_statuses.get(client_name)
if not ts:
self.log_error(
fl_ctx, f"received result from {client_name} for round {result_round}, but it is not a trainer"
)
return make_reply(ReturnCode.EXECUTION_EXCEPTION)
if result_round > self.for_round:
# this should never happen!
# otherwise it means that the client is sending me result for a round that I couldn't possibly schedule!
self.log_error(
fl_ctx,
f"logic error: received result from {client_name} for round {result_round}, "
f"which is > gatherer's current round {self.for_round}",
)
self.executor.update_status(action="gather", error=ReturnCode.EXECUTION_EXCEPTION)
return make_reply(ReturnCode.EXECUTION_EXCEPTION)
if result_round < self.for_round:
# this is a late result for a round that I scheduled in the past.
# Note: we still accept it!
self.log_warning(
fl_ctx,
f"received late result from {client_name} for round {result_round}, "
f"which is < gatherer's current round {self.for_round}",
)
if result_round == self.for_round:
# this is the result that I'm waiting for.
now = time.time()
ts.reply_time = now
if not self.min_resps_received_time:
# see how many responses I have received
num_resps_received = 0
for _, ts in self.trainer_statuses.items():
if ts.reply_time:
num_resps_received += 1
if num_resps_received >= self.min_responses_required:
self.min_resps_received_time = now
rc = result.get_return_code(ReturnCode.OK)
if rc != ReturnCode.OK:
self.log_error(fl_ctx, f"Bad result from {client_name} for round {result_round}: {rc}.")
return make_reply(ReturnCode.EXECUTION_EXCEPTION)
fl_ctx.set_prop(AppConstants.CURRENT_ROUND, self.for_round, private=True, sticky=True)
fl_ctx.set_prop(AppConstants.TRAINING_RESULT, result, private=True, sticky=False)
self.fire_event(AppEventType.BEFORE_CONTRIBUTION_ACCEPT, fl_ctx)
accepted = self.aggregator.accept(result, fl_ctx)
accepted_msg = "ACCEPTED" if accepted else "REJECTED"
self.log_info(
fl_ctx, f"Contribution from {client_name} {accepted_msg} by the aggregator at round {result_round}."
)
fl_ctx.set_prop(AppConstants.AGGREGATION_ACCEPTED, accepted, private=True, sticky=False)
self.fire_event(AppEventType.AFTER_CONTRIBUTION_ACCEPT, fl_ctx)
return make_reply(ReturnCode.OK)
[docs]
def aggregate(self):
fl_ctx = self.fl_ctx
self.log_info(fl_ctx, f"Start aggregation for round {self.for_round}")
self.fire_event(AppEventType.BEFORE_AGGREGATION, fl_ctx)
aggr_result = self.aggregator.aggregate(fl_ctx)
self.fire_event_with_data(AppEventType.AFTER_AGGREGATION, fl_ctx, AppConstants.AGGREGATION_RESULT, aggr_result)
self.log_info(fl_ctx, f"Finished aggregation for round {self.for_round}")
mine_is_better = False
if self.current_best_global_metric is not None:
if (
self.executor.best_metric is not None
and self.metric_comparator.compare(self.executor.best_metric, self.current_best_global_metric) > 0
):
mine_is_better = True
elif self.executor.best_metric is not None:
mine_is_better = True
if mine_is_better:
self.log_info(
fl_ctx, f"I got better metric {self.executor.best_metric} at round {self.executor.best_round}"
)
best_round = self.executor.best_round
best_metric = self.executor.best_metric
best_client = self.executor.me
else:
best_round = self.current_best_round
best_metric = self.current_best_global_metric
best_client = self.current_best_client
self.log_info(fl_ctx, f"global best metric is {best_metric} from client {best_client} at round {best_round}")
aggr_result.set_header(Constant.ROUND, best_round)
aggr_result.set_header(Constant.METRIC, best_metric)
aggr_result.set_header(Constant.CLIENT, best_client)
return aggr_result
[docs]
def is_done(self):
unfinished = 0
for c, s in self.trainer_statuses.items():
if not s.reply_time:
unfinished += 1
if unfinished == 0:
return True
# timeout?
now = time.time()
if self.timeout and now - self.start_time > self.timeout:
self.log_warning(self.fl_ctx, f"gatherer for round {self.for_round} timed out after {self.timeout} seconds")
return True
if (
self.min_resps_received_time
and now - self.min_resps_received_time > self.wait_time_after_min_resps_received
):
# received min responses required and waited for long time
self.log_info(
self.fl_ctx,
f"gatherer for round {self.for_round} exit after {self.wait_time_after_min_resps_received} seconds "
f"since received minimum responses",
)
return True
[docs]
class SwarmClientController(ClientSideController):
def __init__(
self,
task_name_prefix=Constant.TN_PREFIX_SWARM,
learn_task_name=AppConstants.TASK_TRAIN,
persistor_id=AppConstants.DEFAULT_PERSISTOR_ID,
shareable_generator_id=AppConstants.DEFAULT_SHAREABLE_GENERATOR_ID,
aggregator_id=AppConstants.DEFAULT_AGGREGATOR_ID,
metric_comparator_id=None,
learn_task_check_interval=Constant.LEARN_TASK_CHECK_INTERVAL,
learn_task_abort_timeout=Constant.LEARN_TASK_ABORT_TIMEOUT,
learn_task_ack_timeout=Constant.LEARN_TASK_ACK_TIMEOUT,
learn_task_timeout=None,
final_result_ack_timeout=Constant.FINAL_RESULT_ACK_TIMEOUT,
min_responses_required: int = 1,
wait_time_after_min_resps_received: float = 10.0,
request_to_submit_result_max_wait=None,
request_to_submit_result_msg_timeout=5.0,
request_to_submit_result_interval: float = 1.0,
max_concurrent_submissions: int = 1,
memory_gc_rounds: int = 1,
cuda_empty_cache: bool = False,
):
"""
Constructor of a ClientSideController object.
Args:
task_name_prefix: prefix of task names. All CCWF task names are prefixed with this.
learn_task_name: name for the Learning Task (LT)
persistor_id: ID of the persistor component
shareable_generator_id: ID of the shareable generator component
aggregator_id: ID of the aggregator
metric_comparator_id: ID of metric comparator to be used for determining best model.
If not specified, the default NumberMetricComparator is used.
learn_task_check_interval: interval for checking incoming Learning Task (LT)
learn_task_ack_timeout: timeout for sending the LT to other client(s)
learn_task_timeout: max time allowed for a training task
final_result_ack_timeout: timeout for sending final result to participating clients
learn_task_abort_timeout: time to wait for the LT to become stopped after aborting it
min_responses_required: minimum number of responses required for the aggregation
wait_time_after_min_resps_received: how long to wait after min responses (but not all responses)
are received.
request_to_submit_result_max_wait: max amount of time to wait for the permission from the
aggregation client. If the permission is not received within this period of time, the training
result will not be submitted. If this value is not specified (None), then the training client
will keep trying forever.
request_to_submit_result_msg_timeout: the timeout for "submission request" message.
Since submission req is a tiny message, this timeout value should be small.
request_to_submit_result_interval: interval between requests to submit result.
max_concurrent_submissions: max number of concurrent submissions allowed on the aggregation client.
memory_gc_rounds: run gc.collect() + malloc_trim on the aggregator every N FL rounds.
Defaults to 1 (every round) to match legacy behavior where gc.collect() was called
unconditionally after each trainer submission. Set to 0 to disable.
cuda_empty_cache: also call torch.cuda.empty_cache() during aggregator-side cleanup.
In swarm learning the aggregator runs on the same client as the trainer, so GPU
memory may be relevant. Defaults to False.
Note that if the max_concurrent_submissions is set to 1, it practically means that all training results
will be submitted to the aggregation client sequentially. This lowers the resource pressure on
the aggr client, but makes the overall training process longer. The value of request_to_submit_result_max_wait,
if specified, should be long enough to allow the aggr client sufficient time to process training results.
"""
check_non_empty_str("learn_task_name", learn_task_name)
check_non_empty_str("persistor_id", persistor_id)
check_non_empty_str("shareable_generator_id", shareable_generator_id)
check_non_empty_str("aggregator_id", aggregator_id)
check_positive_number("request_to_submit_result_msg_timeout", request_to_submit_result_msg_timeout)
check_positive_number("request_to_submit_result_interval", request_to_submit_result_interval)
check_positive_int("max_concurrent_submissions", max_concurrent_submissions)
if request_to_submit_result_max_wait:
check_positive_number("request_to_submit_result_max_wait", request_to_submit_result_max_wait)
if metric_comparator_id:
check_non_empty_str("metric_comparator_id", metric_comparator_id)
if learn_task_timeout:
check_positive_number("learn_task_timeout", learn_task_timeout)
check_positive_int("min_responses_required", min_responses_required)
check_positive_number("wait_time_after_min_resps_received", wait_time_after_min_resps_received)
super().__init__(
task_name_prefix=task_name_prefix,
learn_task_name=learn_task_name,
persistor_id=persistor_id,
shareable_generator_id=shareable_generator_id,
learn_task_check_interval=learn_task_check_interval,
learn_task_ack_timeout=learn_task_ack_timeout,
learn_task_abort_timeout=learn_task_abort_timeout,
final_result_ack_timeout=final_result_ack_timeout,
allow_busy_task=True,
)
self.metric_comparator_id = metric_comparator_id
self.metric_comparator = None
self.report_learn_result_task_name = make_task_name(task_name_prefix, Constant.BASENAME_REPORT_LEARN_RESULT)
self.request_to_submit_learn_result_task_name = make_task_name(
task_name_prefix, Constant.BASENAME_REQUEST_TO_SUBMIT_LEARN_RESULT
)
self.max_concurrent_submissions = max_concurrent_submissions
self.request_to_submit_result_max_wait = request_to_submit_result_max_wait
self.request_to_submit_result_msg_timeout = request_to_submit_result_msg_timeout
self.request_to_submit_result_interval = request_to_submit_result_interval
self.learn_task_timeout = learn_task_timeout
self.min_responses_required = min_responses_required
self.wait_time_after_min_resps_received = wait_time_after_min_resps_received
self.aggregator_id = aggregator_id
self.aggregator = None
self.gatherer = None
self.gatherer_waiter = threading.Event()
self.trainers = None
self.aggrs = None
self.is_trainer = False
self.is_aggr = False
self.last_aggr_round_done = -1
self.memory_gc_rounds = memory_gc_rounds
self.cuda_empty_cache = cuda_empty_cache
self._aggr_round_count = 0
[docs]
def process_config(self, fl_ctx: FLContext):
all_clients = self.get_config_prop(Constant.CLIENTS)
self.trainers = self.get_config_prop(Constant.TRAIN_CLIENTS)
if not self.trainers:
self.trainers = all_clients
self.is_trainer = self.me in self.trainers
self.aggrs = self.get_config_prop(Constant.AGGR_CLIENTS)
if not self.aggrs:
self.aggrs = all_clients
self.is_aggr = self.me in self.aggrs
self.engine.register_aux_message_handler(
topic=self.topic_for_my_workflow(Constant.TOPIC_SHARE_RESULT),
message_handle_func=self._process_share_result,
)
self.engine.register_aux_message_handler(
topic=self.request_to_submit_learn_result_task_name,
message_handle_func=self._process_submission_request,
)
[docs]
def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
if task_name == self.report_learn_result_task_name:
return self._process_learn_result(shareable, fl_ctx, abort_signal)
return super().execute(task_name, shareable, fl_ctx, abort_signal)
[docs]
def start_run(self, fl_ctx: FLContext):
super().start_run(fl_ctx)
self.aggregator = self.engine.get_component(self.aggregator_id)
if not isinstance(self.aggregator, Aggregator):
self.system_panic(
f"aggregator {self.aggregator_id} must be an Aggregator but got {type(self.aggregator)}",
fl_ctx,
)
return
if self.metric_comparator_id:
self.metric_comparator = self.engine.get_component(self.metric_comparator_id)
if not isinstance(self.metric_comparator, MetricComparator):
self.system_panic(
f"metric comparator {self.metric_comparator_id} must be a MetricComparator "
f"but got {type(self.metric_comparator)}",
fl_ctx,
)
return
else:
# use default comparator
self.metric_comparator = NumberMetricComparator()
aggr_thread = threading.Thread(target=self._monitor_gather)
aggr_thread.daemon = True
aggr_thread.start()
self.log_debug(fl_ctx, "started aggregator thread")
[docs]
def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == AppEventType.GLOBAL_BEST_MODEL_AVAILABLE:
client = fl_ctx.get_prop(Constant.CLIENT)
if client and client != self.me:
# this global best model is from other client
# we got here because this event is fired when I receive the best model shared from another
# client at the end of the workflow.
return
# we got here because the best model selector fired this event: it found the "local best global"
self.best_metric = fl_ctx.get_prop(AppConstants.VALIDATION_RESULT)
self.best_result = copy.deepcopy(fl_ctx.get_prop(AppConstants.GLOBAL_MODEL))
self.log_info(fl_ctx, f"got GLOBAL_BEST_MODEL_AVAILABLE: best metric={self.best_metric}")
current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND)
self.best_round = current_round
self.update_status(last_round=current_round, action="better_aggregation")
else:
super().handle_event(event_type, fl_ctx)
[docs]
def start_workflow(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
clients = self.get_config_prop(Constant.CLIENTS)
aggregator_candidates = self.get_config_prop(Constant.AGGR_CLIENTS, [])
train_clients = self.get_config_prop(Constant.TRAIN_CLIENTS, [])
self.log_info(
fl_ctx,
f"Starting Swarm Workflow on clients {clients}, aggregator candidates {aggregator_candidates}, trainers {train_clients}",
)
if not self._scatter(
task_data=shareable, for_round=self.get_config_prop(Constant.START_ROUND, 0), fl_ctx=fl_ctx
):
return make_reply(ReturnCode.EXECUTION_EXCEPTION)
self.log_info(fl_ctx, "Started Swarm Workflow")
return make_reply(ReturnCode.OK)
def _scatter(self, task_data: Shareable, for_round: int, fl_ctx: FLContext) -> bool:
clients = self.get_config_prop(Constant.TRAIN_CLIENTS)
aggregator_candidates = self.get_config_prop(Constant.AGGR_CLIENTS)
# determine aggr client
aggr = random.choice(aggregator_candidates)
task_data.set_header(AppConstants.CURRENT_ROUND, for_round)
task_data.add_cookie(AppConstants.CONTRIBUTION_ROUND, for_round)
task_data.set_header(Constant.AGGREGATOR, aggr)
# Stamp MSG_ROOT_TTL on the task data so the sender's ArrayDownloadable
# download transaction stays alive for learn_task_timeout seconds — long
# enough for the receiving subprocess to pull the global model. Without
# this, via_downloader._create_downloader() falls back to
# _MIN_DOWNLOAD_TIMEOUT (300 s inactivity floor), which is sufficient
# for most GC pauses. task_controller.broadcast_and_wait() preserves a
# pre-set MSG_ROOT_TTL instead of overwriting it with the short ACK timeout.
if self.learn_task_timeout:
task_data.set_header(ReservedHeaderKey.MSG_ROOT_TTL, float(self.learn_task_timeout))
targets = copy.copy(clients)
if aggr not in targets:
targets.append(aggr)
# Handle self locally to avoid synchronous self-message deadlock.
# When sending to self via broadcast_and_wait, the message is processed synchronously
# on the same thread (via _send_direct_message in core_cell.py). If TensorStreamer
# is enabled, this causes deadlock because wait_for_tensors() blocks the thread
# waiting for streaming data that can't arrive on the blocked thread.
# Instead, queue the task locally via set_learn_task (non-blocking, processed by _do_learn thread).
should_queue_locally = self.me in targets
remote_targets = [t for t in targets if t != self.me]
# Queue locally FIRST with a deep copy. Deep copy is needed because:
# 1. set_learn_task stores a reference, _do_learn processes it later on another thread
# 2. send_learn_task may modify task_data in-place (e.g., TensorStreamer replacing tensors with REF IDs)
# 3. Without deep copy, there's a race condition between modification and processing
if should_queue_locally:
self.log_info(fl_ctx, f"queuing learn task locally for round {for_round}")
local_task_data = copy.deepcopy(task_data)
if self._has_lazy_refs(local_task_data):
local_task_data = self._resolve_lazy_refs(local_task_data, fl_ctx)
if not self.set_learn_task(task_data=local_task_data, fl_ctx=fl_ctx):
self.log_error(fl_ctx, f"failed to queue learn task locally for round {for_round}")
return False
# Then send to remote targets
if remote_targets:
self.log_info(
fl_ctx,
f"broadcasting learn task of round {for_round} to {remote_targets}; aggregation happens on {aggr}",
)
if not self.send_learn_task(targets=remote_targets, request=task_data, fl_ctx=fl_ctx):
return False
return True
def _monitor_gather(self):
while True:
if self.asked_to_stop:
return
gatherer = self.gatherer
if gatherer:
assert isinstance(gatherer, Gatherer)
if gatherer.is_done():
self.last_aggr_round_done = gatherer.for_round
self.gatherer = None
self.gatherer_waiter.clear()
try:
self._end_gather(gatherer)
except:
self.logger.error(f"exception ending gatherer: {secure_format_traceback()}")
self.update_status(action="aggregate", error=ReturnCode.EXECUTION_EXCEPTION)
time.sleep(0.2)
def _end_gather(self, gatherer: Gatherer):
fl_ctx = gatherer.fl_ctx
try:
aggr_result = gatherer.aggregate()
except:
self.log_error(fl_ctx, f"exception in aggregation: {secure_format_traceback()}")
self.update_status(action="aggregate", error=ReturnCode.EXECUTION_EXCEPTION)
return
if self._has_lazy_refs(aggr_result):
self.system_panic(
"LazyDownloadRef objects reached _end_gather() — "
"_resolve_lazy_refs() was not called on the local-aggr path. "
"This is a code bug.",
fl_ctx,
)
return
# aggr_result could be just weight diffs, not full weights!
# need to call shareable_to_learnable to get full weights.
self.log_debug(fl_ctx, f"aggr result: {aggr_result}")
global_weights = self.shareable_generator.shareable_to_learnable(aggr_result, fl_ctx)
self.record_last_result(fl_ctx, gatherer.for_round, global_weights)
# are we done with training?
num_rounds_done = gatherer.for_round - self.get_config_prop(Constant.START_ROUND, 0) + 1
if num_rounds_done >= self.get_config_prop(AppConstants.NUM_ROUNDS):
self.log_info(fl_ctx, f"Swarm Learning Done: number of rounds completed {num_rounds_done}")
# determine the best global result
self._distribute_final_results(aggr_result, fl_ctx)
else:
# continue next round
next_round_data = self.shareable_generator.learnable_to_shareable(global_weights, fl_ctx)
assert isinstance(next_round_data, Shareable)
best_round = aggr_result.get_header(Constant.ROUND)
best_metric = aggr_result.get_header(Constant.METRIC)
best_client = aggr_result.get_header(Constant.CLIENT)
if best_client:
next_round_data.set_header(Constant.ROUND, best_round)
next_round_data.set_header(Constant.CLIENT, best_client)
next_round_data.set_header(Constant.METRIC, best_metric)
self._scatter(next_round_data, gatherer.for_round + 1, gatherer.fl_ctx)
if self.memory_gc_rounds > 0:
self._aggr_round_count += 1
if self._aggr_round_count % self.memory_gc_rounds == 0:
from nvflare.fuel.utils.memory_utils import cleanup_memory
cleanup_memory(cuda_empty_cache=self.cuda_empty_cache)
self.log_info(fl_ctx, f"Swarm aggregator memory cleanup at round {self._aggr_round_count}")
def _ask_to_share_best_result(self, client: str, metric, fl_ctx: FLContext):
# other client has best model - ask it to distribute its result
self.log_info(fl_ctx, f"client {client} has the best metric {metric} - ask it to share result")
resp = self.engine.send_aux_request(
targets=[client],
topic=self.topic_for_my_workflow(Constant.TOPIC_SHARE_RESULT),
request=Shareable(),
timeout=self.final_result_ack_timeout,
fl_ctx=fl_ctx,
secure=False,
)
assert isinstance(resp, dict)
reply = resp.get(client)
if not reply:
self.log_error(fl_ctx, f"failed to ask client {client} to share final result")
return
if not isinstance(reply, Shareable):
self.log_error(fl_ctx, f"client {client} failed to respond to share final result request")
return
rc = reply.get_return_code()
if rc != ReturnCode.OK:
self.log_error(fl_ctx, f"client {client} failed to respond to share final result request: {rc}")
def _distribute_final_results(self, aggr_result: Shareable, fl_ctx: FLContext):
best_client = aggr_result.get_header(Constant.CLIENT)
best_metric = aggr_result.get_header(Constant.METRIC)
if best_client:
if best_client == self.me:
# I have the best model
self.log_info(fl_ctx, f"I have global best metric {best_metric}")
self.broadcast_final_result(
fl_ctx, ResultType.BEST, self.best_result, self.best_metric, self.best_round
)
else:
try:
self._ask_to_share_best_result(best_client, best_metric, fl_ctx)
except:
self.log_error(
fl_ctx, f"error asking client {best_client} to share best result {secure_format_traceback()}"
)
else:
self.log_info(fl_ctx, "No global best result!")
self.log_info(fl_ctx, "distributing last result")
self.broadcast_final_result(fl_ctx, ResultType.LAST, self.last_result, round_num=self.last_round)
def _process_submission_request(self, topic: str, request: Shareable, fl_ctx: FLContext):
peer_ctx = fl_ctx.get_peer_context()
assert isinstance(peer_ctx, FLContext)
client_name = peer_ctx.get_identity_name()
current_round = request.get_header(AppConstants.CURRENT_ROUND)
self.log_debug(fl_ctx, f"got result submission request {topic} from {client_name} for round {current_round}")
gatherer = self.gatherer
if not gatherer:
# this could be from a fast client before I even create the waiter;
# or from a late client after I already finished gathering.
if current_round <= self.last_aggr_round_done:
# late client case - drop the result
self.log_debug(fl_ctx, f"reject from late {client_name} for round {current_round}")
return make_reply(ReturnCode.MODEL_UNRECOGNIZED)
# case of fast client - ask to try again
self.log_debug(
fl_ctx, f"got submission result from {client_name} for round {current_round} before gatherer setup"
)
return make_reply(ReturnCode.SERVICE_UNAVAILABLE)
assert isinstance(gatherer, Gatherer)
if gatherer.for_round != current_round:
self.log_warning(
fl_ctx,
f"Got submission request from {client_name} for round {current_round}, "
f"but I'm waiting for round {gatherer.for_round}",
)
# check whether this client is permitted to submit result
rc = gatherer.can_accept_submission(client_name, request, fl_ctx)
self.log_debug(fl_ctx, f"got permission from gatherer: {rc}")
return make_reply(rc)
@staticmethod
def _has_lazy_refs(obj) -> bool:
"""Return True if obj (recursively) contains any LazyDownloadRef."""
from nvflare.fuel.utils.fobs.decomposers.via_downloader import LazyDownloadRef
if isinstance(obj, LazyDownloadRef):
return True
if isinstance(obj, dict):
return any(SwarmClientController._has_lazy_refs(v) for v in obj.values())
if isinstance(obj, (list, tuple)):
return any(SwarmClientController._has_lazy_refs(v) for v in obj)
return False
def _resolve_lazy_refs(self, result: Shareable, fl_ctx: FLContext) -> Shareable:
"""Resolve any LazyDownloadRef objects in result by downloading from subprocess.
When the subprocess sends its result via CellPipe with pass_through_on_send=True,
Adapter.call() decodes the message with PASS_THROUGH=True and creates
LazyDownloadRef objects (one per large tensor) instead of downloading the tensors.
These placeholders carry the subprocess's fqcn and ref_id so that a downstream
hop can download from the subprocess DownloadService on demand.
For the remote aggregator path this download is triggered automatically by the
FOBS encode/decode inside broadcast_and_wait() (Fix 14). For the local
aggregation path (aggr == self.me) there is no encode/decode, so we must
trigger the download explicitly here before the result reaches the gatherer.
Uses an FOBS round-trip:
encode: LazyDownloadRefDecomposer.decompose() re-emits the original subprocess
datum (fqcn + ref_id) as a TEXT datum — no CELL needed in the encode ctx.
decode: process_datum() with PASS_THROUGH=False calls _download_from_remote_cell()
which downloads real numpy arrays from the subprocess DownloadService.
cell.get_fobs_context() supplies the CELL so the download can route to
the subprocess via the cell network.
"""
import nvflare.fuel.utils.fobs as fobs
engine = fl_ctx.get_engine()
if not engine:
return result
# Not all engine implementations expose get_cell() (e.g. test stubs).
# If the method is absent or returns None, skip the download — there is
# no cell network available to route the download through.
get_cell = getattr(engine, "get_cell", None)
if not get_cell:
return result
cell = get_cell()
if not cell:
return result
encoded = fobs.dumps(result)
decode_ctx = cell.get_fobs_context(props={fobs.FOBSContextKey.PASS_THROUGH: False})
return fobs.loads(encoded, fobs_ctx=decode_ctx)
def _process_learn_result(self, request: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
try:
peer_ctx = fl_ctx.get_peer_context()
assert isinstance(peer_ctx, FLContext)
client_name = peer_ctx.get_identity_name()
current_round = request.get_header(AppConstants.CURRENT_ROUND)
self.log_info(fl_ctx, f"got training result from {client_name} for round {current_round}")
# Remote trainer results arrive on CellChannel.AUX_COMMUNICATION, which is
# never in decode_pass_through_channels, so Adapter.call() uses PASS_THROUGH=False
# and tensors are downloaded inline — request already contains real tensors here.
# This check is a defensive guard: if a future caller path omits the pre-resolve
# step (as the local path does above), lazy refs are still resolved before the
# gatherer rather than crashing downstream.
if self._has_lazy_refs(request):
request = self._resolve_lazy_refs(request, fl_ctx)
# to be compatible with some widgets that rely on peer_ctx to get result
peer_ctx.set_prop(FLContextKey.SHAREABLE, request, private=True, sticky=True)
gatherer = self.gatherer
if not gatherer:
# this could be from a fast client before I even create the waiter;
# or from a late client after I already finished gathering.
if current_round <= self.last_aggr_round_done:
# late client case - drop the result
self.log_info(fl_ctx, f"dropped result from late {client_name} for round {current_round}")
return make_reply(ReturnCode.OK)
# case of fast client
# wait until the gatherer is set up.
self.log_info(fl_ctx, f"got result from {client_name} for round {current_round} before gatherer setup")
self.gatherer_waiter.wait(self.learn_task_abort_timeout)
if abort_signal.triggered:
return make_reply(ReturnCode.TASK_ABORTED)
gatherer = self.gatherer
if not gatherer:
self.log_error(fl_ctx, f"Still no gatherer after {self.learn_task_abort_timeout} seconds")
self.log_error(fl_ctx, f"Ignored result from {client_name} for round {current_round} since no gatherer")
self.update_status(action="wait_for_gatherer", error=ReturnCode.EXECUTION_EXCEPTION)
return make_reply(ReturnCode.EXECUTION_EXCEPTION)
assert isinstance(gatherer, Gatherer)
if gatherer.for_round != current_round:
self.log_warning(
fl_ctx,
f"Got result from {client_name} for round {current_round}, "
f"but I'm waiting for round {gatherer.for_round}",
)
return gatherer.gather(client_name, request, fl_ctx)
except:
self.log_exception(fl_ctx, f"exception processing learn result: {secure_format_traceback()}")
self.update_status(action="process_learn_result", error=ReturnCode.EXECUTION_EXCEPTION)
return make_reply(ReturnCode.EXECUTION_EXCEPTION)
[docs]
def do_learn_task(self, name: str, task_data: Shareable, fl_ctx: FLContext, abort_signal: Signal):
# set status report of starting task
current_round = task_data.get_header(AppConstants.CURRENT_ROUND)
self.update_status(last_round=current_round, action="start_learn_task")
aggr = task_data.get_header(Constant.AGGREGATOR)
if not aggr:
self.log_error(fl_ctx, f"missing aggregation client for round {current_round}")
self.update_status(action="do_learn_task", error=ReturnCode.EXECUTION_EXCEPTION)
return
self.log_info(fl_ctx, f"Round {current_round} started.")
task_data.set_header(FLContextKey.TASK_NAME, name)
# Some shareable generators assume the base model (GLOBAL_MODEL) is always available, which is true for
# server-controlled fed-avg. But this is not true for swarm learning.
# To make these generators happy, we create an empty global model here if not present.
base_model = fl_ctx.get_prop(AppConstants.GLOBAL_MODEL)
if not base_model:
base_model = Learnable()
fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, base_model, private=True, sticky=True)
# If task_data contains LazyDownloadRef (receiver-side decode_pass_through
# was active), resolve before shareable_to_learnable so GLOBAL_MODEL gets
# real tensors — required by the WEIGHT_DIFF branch of _end_gather().
# task_data itself keeps its refs intact for execute_learn_task() below so
# the subprocess can download directly from the source DownloadService.
task_data_for_model = (
self._resolve_lazy_refs(task_data, fl_ctx) if self._has_lazy_refs(task_data) else task_data
)
global_weights = self.shareable_generator.shareable_to_learnable(task_data_for_model, fl_ctx)
self.log_debug(fl_ctx, f"current global model: {global_weights}")
fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, global_weights, private=True, sticky=True)
fl_ctx.set_prop(AppConstants.CURRENT_ROUND, current_round, private=True, sticky=True)
self.fire_event(AppEventType.ROUND_STARTED, fl_ctx)
if self.me == aggr:
# set up the aggr waiter
gatherer = self.gatherer
if gatherer:
# already waiting for aggregation - should never happen
self.log_error(
fl_ctx,
f"logic error: got task for round {current_round} while gathering for round {gatherer.for_round}",
)
self.update_status(action="do_learn_task", error=ReturnCode.EXECUTION_EXCEPTION)
return
self.log_info(fl_ctx, f"setting up the gatherer for round {current_round}")
self.gatherer = Gatherer(
fl_ctx=fl_ctx,
all_clients=self.get_config_prop(Constant.CLIENTS),
metric_comparator=self.metric_comparator,
trainers=self.trainers,
for_round=current_round,
timeout=self.learn_task_timeout,
min_responses_required=self.min_responses_required,
wait_time_after_min_resps_received=self.wait_time_after_min_resps_received,
aggregator=self.aggregator,
executor=self,
task_data=task_data,
max_concurrent_submissions=self.max_concurrent_submissions,
)
self.gatherer_waiter.set()
# execute the task
if self.is_trainer:
# update status
result = self.execute_learn_task(task_data, fl_ctx, abort_signal)
rc = result.get_return_code(ReturnCode.OK)
if rc != ReturnCode.OK:
self.log_error(fl_ctx, f"learn executor failed: {rc}")
self.update_status(action="learner_execution", error=rc)
return
# ask permission to submit result to the aggr client repeatedly until permitted
self.log_info(fl_ctx, f"asking permission to submit result to the aggregation client {aggr}")
submission_req = Shareable()
submission_req.set_header(AppConstants.CURRENT_ROUND, current_round)
req_start_time = time.time()
engine = fl_ctx.get_engine()
max_wait = self.request_to_submit_result_max_wait
while True:
if abort_signal.triggered:
self.log_info(
fl_ctx, f"giving up result submission to {aggr} for round {current_round}: job aborted"
)
return
if max_wait and time.time() - req_start_time > max_wait:
self.log_error(
fl_ctx, f"giving up result submission to {aggr} for round {current_round} after {max_wait} secs"
)
return
resp = engine.send_aux_request(
targets=[aggr],
topic=self.request_to_submit_learn_result_task_name,
request=submission_req,
timeout=self.request_to_submit_result_msg_timeout,
fl_ctx=fl_ctx,
secure=False,
)
self.log_debug(fl_ctx, f"got request_to_submit response from {aggr}: {resp}")
reply = resp.get(aggr)
if not reply:
self.log_error(
fl_ctx, f"failed to receive reply for submission request from aggregation client: {aggr}"
)
self.update_status(action="receive_permission_reply", error=ReturnCode.EXECUTION_EXCEPTION)
return
if not isinstance(reply, Shareable):
self.log_error(
fl_ctx, f"bad reply from aggregation client {aggr}: expect Shareable but got {type(reply)}"
)
self.update_status(action="receive_permission_reply", error=ReturnCode.EXECUTION_EXCEPTION)
return
rc = reply.get_return_code(ReturnCode.OK)
if rc == ReturnCode.OK:
# permission granted
time_taken = time.time() - req_start_time
self.log_info(
fl_ctx,
f"got permission from {aggr} to submit round {current_round} result in {time_taken} secs",
)
break
elif rc == ReturnCode.MODEL_UNRECOGNIZED:
# aggr client doesn't want me to submit the result!
self.log_info(fl_ctx, f"{aggr} does not want me to submit learn result!")
self.update_status(action="receive_learn_result_reply", error=rc)
return
elif rc != ReturnCode.SERVICE_UNAVAILABLE:
self.log_warning(fl_ctx, f"got unexpected RC {rc} for submission request from {aggr}")
# aggr client is not ready - need try again
time.sleep(self.request_to_submit_result_interval)
# send the result to the aggr
if aggr == self.me:
# Avoid synchronous self-message path through CoreCell._send_direct_message.
self.log_info(fl_ctx, "submitting training result locally (aggregation client is self)")
# The subprocess result arrives at CJ as LazyDownloadRef (subprocess-side
# CellPipe has pass_through_on_send=True). Resolve before local aggregation.
# The remote path goes through AUX_COMMUNICATION (not the pipe channel), so
# Adapter.call() downloads tensors inline and _process_learn_result() receives
# real tensors; its own _has_lazy_refs guard is purely defensive.
result = self._resolve_lazy_refs(result, fl_ctx)
engine = fl_ctx.get_engine()
local_fl_ctx = fl_ctx.clone()
local_fl_ctx.set_peer_context(engine.new_context())
reply = self._process_learn_result(result, local_fl_ctx, abort_signal)
else:
self.log_info(fl_ctx, f"sending training result to aggregation client {aggr}")
task = Task(
name=self.report_learn_result_task_name,
data=result,
timeout=int(self.learn_task_ack_timeout),
secure=self.is_task_secure(fl_ctx),
)
resp = self.broadcast_and_wait(
task=task,
targets=[aggr],
min_responses=1,
fl_ctx=fl_ctx,
)
reply = resp.get(aggr)
if not reply:
self.log_error(fl_ctx, f"failed to receive reply from aggregation client: {aggr}")
self.update_status(action="receive_learn_result_reply", error=ReturnCode.EXECUTION_EXCEPTION)
return
if not isinstance(reply, Shareable):
self.log_error(
fl_ctx, f"bad reply from aggregation client {aggr}: expect Shareable but got {type(reply)}"
)
self.update_status(action="receive_learn_result_reply", error=ReturnCode.EXECUTION_EXCEPTION)
return
rc = reply.get_return_code(ReturnCode.OK)
if rc != ReturnCode.OK:
self.log_error(fl_ctx, f"bad return code from aggregation client {aggr}: {rc}")
self.update_status(action="receive_learn_result_reply", error=ReturnCode.EXECUTION_EXCEPTION)
return
self.log_info(fl_ctx, f"Finished round {current_round}")
# update status
self.update_status(last_round=current_round, action="finished_learn_task")
def _process_share_result(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable:
peer_ctx = fl_ctx.get_peer_context()
assert isinstance(peer_ctx, FLContext)
client_name = peer_ctx.get_identity_name()
if not self.best_result:
self.log_error(
fl_ctx, f"got request from {client_name} to share my best result, but I don't have best result"
)
return make_reply(ReturnCode.BAD_REQUEST_DATA)
self.update_status(action="start_share_result_request_process")
self.broadcast_final_result(
fl_ctx, ResultType.BEST, self.best_result, metric=self.best_metric, round_num=self.best_round
)
return make_reply(ReturnCode.OK)