Source code for nvflare.app_common.streamers.object_retriever

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time
import uuid
from abc import ABC, abstractmethod
from typing import Any

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import ReturnCode, Shareable, make_reply
from nvflare.apis.streaming import StreamContext
from nvflare.security.logging import secure_format_exception

RETRIEVER_TX_ID = "_rtr_tx_id_"

_SHORT_WAIT = 0.1


class _Waiter(threading.Event):
    def __init__(self):
        super().__init__()
        self.result = None

    def set_result(self, rc: str, data: Any):
        self.result = (rc, data)


[docs] class ObjectRetriever(FLComponent, ABC): """This is the base class for object retrieval with streaming. The retrieval works as follows: - The requesting site initiates the process by sending a data request to the site that has the data; - The requesting site then waits for the data to be completely received; - Once the data request is received, the data owner site streams the data to the requesting site; - During the streaming process, the requesting site keeps checking for the completion of the streaming until either the data is completely received, or timed out, or aborted. """ def __init__( self, topic: str = None, ): FLComponent.__init__(self) class_name = self.__class__.__name__ if not topic: topic = class_name self.topic = topic self.stream_channel = class_name self.tx_table = {}
[docs] @abstractmethod def register_stream_processing( self, channel: str, topic: str, fl_ctx: FLContext, stream_done_cb, **cb_kwargs, ): """Object requester side, which will receive data stream. This is called to register the status_cb for received stream. Args: channel: stream channel topic: stream topic fl_ctx: FLContext object stream_done_cb: the stream_done callback to be registered **cb_kwargs: kwargs to be passed to the CB Returns: """ pass
[docs] @abstractmethod def validate_request(self, request: Shareable, fl_ctx: FLContext) -> (str, Any): """Object sending side. Called to validate the received retrieval request. Args: request: the request to be validated fl_ctx: FLContext object Returns: tuple of (ReturnCode, Validation Data) This method should do as much as possible so that the do_stream method won't be called if any error is detected (the do_stream method is called in a separate thread). The validation data produced by this method will be passed to the do_stream method. """ pass
[docs] @abstractmethod def do_stream( self, target: str, request: Shareable, fl_ctx: FLContext, stream_ctx: StreamContext, validation_data: Any, ) -> Any: """Object sending side. Called to stream data to the requesting side. Args: target: the requesting site to stream to request: the object retrieval request fl_ctx: a FLContext object stream_ctx: stream context data validation_data: the validation data produced by the validate_request method. Returns: Any object """ pass
[docs] @abstractmethod def get_result(self, stream_ctx: StreamContext) -> (str, Any): """Object requesting side, which is also the stream receiving side. Called to get the result of the streaming. Args: stream_ctx: StreamContext object Returns: tuple of (ReturnCode, Result Object) """ pass
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.START_RUN: engine = fl_ctx.get_engine() self.register_stream_processing( fl_ctx=fl_ctx, channel=self.stream_channel, topic=self.topic, stream_done_cb=self._handle_stream_done, ) engine.register_aux_message_handler(topic=self.topic, message_handle_func=self._handle_request)
[docs] def retrieve(self, from_site: str, fl_ctx: FLContext, timeout: float, **obj_attrs) -> (str, Any): """Retrieve an object from a specified site. Args: from_site: the site to retrieve the object from fl_ctx: a FLContext object timeout: max number of seconds to wait for the data **obj_attrs: attributes of the object to be retrieved Returns: tuple of (ReturnCode, Retrieved Data) """ engine = fl_ctx.get_engine() waiter = _Waiter() tx_id = str(uuid.uuid4()) self.tx_table[tx_id] = waiter self.log_debug(fl_ctx, f"set waiter for Rtr {tx_id}") try: request = Shareable({RETRIEVER_TX_ID: tx_id}) if obj_attrs: request.update(obj_attrs) # ask the site to start streaming replies = engine.send_aux_request( targets=[from_site], request=request, topic=self.topic, fl_ctx=fl_ctx, timeout=timeout ) # the 'replies' is a dict keyed with site names! reply = replies.get(from_site) # now the reply is a Shareable object if not isinstance(reply, Shareable): self.log_error(fl_ctx, f"bad reply from site {from_site}: expect Shareable but got {type(reply)}") return ReturnCode.EXECUTION_EXCEPTION, None rc = reply.get_return_code() if rc != ReturnCode.OK: self.log_error(fl_ctx, f"retrieval request rejected by site {from_site}: {rc}") return rc, None # wait for result until either the result is received or progress timed out rc = ReturnCode.OK abort_signal = fl_ctx.get_run_abort_signal() start_time = time.time() while True: # wait a short time so that we can check other conditions if not waiter.wait(_SHORT_WAIT): # see whether we have any progress if time.time() - start_time > timeout: # no progress for too long self.log_error(fl_ctx, f"stream data not completed in {timeout} seconds") rc = ReturnCode.TIMEOUT break if abort_signal and abort_signal.triggered: rc = ReturnCode.TASK_ABORTED break else: # result available! break except Exception as ex: self.log_error(fl_ctx, f"exception occurred during retrieval: {secure_format_exception(ex)}") rc = ReturnCode.EXECUTION_EXCEPTION self.tx_table.pop(tx_id, None) self.log_debug(fl_ctx, f"popped waiter for RTR {tx_id}") if waiter.result: # If the waiter already got result, we return it. # Note that due to racing condition, it is possible that the waiter still got the result # even after we determined the streaming is timed out! return waiter.result else: return rc, None
def _handle_stream_done(self, stream_ctx: StreamContext, fl_ctx: FLContext): # On stream receiving side, which is also the requesting side tx_id = stream_ctx.get(RETRIEVER_TX_ID) waiter = self.tx_table.get(tx_id) if not waiter: self.log_error(fl_ctx, f"late stream completion {tx_id=} after timed out") return try: result = self.get_result(stream_ctx) except Exception as ex: self.log_error(fl_ctx, f"Exception when get_result: {secure_format_exception(ex)}") result = (ReturnCode.EXECUTION_EXCEPTION, None) waiter.result = result waiter.set() self.log_info(fl_ctx, f"got result for RTR {tx_id}: {type(waiter.result)}") def _handle_request(self, topic, request: Shareable, fl_ctx: FLContext) -> Shareable: # On request receiving side, which is also stream sending side. tx_id = request.get(RETRIEVER_TX_ID) if not tx_id: self.log_error(fl_ctx, f"bad request '{topic}': missing {RETRIEVER_TX_ID}") return make_reply(ReturnCode.BAD_REQUEST_DATA) peer_ctx = fl_ctx.get_peer_context() if not peer_ctx: self.log_error(fl_ctx, f"bad request '{topic}': missing peer context") return make_reply(ReturnCode.BAD_REQUEST_DATA) if not isinstance(peer_ctx, FLContext): self.log_error(fl_ctx, f"bad request '{topic}': bad peer context ({type(peer_ctx)})") return make_reply(ReturnCode.BAD_REQUEST_DATA) peer = peer_ctx.get_identity_name() if not peer: self.log_error(fl_ctx, f"bad request '{topic}': missing peer name") return make_reply(ReturnCode.BAD_REQUEST_DATA) # validate the request before starting stream try: rc, validated_data = self.validate_request(request, fl_ctx) if rc and rc != ReturnCode.OK: self.log_error(fl_ctx, f"bad request '{topic}': failed validation ({rc})") return make_reply(rc) except Exception as ex: self.log_error(fl_ctx, f"exception validating request: {secure_format_exception(ex)}") return make_reply(ReturnCode.EXECUTION_EXCEPTION) # start the streaming in a separate thread so that we can respond to the requestor. self.log_debug(fl_ctx, "About to start streaming ...") t = threading.Thread(target=self._do_stream, args=(request, fl_ctx, validated_data), daemon=True) t.start() return make_reply(ReturnCode.OK) def _do_stream(self, request: Shareable, fl_ctx: FLContext, validated_data: Any): # On request receiving side, which is also stream sending side. tx_id = request.get(RETRIEVER_TX_ID) self.log_debug(fl_ctx, f"Started streaming for RTR Request {tx_id}") stream_ctx = {RETRIEVER_TX_ID: tx_id} peer_ctx = fl_ctx.get_peer_context() peer = peer_ctx.get_identity_name() try: # start streaming object to the peer result = self.do_stream(peer, request, fl_ctx, stream_ctx, validated_data) self.log_info(fl_ctx, f"finished streaming for RTR {tx_id}: {result=}") except Exception as ex: self.log_error(fl_ctx, f"streaming exception occurred: {secure_format_exception(ex)}")