Source code for nvflare.fuel.f3.streaming.tools.receiver

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time

from nvflare.fuel.f3.cellnet.core_cell import CoreCell
from nvflare.fuel.f3.stream_cell import StreamCell
from nvflare.fuel.f3.streaming.stream_types import StreamFuture
from nvflare.fuel.f3.streaming.tools.utils import BUF_SIZE, RX_CELL, TEST_CHANNEL, TEST_TOPIC, make_buffer, setup_log


[docs]class Receiver: """Test BLOB receiving""" def __init__(self, listening_url: str): cell = CoreCell(RX_CELL, listening_url, secure=False, credentials={}) cell.start() self.stream_cell = StreamCell(cell) self.stream_cell.register_blob_cb(TEST_CHANNEL, TEST_TOPIC, self.blob_cb) self.futures = {}
[docs] def get_futures(self) -> dict: return self.futures
[docs] def blob_cb(self, stream_future: StreamFuture, *args, **kwargs): sid = stream_future.get_stream_id() print(f"Stream {sid} received") self.futures[sid] = stream_future
if __name__ == "__main__": setup_log(logging.INFO) url = "tcp://localhost:1234" receiver = Receiver(url) time.sleep(2) result = None last = 0 while True: if receiver.get_futures: for sid, fut in receiver.get_futures().items(): if fut.done(): result = fut.result() break else: progress = fut.get_progress() print(f"{sid} Progress: {progress} Delta:{progress - last}") last = progress time.sleep(1) if result: break print("Recreating buffer ...") start = time.time() buffer = make_buffer(BUF_SIZE) print(f"Buffer done, took {time.time()-start} seconds") if buffer == result: print("Result is correct") else: print("Result is wrong")