Source code for nvflare.fuel.data_event.data_bus

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, List

from nvflare.fuel.data_event.pub_sub import EventPubSub


[docs]class DataBus(EventPubSub): """ Singleton class for a simple data bus implementation. This class allows components to subscribe to topics, publish messages to topics, and store/retrieve messages associated with specific keys and topics. """ _instance = None _lock = threading.Lock() def __new__(cls) -> "DataBus": """ Create a new instance of the DataBus class. This method ensures that only one instance of the class is created (singleton pattern). The databus """ with cls._lock: if not cls._instance: cls._instance = super(DataBus, cls).__new__(cls) cls._instance.subscribers = {} cls._instance.data_store = {} return cls._instance
[docs] def subscribe(self, topics: List[str], callback: Callable[[str, Any, "DataBus"], None]) -> None: """ Subscribe a callback function to one or more topics. Args: topics (List[str]): A list of topics to subscribe to. callback (Callable): The callback function to be called when messages are published to the subscribed topics. """ if not topics: raise ValueError("topics must non-empty") for topic in topics: if topic.isspace(): raise ValueError(f"topics {topics}contains white space topic") with self._lock: if topic not in self.subscribers: self.subscribers[topic] = [] self.subscribers[topic].append(callback)
[docs] def publish(self, topics: List[str], datum: Any) -> None: """ Publish a data to one or more topics, notifying all subscribed callbacks. Args: topics (List[str]): A list of topics to publish the data to. datum (Any): The data to be published to the specified topics. """ if topics: for topic in topics: if topic in self.subscribers: with self._lock: executor = ThreadPoolExecutor(max_workers=len(self.subscribers[topic])) for callback in self.subscribers[topic]: executor.submit(callback, topic, datum, self) executor.shutdown()
[docs] def put_data(self, key: Any, datum: Any) -> None: """ Store a data associated with a key and topic. Args: key (Any): The key to associate with the stored message. datum (Any): The message to be stored. """ with self._lock: self.data_store[key] = datum
[docs] def get_data(self, key: Any) -> Any: """ Retrieve a stored data associated with a key and topic. Args: key (Any): The key associated with the stored message. Returns: Any: The stored datum if found, or None if not found. """ return self.data_store.get(key)