Source code for nvflare.widgets.fed_event

# Copyright (c) 2021-2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
import time

from nvflare.apis.client_engine_spec import ClientEngineSpec
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import EventScope, FedEventHeader, FLContextKey, ReservedKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.server_engine_spec import ServerEngineSpec
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.widgets.widget import Widget

FED_EVENT_TOPIC = "fed.event"


[docs]class FedEventRunner(Widget): def __init__(self, topic=FED_EVENT_TOPIC): """Init FedEventRunner. The FedEventRunner handles posting and receiving of fed events. The system will do its best to fire off all events in the queue before shutdown using the ABOUT_TO_END_RUN event and a grace period during END_RUN. Args: topic: the fed event topic to be handled. Defaults to 'fed.event' """ Widget.__init__(self) self.topic = topic self.abort_signal = None self.asked_to_stop = False self.asked_to_flush = False self.regular_interval = 0.001 self.grace_period = 2 self.flush_wait = 2 self.engine = None self.last_timestamps = {} # client name => last_timestamp self.in_events = [] self.in_lock = threading.Lock() self.poster = threading.Thread(target=self._post, args=())
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.START_RUN: self.engine = fl_ctx.get_engine() self.engine.register_aux_message_handler(topic=self.topic, message_handle_func=self._receive) self.abort_signal = fl_ctx.get_run_abort_signal() self.asked_to_stop = False self.asked_to_flush = False self.poster.start() elif event_type == EventType.ABOUT_TO_END_RUN: self.asked_to_flush = True # delay self.flush_wait seconds so # _post can empty the queue before # END_RUN is fired time.sleep(self.flush_wait) elif event_type == EventType.END_RUN: self.asked_to_stop = True if self.poster.is_alive(): self.poster.join() else: # handle outgoing fed events event_scope = fl_ctx.get_prop(key=FLContextKey.EVENT_SCOPE, default=EventScope.LOCAL) if event_scope != EventScope.FEDERATION: return event_data = fl_ctx.get_prop(FLContextKey.EVENT_DATA, None) if not isinstance(event_data, Shareable): self.log_error(fl_ctx, "bad fed event: expect data to be Shareable but got {}".format(type(event_data))) return direction = event_data.get_header(FedEventHeader.DIRECTION, "out") if direction != "out": # ignore incoming events return event_data.set_header(FedEventHeader.EVENT_TYPE, event_type) event_data.set_header(FedEventHeader.ORIGIN, fl_ctx.get_identity_name()) event_data.set_header(FedEventHeader.TIMESTAMP, time.time()) targets = event_data.get_header(FedEventHeader.TARGETS, None) self.fire_and_forget_request(request=event_data, fl_ctx=fl_ctx, targets=targets)
[docs] def fire_and_forget_request(self, request: Shareable, fl_ctx: FLContext, targets=None): pass
def _receive(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: peer_name = request.get_peer_prop(ReservedKey.IDENTITY_NAME, None) if not peer_name: self.log_error(fl_ctx, "missing identity name of the data sender") return make_reply(ReturnCode.MISSING_PEER_CONTEXT) timestamp = request.get_header(FedEventHeader.TIMESTAMP, None) if timestamp is None: self.log_error(fl_ctx, "missing timestamp in incoming fed event") return make_reply(ReturnCode.BAD_REQUEST_DATA) event_type = request.get_header(FedEventHeader.EVENT_TYPE, None) if event_type is None: self.log_error(fl_ctx, "missing event_type in incoming fed event") return make_reply(ReturnCode.BAD_REQUEST_DATA) with self.in_lock: last_timestamp = self.last_timestamps.get(peer_name, None) if last_timestamp is None or timestamp > last_timestamp: # we only keep new items, in case the peer somehow sent old items request.set_header(FedEventHeader.DIRECTION, "in") self.in_events.append(request) self.last_timestamps[peer_name] = timestamp # NOTE: we do not fire event here since event process could take time. # Instead, we simply add the package to the queue and return quickly. # The posting of events will be handled in the poster thread return make_reply(ReturnCode.OK) def _post(self): """Post an event. During ABOUT_TO_END_RUN, sleep_time is 0 and system will flush in_events by firing events without delay. During END_RUN, system will wait for self.grace_period, even the queue is empty, so any new item can be processed. However, since the system does not guarantee the receiving side of _post is still alive, we catch the exception and show warning messages to users if events can not be handled by receiving side. """ sleep_time = self.regular_interval countdown = self.grace_period while True: time.sleep(sleep_time) if self.abort_signal.triggered: break n = len(self.in_events) if n > 0: if self.asked_to_flush: sleep_time = 0 else: sleep_time = self.regular_interval with self.in_lock: event_to_post = self.in_events.pop(0) elif self.asked_to_stop: # the queue is empty, and we are asked to stop. # wait self.grace_period seconds , then exit. if countdown < 0: break else: countdown = countdown - 1 time.sleep(1) continue else: sleep_time = min(sleep_time * 2, 1) continue with self.engine.new_context() as fl_ctx: if self.asked_to_stop: self.log_warning(fl_ctx, f"{n} items remained in in_events. Will stop when it reaches 0.") fl_ctx.set_prop(key=FLContextKey.EVENT_DATA, value=event_to_post, private=True, sticky=False) fl_ctx.set_prop(key=FLContextKey.EVENT_SCOPE, value=EventScope.FEDERATION, private=True, sticky=False) event_type = event_to_post.get_header(FedEventHeader.EVENT_TYPE) try: self.engine.fire_event(event_type=event_type, fl_ctx=fl_ctx) except BaseException as e: if self.asked_to_stop: self.log_warning(fl_ctx, f"event {event_to_post} fired unsuccessfully during END_RUN") else: raise e
[docs]class ServerFedEventRunner(FedEventRunner): def __init__(self, topic=FED_EVENT_TOPIC): """Init ServerFedEventRunner.""" FedEventRunner.__init__(self, topic)
[docs] def fire_and_forget_request(self, request: Shareable, fl_ctx: FLContext, targets=None): if not isinstance(self.engine, ServerEngineSpec): raise TypeError("self.engine must be ServerEngineSpec but got {}".format(type(self.engine))) self.engine.fire_and_forget_aux_request( topic=self.topic, targets=targets, request=request, fl_ctx=fl_ctx, )
[docs]class ClientFedEventRunner(FedEventRunner): def __init__(self, topic=FED_EVENT_TOPIC): """Init ClientFedEventRunner.""" FedEventRunner.__init__(self, topic) self.ready = False
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext): super().handle_event(event_type, fl_ctx) if event_type == EventType.START_RUN: self.ready = True
[docs] def fire_and_forget_request(self, request: Shareable, fl_ctx: FLContext, targets=None): if not self.ready: self.log_warning(fl_ctx, "Engine in not ready, skip the fed event firing.") return if not isinstance(self.engine, ClientEngineSpec): raise TypeError("self.engine must be ClientEngineSpec but got {}".format(type(self.engine))) self.engine.fire_and_forget_aux_request(topic=self.topic, request=request, fl_ctx=fl_ctx)