Source code for nvflare.app_common.streamers.container_streamer

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Tuple

from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import ReturnCode, Shareable, make_reply
from nvflare.apis.streaming import ConsumerFactory, ObjectConsumer, ObjectProducer, StreamableEngine, StreamContext
from nvflare.app_common.streamers.streamer_base import StreamerBase
from nvflare.fuel.utils.class_loader import get_class_name, load_class
from nvflare.fuel.utils.log_utils import get_obj_logger
from nvflare.fuel.utils.validation_utils import check_positive_number

_PREFIX = "ContainerStreamer."

# Keys for StreamCtx
_CTX_TYPE = _PREFIX + "type"
_CTX_SIZE = _PREFIX + "size"
_CTX_RESULT = _PREFIX + "result"

# Keys for Shareable
_KEY_ENTRY = _PREFIX + "entry"
_KEY_LAST = _PREFIX + "last"


class _EntryConsumer(ObjectConsumer):
    def __init__(self, stream_ctx: StreamContext):
        self.logger = get_obj_logger(self)
        container_type = stream_ctx.get(_CTX_TYPE)
        container_class = load_class(container_type)
        self.container = container_class()
        self.size = stream_ctx.get(_CTX_SIZE)

    def consume(
        self,
        shareable: Shareable,
        stream_ctx: StreamContext,
        fl_ctx: FLContext,
    ) -> Tuple[bool, Shareable]:

        entry = shareable.get(_KEY_ENTRY)
        try:
            if isinstance(self.container, dict):
                key, value = entry
                self.container[key] = value
            elif isinstance(self.container, set):
                self.container.add(entry)
            else:
                self.container.append(entry)
        except Exception as ex:
            error = f"Unable to add entry ({type(entry)} to container ({type(self.container)}"
            self.logger.error(error)
            raise ValueError(error)

        last = shareable.get(_KEY_LAST)
        if last:
            # Check if all entries are added
            if self.size != len(self.container):
                err = f"Container size {len(self.container)} does not match expected size {self.size}"
                self.logger.error(err)
                raise ValueError(err)
            else:
                stream_ctx[_CTX_RESULT] = self.container
                return False, make_reply(ReturnCode.OK)
        else:
            # continue streaming
            return True, make_reply(ReturnCode.OK)

    def finalize(self, stream_ctx: StreamContext, fl_ctx: FLContext):
        self.logger.debug(f"Container streaming is done for container type {type(self.container)}")


class _EntryConsumerFactory(ConsumerFactory):
    def get_consumer(self, stream_ctx: StreamContext, fl_ctx: FLContext) -> ObjectConsumer:
        return _EntryConsumer(stream_ctx)


class _EntryProducer(ObjectProducer):
    def __init__(self, container, entry_timeout):
        self.logger = get_obj_logger(self)
        if not container:
            error = "Can't stream empty container"
            self.logger.error(error)
            raise ValueError(error)

        self.container = container
        if isinstance(container, dict):
            self.iterator = iter(container.items())
        else:
            self.iterator = iter(container)
        self.size = len(container)
        self.count = 0
        self.last = False
        self.entry_timeout = entry_timeout

    def produce(
        self,
        stream_ctx: StreamContext,
        fl_ctx: FLContext,
    ) -> Tuple[Shareable, float]:

        try:
            entry = next(self.iterator)
            self.count += 1
            self.last = self.count >= self.size
        except StopIteration:
            self.logger.error(f"Producer called too many times {self.count}/{self.size}")
            self.last = True
            return None, 0.0

        result = Shareable()
        result[_KEY_ENTRY] = entry
        result[_KEY_LAST] = self.last
        return result, self.entry_timeout

    def process_replies(
        self,
        replies: Dict[str, Shareable],
        stream_ctx: StreamContext,
        fl_ctx: FLContext,
    ) -> Any:
        has_error = False
        for target, reply in replies.items():
            rc = reply.get_return_code(ReturnCode.OK)
            if rc != ReturnCode.OK:
                self.logger.error(f"error from target {target}: {rc}")
                has_error = True

        if has_error:
            # done - failed
            return False
        elif self.last:
            # done - succeeded
            return True
        else:
            # not done yet - continue streaming
            return None


[docs] class ContainerStreamer(StreamerBase):
[docs] @staticmethod def register_stream_processing( fl_ctx: FLContext, channel: str, topic: str, stream_done_cb=None, **cb_kwargs, ): """Register for stream processing on the receiving side. Args: fl_ctx: the FLContext object channel: the app channel topic: the app topic stream_done_cb: if specified, the callback to be called when the file is completely received **cb_kwargs: the kwargs for the stream_done_cb Returns: None Notes: the stream_done_cb must follow stream_done_cb_signature as defined in apis.streaming. """ engine = fl_ctx.get_engine() if not isinstance(engine, StreamableEngine): raise RuntimeError(f"engine must be StreamableEngine but got {type(engine)}") engine.register_stream_processing( channel=channel, topic=topic, factory=_EntryConsumerFactory(), stream_done_cb=stream_done_cb, **cb_kwargs, )
[docs] @staticmethod def stream_container( channel: str, topic: str, stream_ctx: StreamContext, targets: List[str], container: Any, fl_ctx: FLContext, entry_timeout=None, optional=False, secure=False, ) -> bool: """Stream a file to one or more targets. Args: channel: the app channel topic: the app topic stream_ctx: context data of the stream targets: targets that the file will be sent to container: container to be streamed fl_ctx: a FLContext object entry_timeout: timeout for each entry sent to targets. optional: whether the file is optional secure: whether P2P security is required Returns: whether the streaming completed successfully Notes: this is a blocking call - only returns after the streaming is done. """ if not entry_timeout: entry_timeout = 60.0 check_positive_number("entry_timeout", entry_timeout) producer = _EntryProducer(container, entry_timeout) engine = fl_ctx.get_engine() if not isinstance(engine, StreamableEngine): raise RuntimeError(f"engine must be StreamableEngine but got {type(engine)}") if not stream_ctx: stream_ctx = {} stream_ctx[_CTX_TYPE] = get_class_name(type(container)) stream_ctx[_CTX_SIZE] = len(container) return engine.stream_objects( channel=channel, topic=topic, stream_ctx=stream_ctx, targets=targets, producer=producer, fl_ctx=fl_ctx, optional=optional, secure=secure, )
[docs] @staticmethod def get_result(stream_ctx: StreamContext) -> Any: """Get the received container This method is intended to be used by the stream_done_cb() function of the receiving side. Args: stream_ctx: the stream context Returns: The received container """ return stream_ctx.get(_CTX_RESULT)