Source code for nvflare.app_opt.p2p.controllers.dist_opt_controller

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nvflare.apis.controller_spec import Task
from nvflare.apis.dxo import DXO, DataKind
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import Controller
from nvflare.apis.signal import Signal
from nvflare.app_opt.p2p.types import Config


[docs] class DistOptController(Controller): """Controller for running a peer-to-peer distributed optimization algorithm on a network. This controller manages the execution of a distributed optimization algorithm by configuring each client with its neighbors and initiating the algorithm execution across the network. Args: config (Config): The P2P network configuration containing node and neighbor information. """ def __init__( self, config: Config, *args, **kwargs, ): super().__init__(*args, **kwargs) self.config = config
[docs] def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): # Send network config (aka neighbors info) to each client for node in self.config.network.nodes: task = Task( name="config", data=DXO( data_kind=DataKind.APP_DEFINED, data={"neighbors": [n.__dict__ for n in node.neighbors]}, ).to_shareable(), ) self.send_and_wait(task=task, targets=[node.id], fl_ctx=fl_ctx) # Run algorithm (with extra params if any passed as data) targets = [node.id for node in self.config.network.nodes] self.broadcast_and_wait( task=Task( name="run_algorithm", data=DXO( data_kind=DataKind.APP_DEFINED, data={key: value for key, value in self.config.extra.items()}, ).to_shareable(), ), targets=targets, min_responses=0, fl_ctx=fl_ctx, )
[docs] def start_controller(self, fl_ctx: FLContext): self.log_info(fl_ctx, "P2PAlgorithmController started")
[docs] def stop_controller(self, fl_ctx: FLContext): self.log_info(fl_ctx, "P2PAlgorithmController stopped")