Source code for nvflare.private.fed.cmi

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod

from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import FLContextKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import ReservedHeaderKey, Shareable, make_reply
from nvflare.fuel.f3.cellnet.core_cell import FQCN, CoreCell, Message, MessageHeaderKey
from nvflare.fuel.f3.cellnet.core_cell import ReturnCode as CellReturnCode
from nvflare.private.defs import CellMessageHeaderKeys


[docs]class CellMessageInterface(FLComponent, ABC): HEADER_KEY_PEER_PROPS = "cmi.peer_props" HEADER_JOB_ID = "cmi.job_id" HEADER_PROJECT_NAME = "cmi.project" HEADER_SSID = "cmi.ssid" HEADER_CLIENT_TOKEN = "cmi.client_token" HEADER_CLIENT_NAME = "cmi.client_name" PROP_KEY_CLIENT = "cmi.client" PROP_KEY_FL_CTX = "cmi.fl_ctx" PROP_KEY_PEER_CTX = "cmi.peer_ctx" RC_TABLE = { CellReturnCode.TIMEOUT: ReturnCode.COMMUNICATION_ERROR, CellReturnCode.COMM_ERROR: ReturnCode.COMMUNICATION_ERROR, CellReturnCode.PROCESS_EXCEPTION: ReturnCode.EXECUTION_EXCEPTION, CellReturnCode.ABORT_RUN: CellReturnCode.ABORT_RUN, CellReturnCode.INVALID_REQUEST: CellReturnCode.INVALID_REQUEST, CellReturnCode.INVALID_SESSION: CellReturnCode.INVALID_SESSION, CellReturnCode.AUTHENTICATION_ERROR: CellReturnCode.UNAUTHENTICATED, CellReturnCode.SERVICE_UNAVAILABLE: CellReturnCode.SERVICE_UNAVAILABLE, } def __init__( self, engine, ): FLComponent.__init__(self) self.engine = engine self.cell = engine.get_cell() self.ready = False self.cell.add_incoming_request_filter(channel="*", topic="*", cb=self._filter_incoming_request) self.cell.add_outgoing_reply_filter(channel="*", topic="*", cb=self._filter_outgoing_message) self.cell.add_outgoing_request_filter(channel="*", topic="*", cb=self._filter_outgoing_message) self.cell.add_incoming_reply_filter(channel="*", topic="*", cb=self._filter_incoming_message)
[docs] def new_cmi_message(self, fl_ctx: FLContext, headers=None, payload=None): msg = Message(headers, payload) msg.set_prop(self.PROP_KEY_FL_CTX, fl_ctx) return msg
def _filter_incoming_message(self, message: Message): public_props = message.get_header(self.HEADER_KEY_PEER_PROPS) if public_props: peer_ctx = self._make_peer_ctx(public_props) message.set_prop(self.PROP_KEY_PEER_CTX, peer_ctx) shareable = message.payload if isinstance(shareable, Shareable): if public_props: shareable.set_peer_props(public_props) def _filter_incoming_request(self, message: Message): self._filter_incoming_message(message) fl_ctx = self.engine.new_context() peer_ctx = message.get_prop(self.PROP_KEY_PEER_CTX) assert isinstance(fl_ctx, FLContext) if peer_ctx: fl_ctx.set_peer_context(peer_ctx) message.set_prop(self.PROP_KEY_FL_CTX, fl_ctx) def _filter_outgoing_message(self, message: Message): fl_ctx = message.get_prop(self.PROP_KEY_FL_CTX) if fl_ctx: assert isinstance(fl_ctx, FLContext) public_props = fl_ctx.get_all_public_props() message.set_header(self.HEADER_KEY_PEER_PROPS, public_props) ssid = fl_ctx.get_prop(CellMessageHeaderKeys.SSID) if ssid: message.set_header(self.HEADER_SSID, ssid) project_name = fl_ctx.get_prop(CellMessageHeaderKeys.PROJECT_NAME) if project_name: message.set_header(self.HEADER_PROJECT_NAME, project_name) client_name = fl_ctx.get_prop(FLContextKey.CLIENT_NAME) if client_name: message.set_header(self.HEADER_CLIENT_NAME, client_name) client_token = fl_ctx.get_prop(CellMessageHeaderKeys.TOKEN) if client_token: message.set_header(self.HEADER_CLIENT_TOKEN, client_token) @staticmethod def _make_peer_ctx(props: dict) -> FLContext: ctx = FLContext() ctx.set_public_props(props) return ctx @staticmethod def _convert_return_code(rc: CellReturnCode): return CellMessageInterface.RC_TABLE.get(rc, ReturnCode.ERROR)
[docs] @abstractmethod def send_to_cell( self, targets: [], channel: str, topic: str, request: Shareable, timeout: float, fl_ctx: FLContext, bulk_send=False, ) -> dict: pass
[docs]class JobCellMessenger(CellMessageInterface): def __init__(self, engine, job_id: str): super().__init__(engine) self.job_id = job_id self.cell.add_incoming_request_filter(channel="*", topic="*", cb=self._filter_incoming) self.cell.add_incoming_reply_filter(channel="*", topic="*", cb=self._filter_incoming) self.cell.add_outgoing_request_filter(channel="*", topic="*", cb=self._filter_outgoing) self.cell.add_outgoing_reply_filter(channel="*", topic="*", cb=self._filter_outgoing) def _filter_incoming(self, message: Message): job_id = message.get_header(self.HEADER_JOB_ID) if job_id and job_id != self.job_id: self.logger.error(f"received job id {job_id} != my job id {self.job_id}") def _filter_outgoing(self, message: Message): message.set_header(self.HEADER_JOB_ID, self.job_id)
[docs] def send_to_cell( self, targets: [], channel: str, topic: str, request: Shareable, timeout: float, fl_ctx: FLContext, bulk_send=False, optional=False, ) -> dict: """Send request to the job cells of other target sites. Args: targets (list): list of client names that the request will be sent to channel (str): channel of the request topic (str): topic of the request request (Shareable): request timeout (float): how long to wait for result. 0 means fire-and-forget fl_ctx (FLContext): the FL context bulk_send: whether to bulk send this request (only applies in the fire-and-forget situation) optional: whether the request is optional Returns: A dict of Shareables """ if not isinstance(request, Shareable): raise ValueError(f"invalid request type: expect Shareable but got {type(request)}") if not targets: raise ValueError("targets must be specified") if targets is not None and not isinstance(targets, list): raise TypeError(f"targets must be a list of str, but got {type(targets)}") if not isinstance(topic, str): raise TypeError(f"invalid topic '{topic}': expects str but got {type(topic)}") if not topic: raise ValueError("invalid topic: must not be empty") if not isinstance(timeout, float): raise TypeError(f"invalid timeout: expects float but got {type(timeout)}") if timeout < 0: raise ValueError(f"invalid timeout value {timeout}: must >= 0.0") if not isinstance(fl_ctx, FLContext): raise TypeError(f"invalid fl_ctx: expects FLContext but got {type(fl_ctx)}") request.set_header(ReservedHeaderKey.TOPIC, topic) job_id = fl_ctx.get_job_id() cell = self.engine.get_cell() assert isinstance(cell, CoreCell) target_names = [] for t in targets: if not isinstance(t, str): raise ValueError(f"invalid target name {t}: expect str but got {type(t)}") if t not in target_names: target_names.append(t) target_fqcns = [] for name in target_names: target_fqcns.append(FQCN.join([name, job_id])) cell_msg = self.new_cmi_message(fl_ctx, payload=request) if timeout > 0: cell_replies = cell.broadcast_request( channel=channel, topic=topic, request=cell_msg, targets=target_fqcns, timeout=timeout, optional=optional ) replies = {} if cell_replies: for k, v in cell_replies.items(): assert isinstance(v, Message) rc = v.get_header(MessageHeaderKey.RETURN_CODE, ReturnCode.OK) client_name = FQCN.get_root(k) if rc == CellReturnCode.OK: result = v.payload if not isinstance(result, Shareable): self.logger.error(f"reply of {channel}:{topic} must be dict but got {type(result)}") result = make_reply(ReturnCode.ERROR) replies[client_name] = result else: src = self._convert_return_code(rc) replies[client_name] = make_reply(src) return replies else: if bulk_send: cell.queue_message(channel=channel, topic=topic, message=cell_msg, targets=target_fqcns) else: cell.fire_and_forget( channel=channel, topic=topic, message=cell_msg, targets=target_fqcns, optional=optional ) return {}