Source code for nvflare.private.fed.client.command_agent

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.utils.fl_context_utils import gen_new_peer_ctx
from nvflare.fuel.f3.cellnet.core_cell import Message as CellMessage
from nvflare.fuel.f3.cellnet.core_cell import MessageHeaderKey, ReturnCode
from nvflare.fuel.f3.cellnet.core_cell import make_reply as make_cellnet_reply
from nvflare.private.defs import CellChannel, new_cell_message

from .admin_commands import AdminCommands


[docs]class CommandAgent(object): def __init__(self, federated_client) -> None: """To init the CommandAgent. Args: federated_client: FL client object """ self.federated_client = federated_client self.thread = None self.asked_to_stop = False self.commands = AdminCommands.commands self.logger = logging.getLogger(self.__class__.__name__)
[docs] def start(self, fl_ctx: FLContext): self.engine = fl_ctx.get_engine() self.register_cell_cb()
[docs] def register_cell_cb(self): self.federated_client.cell.register_request_cb( channel=CellChannel.CLIENT_COMMAND, topic="*", cb=self.execute_command, ) self.federated_client.cell.register_request_cb( channel=CellChannel.AUX_COMMUNICATION, topic="*", cb=self.aux_communication, )
[docs] def execute_command(self, request: CellMessage) -> CellMessage: assert isinstance(request, CellMessage), "request must be CellMessage but got {}".format(type(request)) command_name = request.get_header(MessageHeaderKey.TOPIC) data = request.payload command = AdminCommands.get_command(command_name) if command: with self.engine.new_context() as new_fl_ctx: reply = command.process(data=data, fl_ctx=new_fl_ctx) if reply is not None: return_message = new_cell_message({}, reply) return_message.set_header(MessageHeaderKey.RETURN_CODE, ReturnCode.OK) else: return_message = new_cell_message({}, None) return return_message return make_cellnet_reply(ReturnCode.INVALID_REQUEST, "", None)
[docs] def aux_communication(self, request: CellMessage) -> CellMessage: assert isinstance(request, CellMessage), "request must be CellMessage but got {}".format(type(request)) shareable = request.payload with self.engine.new_context() as fl_ctx: topic = request.get_header(MessageHeaderKey.TOPIC) reply = self.engine.dispatch(topic=topic, request=shareable, fl_ctx=fl_ctx) shared_fl_ctx = gen_new_peer_ctx(fl_ctx) reply.set_header(key=FLContextKey.PEER_CONTEXT, value=shared_fl_ctx) if reply is not None: return_message = new_cell_message({}, reply) return_message.set_header(MessageHeaderKey.RETURN_CODE, ReturnCode.OK) else: return_message = new_cell_message({}, None) return return_message
[docs] def shutdown(self): self.asked_to_stop = True