# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from nvflare.apis.fl_constant import FLContextKey, ServerCommandKey
from nvflare.apis.utils.fl_context_utils import gen_new_peer_ctx
from nvflare.fuel.f3.cellnet.cell import Cell
from nvflare.fuel.f3.cellnet.core_cell import MessageHeaderKey, ReturnCode, make_reply
from nvflare.fuel.f3.message import Message as CellMessage
from nvflare.private.defs import CellChannel, CellMessageHeaderKeys, new_cell_message
from .server_commands import ServerCommands
[docs]class ServerCommandAgent(object):
def __init__(self, engine, cell: Cell) -> None:
"""To init the CommandAgent.
Args:
listen_port: port to listen the command
"""
self.logger = logging.getLogger(self.__class__.__name__)
self.asked_to_stop = False
self.engine = engine
self.cell = cell
[docs] def start(self):
self.cell.register_request_cb(
channel=CellChannel.SERVER_COMMAND,
topic="*",
cb=self.execute_command,
)
self.cell.register_request_cb(
channel=CellChannel.AUX_COMMUNICATION,
topic="*",
cb=self.aux_communicate,
)
self.logger.info(f"ServerCommandAgent cell register_request_cb: {self.cell.get_fqcn()}")
[docs] def execute_command(self, request: CellMessage) -> CellMessage:
if not isinstance(request, CellMessage):
raise RuntimeError("request must be CellMessage but got {}".format(type(request)))
command_name = request.get_header(MessageHeaderKey.TOPIC)
# data = fobs.loads(request.payload)
data = request.payload
token = request.get_header(CellMessageHeaderKeys.TOKEN, None)
# client_name = request.get_header(CellMessageHeaderKeys.CLIENT_NAME, None)
client = None
if token:
client = self._get_client(token)
if client:
data.set_header(ServerCommandKey.FL_CLIENT, client)
command = ServerCommands.get_command(command_name)
if command:
if command_name in ServerCommands.client_request_commands_names:
if not client:
return make_reply(
ReturnCode.AUTHENTICATION_ERROR,
"Request from client: missing client token",
None,
)
with self.engine.new_context() as new_fl_ctx:
if command_name in ServerCommands.client_request_commands_names:
state_check = command.get_state_check(new_fl_ctx)
error = self.engine.server.authentication_check(request, state_check)
if error:
return make_reply(ReturnCode.AUTHENTICATION_ERROR, error, None)
reply = command.process(data=data, fl_ctx=new_fl_ctx)
if reply is not None:
return_message = new_cell_message({}, reply)
return_message.set_header(MessageHeaderKey.RETURN_CODE, ReturnCode.OK)
else:
return_message = make_reply(ReturnCode.PROCESS_EXCEPTION, "No process results", None)
return return_message
else:
return make_reply(ReturnCode.INVALID_REQUEST, "No server command found", None)
def _get_client(self, token):
fl_server = self.engine.server
client_manager = fl_server.client_manager
clients = client_manager.clients
return clients.get(token)
[docs] def aux_communicate(self, request: CellMessage) -> CellMessage:
assert isinstance(request, CellMessage), "request must be CellMessage but got {}".format(type(request))
data = request.payload
topic = request.get_header(MessageHeaderKey.TOPIC)
with self.engine.new_context() as fl_ctx:
server_state = self.engine.server.server_state
state_check = server_state.aux_communicate(fl_ctx)
error = self.engine.server.authentication_check(request, state_check)
if error:
make_reply(ReturnCode.AUTHENTICATION_ERROR, error, None)
engine = fl_ctx.get_engine()
reply = engine.dispatch(topic=topic, request=data, fl_ctx=fl_ctx)
self.logger.debug("Before gen_new_peer_ctx")
shared_fl_ctx = gen_new_peer_ctx(fl_ctx)
self.logger.debug("After gen_new_peer_ctx")
reply.set_header(key=FLContextKey.PEER_CONTEXT, value=shared_fl_ctx)
if reply is not None:
return_message = new_cell_message({}, reply)
return_message.set_header(MessageHeaderKey.RETURN_CODE, ReturnCode.OK)
else:
return_message = new_cell_message({}, None)
return return_message
[docs] def shutdown(self):
self.asked_to_stop = True