Source code for nvflare.app_common.streamers.container_retriever

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import ReturnCode, Shareable
from nvflare.apis.streaming import StreamContext

from .container_streamer import ContainerStreamer
from .object_retriever import ObjectRetriever


[docs] class ContainerRetriever(ObjectRetriever): def __init__( self, topic: str = None, stream_msg_optional=False, stream_msg_secure=False, entry_timeout=None, ): ObjectRetriever.__init__(self, topic) self.stream_msg_optional = stream_msg_optional self.stream_msg_secure = stream_msg_secure self.entry_timeout = entry_timeout self.containers = {}
[docs] def add_container(self, name: str, container: Any): """Add a container to the retriever. This must be called on the sending side Args: name: name for the container. container: The container to be streamed """ self.containers[name] = container
[docs] def register_stream_processing( self, channel: str, topic: str, fl_ctx: FLContext, stream_done_cb, **cb_kwargs, ): """Called on the stream sending side. Args: channel: topic: fl_ctx: stream_done_cb: **cb_kwargs: Returns: """ ContainerStreamer.register_stream_processing( channel=channel, topic=topic, fl_ctx=fl_ctx, stream_done_cb=stream_done_cb, **cb_kwargs, )
[docs] def validate_request(self, request: Shareable, fl_ctx: FLContext) -> (str, Any): name = request.get("name") if not name: self.log_error(fl_ctx, "bad request: missing container name") return ReturnCode.BAD_REQUEST_DATA, None container = self.containers.get(name, None) if not container: self.log_error(fl_ctx, f"bad request: requested container {name} doesn't exist") return ReturnCode.BAD_REQUEST_DATA, None return ReturnCode.OK, container
[docs] def retrieve_container(self, from_site: str, fl_ctx: FLContext, timeout: float, name: str) -> (str, Any): """Retrieve a container from the specified site. This method is to be called by the app. Args: from_site: the site that has the container to be retrieved fl_ctx: FLContext object timeout: how long to wait for the file name: name of the container Returns: a tuple of (ReturnCode, container) """ return self.retrieve(from_site=from_site, fl_ctx=fl_ctx, timeout=timeout, name=name)
[docs] def do_stream( self, target: str, request: Shareable, fl_ctx: FLContext, stream_ctx: StreamContext, validated_data: Any ): """Stream the container to the peer. Called on the stream sending side. Args: target: the receiving site request: data to be sent fl_ctx: FLContext object stream_ctx: the stream context validated_data: the file full path returned from the validate_request method Returns: """ ContainerStreamer.stream_container( targets=[target], stream_ctx=stream_ctx, channel=self.stream_channel, topic=self.topic, container=validated_data, fl_ctx=fl_ctx, optional=self.stream_msg_optional, secure=self.stream_msg_secure, )
[docs] def get_result(self, stream_ctx: StreamContext) -> (str, Any): """Called on the stream receiving side. Get the final result of the streaming. The result is the location of the received file. Args: stream_ctx: the StreamContext Returns: """ return ContainerStreamer.get_rc(stream_ctx), ContainerStreamer.get_result(stream_ctx)