Source code for nvflare.apis.streaming

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from builtins import dict as StreamContext
from typing import Any, Dict, List, Tuple

from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable


[docs] class StreamContextKey: CHANNEL = "__channel__" TOPIC = "__topic__" RC = "__RC__"
[docs] class ObjectProducer(ABC):
[docs] @abstractmethod def produce( self, stream_ctx: StreamContext, fl_ctx: FLContext, ) -> Tuple[Shareable, float]: """Called to produce the next Shareable object to be sent. If this method needs to take long time, it should check the abort_signal in the fl_ctx frequently. If aborted it should return immediately. You can get the abort_signal by calling fl_ctx.get_run_abort_signal(). Args: stream_ctx: stream context data fl_ctx: The FLContext object Returns: a tuple of (Shareable object to be sent, timeout for sending this object) """ pass
[docs] @abstractmethod def process_replies( self, replies: Dict[str, Shareable], stream_ctx: StreamContext, fl_ctx: FLContext, ) -> Any: """Called to process replies from receivers of the last Shareable object sent to them. Args: replies: replies from receivers. It's dict of site_name => reply stream_ctx: stream context data fl_ctx: the FLContext object Returns: Any object or None If None is returned, the streaming will continue; otherwise the streaming stops and the returned object is returned as the final result of the streaming. """ pass
[docs] class ObjectConsumer(ABC):
[docs] @abstractmethod def consume( self, shareable: Shareable, stream_ctx: StreamContext, fl_ctx: FLContext, ) -> Tuple[bool, Shareable]: """Consume the received Shareable object in the stream. Args: stream_ctx: the stream context data. shareable: the Shareable object to be processed fl_ctx: the FLContext object Returns: a tuple of (whether to continue streaming, reply message) Note: the channel and topic here are defined by the app. They are not the regular message headers (CHANNEL and TOPIC) defined in MessageHeaderKey. """ pass
[docs] def finalize( self, stream_ctx: StreamContext, fl_ctx: FLContext, ): """Called to finalize the generator. Args: stream_ctx: stream context fl_ctx: the FLContext object Returns: None This method is guaranteed to be called at the end of streaming. """ pass
[docs] class ConsumerFactory(ABC):
[docs] @abstractmethod def get_consumer( self, stream_ctx: StreamContext, fl_ctx: FLContext, ) -> ObjectConsumer: """Called to get an ObjectConsumer to process a new stream on the receiving side. This is called only when the 1st streaming object is received for each stream. Args: stream_ctx: the context of the stream fl_ctx: FLContext object Returns: an ObjectConsumer """ pass
[docs] def return_consumer( self, consumer: ObjectConsumer, stream_ctx: StreamContext, fl_ctx: FLContext, ): """Return the consumer back to the factory after a stream is finished on the receiving side. Args: consumer: the consumer to be returned stream_ctx: context of the stream fl_ctx: FLContext object Returns: None """ pass
[docs] def stream_done_cb_signature(stream_ctx: StreamContext, fl_ctx: FLContext, **kwargs): """This is the signature of stream_done_cb. Args: stream_ctx: context of the stream fl_ctx: FLContext object **kwargs: the kwargs specified when registering the stream_done_cb. Returns: None """ pass
[docs] class StreamableEngine(ABC): """This class defines requirements for streaming capable engines."""
[docs] @abstractmethod def stream_objects( self, channel: str, topic: str, stream_ctx: StreamContext, targets: List[str], producer: ObjectProducer, fl_ctx: FLContext, optional=False, secure=False, ): """Send a stream of Shareable objects to receivers. Args: channel: the channel for this stream topic: topic of the stream stream_ctx: context of the stream targets: receiving sites producer: the ObjectProducer that can produces the stream of Shareable objects fl_ctx: the FLContext object optional: whether the stream is optional secure: whether to use P2P security Returns: result from the generator's reply processing """ pass
[docs] @abstractmethod def register_stream_processing( self, channel: str, topic: str, factory: ConsumerFactory, stream_done_cb=None, **cb_kwargs, ): """Register a ConsumerFactory for specified app channel and topic. Once a new streaming request is received for the channel/topic, the registered factory will be used to create an ObjectConsumer object to handle the new stream. Note: the factory should generate a new ObjectConsumer every time get_consumer() is called. This is because multiple streaming sessions could be going on at the same time. Each streaming session should have its own ObjectConsumer. Args: channel: app channel topic: app topic factory: the factory to be registered stream_done_cb: the callback to be called when streaming is done on receiving side Returns: None """ pass
[docs] @abstractmethod def shutdown_streamer(self): """Shutdown the engine's streamer. Returns: None """ pass