# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import concurrent.futures
import logging
import threading
import time
import uuid
from nvflare.apis.fl_constant import ConfigVarName, SystemConfigs
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import ReservedHeaderKey, ReturnCode, Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.apis.utils.fl_context_utils import generate_log_message
from nvflare.fuel.utils.config_service import ConfigService
from nvflare.fuel.utils.validation_utils import check_positive_number
from nvflare.security.logging import secure_format_exception, secure_format_traceback
# Operation Types
OP_REQUEST = "req"
OP_QUERY = "query"
OP_REPLY = "reply"
# Reliable Message headers
HEADER_OP = "rm.op"
HEADER_TOPIC = "rm.topic"
HEADER_TX_ID = "rm.tx_id"
HEADER_PER_MSG_TIMEOUT = "rm.per_msg_timeout"
HEADER_TX_TIMEOUT = "rm.tx_timeout"
HEADER_STATUS = "rm.status"
# Status
STATUS_IN_PROCESS = "in_process"
STATUS_IN_REPLY = "in_reply"
STATUS_NOT_RECEIVED = "not_received"
STATUS_REPLIED = "replied"
STATUS_ABORTED = "aborted"
# Topics for Reliable Message
TOPIC_RELIABLE_REQUEST = "RM.RELIABLE_REQUEST"
TOPIC_RELIABLE_REPLY = "RM.RELIABLE_REPLY"
PROP_KEY_TX_ID = "RM.TX_ID"
PROP_KEY_TOPIC = "RM.TOPIC"
PROP_KEY_OP = "RM.OP"
def _extract_result(reply: dict, target: str):
err_rc = ReturnCode.COMMUNICATION_ERROR
if not isinstance(reply, dict):
return make_reply(err_rc), err_rc
result = reply.get(target)
if not result:
return make_reply(err_rc), err_rc
return result, result.get_return_code()
def _status_reply(status: str):
return make_reply(rc=ReturnCode.OK, headers={HEADER_STATUS: status})
def _error_reply(rc: str, error: str):
return make_reply(rc, headers={ReservedHeaderKey.ERROR: error})
class _RequestReceiver:
"""This class handles reliable message request on the receiving end"""
def __init__(self, topic, request_handler_f, executor, per_msg_timeout, tx_timeout):
"""The constructor
Args:
topic: The topic of the reliable message
request_handler_f: The callback function to handle the request in the form of
request_handler_f(topic: str, request: Shareable, fl_ctx:FLContext)
executor: A ThreadPoolExecutor
"""
self.topic = topic
self.request_handler_f = request_handler_f
self.executor = executor
self.per_msg_timeout = per_msg_timeout
self.tx_timeout = tx_timeout
self.rcv_time = None
self.result = None
self.source = None
self.tx_id = None
self.reply_time = None
def process(self, request: Shareable, fl_ctx: FLContext) -> Shareable:
self.tx_id = request.get_header(HEADER_TX_ID)
op = request.get_header(HEADER_OP)
peer_ctx = fl_ctx.get_peer_context()
assert isinstance(peer_ctx, FLContext)
self.source = peer_ctx.get_identity_name()
if op == OP_REQUEST:
# it is possible that a new request for the same tx is received while we are processing the previous one
if not self.rcv_time:
self.rcv_time = time.time()
self.per_msg_timeout = request.get_header(HEADER_PER_MSG_TIMEOUT)
self.tx_timeout = request.get_header(HEADER_TX_TIMEOUT)
# start processing
ReliableMessage.info(fl_ctx, f"started processing request of topic {self.topic}")
self.executor.submit(self._do_request, request, fl_ctx)
return _status_reply(STATUS_IN_PROCESS) # ack
elif self.result:
# we already finished processing - send the result back
ReliableMessage.info(fl_ctx, "resend result back to requester")
return self.result
else:
# we are still processing
ReliableMessage.info(fl_ctx, "got request - the request is being processed")
return _status_reply(STATUS_IN_PROCESS)
elif op == OP_QUERY:
if self.result:
if self.reply_time:
# result already sent back successfully
ReliableMessage.info(fl_ctx, "got query: we already replied successfully")
return _status_reply(STATUS_REPLIED)
elif self.replying:
# result is being sent
ReliableMessage.info(fl_ctx, "got query: reply is being sent")
return _status_reply(STATUS_IN_REPLY)
else:
# try to send the result again
ReliableMessage.info(fl_ctx, "got query: sending reply again")
return self.result
else:
# still in process
if time.time() - self.rcv_time > self.tx_timeout:
# the process is taking too much time
ReliableMessage.error(fl_ctx, f"aborting processing since exceeded max tx time {self.tx_timeout}")
return _status_reply(STATUS_ABORTED)
else:
ReliableMessage.info(fl_ctx, "got query: request is in-process")
return _status_reply(STATUS_IN_PROCESS)
def _try_reply(self, fl_ctx: FLContext):
engine = fl_ctx.get_engine()
self.replying = True
start_time = time.time()
ReliableMessage.info(fl_ctx, f"try to send reply back to {self.source}: {self.per_msg_timeout=}")
ack = engine.send_aux_request(
targets=[self.source],
topic=TOPIC_RELIABLE_REPLY,
request=self.result,
timeout=self.per_msg_timeout,
fl_ctx=fl_ctx,
)
time_spent = time.time() - start_time
self.replying = False
_, rc = _extract_result(ack, self.source)
if rc == ReturnCode.OK:
# reply sent successfully!
self.reply_time = time.time()
ReliableMessage.info(fl_ctx, f"sent reply successfully in {time_spent} secs")
else:
ReliableMessage.error(
fl_ctx, f"failed to send reply in {time_spent} secs: {rc=}; will wait for requester to query"
)
def _do_request(self, request: Shareable, fl_ctx: FLContext):
start_time = time.time()
ReliableMessage.info(fl_ctx, "invoking request handler")
try:
result = self.request_handler_f(self.topic, request, fl_ctx)
except Exception as e:
ReliableMessage.error(fl_ctx, f"exception processing request: {secure_format_traceback()}")
result = _error_reply(ReturnCode.EXECUTION_EXCEPTION, secure_format_exception(e))
# send back
result.set_header(HEADER_TX_ID, self.tx_id)
result.set_header(HEADER_OP, OP_REPLY)
result.set_header(HEADER_TOPIC, self.topic)
self.result = result
ReliableMessage.info(fl_ctx, f"finished request handler in {time.time()-start_time} secs")
self._try_reply(fl_ctx)
class _ReplyReceiver:
def __init__(self, tx_id: str, per_msg_timeout: float, tx_timeout: float):
self.tx_id = tx_id
self.tx_start_time = time.time()
self.tx_timeout = tx_timeout
self.per_msg_timeout = per_msg_timeout
self.result = None
self.result_ready = threading.Event()
def process(self, reply: Shareable) -> Shareable:
self.result = reply
self.result_ready.set()
return make_reply(ReturnCode.OK)
[docs]class ReliableMessage:
_topic_to_handle = {}
_req_receivers = {} # tx id => receiver
_enabled = False
_executor = None
_query_interval = 1.0
_max_retries = 5
_reply_receivers = {} # tx id => receiver
_tx_lock = threading.Lock()
_shutdown_asked = False
_logger = logging.getLogger("ReliableMessage")
[docs] @classmethod
def register_request_handler(cls, topic: str, handler_f):
"""Register a handler for the reliable message with this topic
Args:
topic: The topic of the reliable message
handler_f: The callback function to handle the request in the form of
handler_f(topic, request, fl_ctx)
"""
if not cls._enabled:
raise RuntimeError("ReliableMessage is not enabled. Please call ReliableMessage.enable() to enable it")
if not callable(handler_f):
raise TypeError(f"handler_f must be callable but {type(handler_f)}")
cls._topic_to_handle[topic] = handler_f
@classmethod
def _get_or_create_receiver(cls, topic: str, request: Shareable, handler_f) -> _RequestReceiver:
tx_id = request.get_header(HEADER_TX_ID)
if not tx_id:
raise RuntimeError("missing tx_id in request")
with cls._tx_lock:
receiver = cls._req_receivers.get(tx_id)
if not receiver:
per_msg_timeout = request.get_header(HEADER_PER_MSG_TIMEOUT)
if not per_msg_timeout:
raise RuntimeError("missing per_msg_timeout in request")
tx_timeout = request.get_header(HEADER_TX_TIMEOUT)
if not tx_timeout:
raise RuntimeError("missing tx_timeout in request")
receiver = _RequestReceiver(topic, handler_f, cls._executor, per_msg_timeout, tx_timeout)
cls._req_receivers[tx_id] = receiver
return receiver
@classmethod
def _receive_request(cls, topic: str, request: Shareable, fl_ctx: FLContext):
tx_id = request.get_header(HEADER_TX_ID)
op = request.get_header(HEADER_OP)
rm_topic = request.get_header(HEADER_TOPIC)
fl_ctx.set_prop(key=PROP_KEY_TX_ID, value=tx_id, sticky=False, private=True)
fl_ctx.set_prop(key=PROP_KEY_OP, value=op, sticky=False, private=True)
fl_ctx.set_prop(key=PROP_KEY_TOPIC, value=rm_topic, sticky=False, private=True)
cls.debug(fl_ctx, f"received aux msg ({topic=}) for RM request")
if op == OP_REQUEST:
handler_f = cls._topic_to_handle.get(rm_topic)
if not handler_f:
# no handler registered for this topic!
cls.error(fl_ctx, f"no handler registered for request {rm_topic=}")
return make_reply(ReturnCode.TOPIC_UNKNOWN)
receiver = cls._get_or_create_receiver(rm_topic, request, handler_f)
cls.info(fl_ctx, f"received request {rm_topic=}")
return receiver.process(request, fl_ctx)
elif op == OP_QUERY:
receiver = cls._req_receivers.get(tx_id)
if not receiver:
cls.error(fl_ctx, f"received query but the request ({rm_topic=}) is not received!")
return _status_reply(STATUS_NOT_RECEIVED) # meaning the request wasn't received
else:
return receiver.process(request, fl_ctx)
else:
cls.error(fl_ctx, f"received invalid op {op} for the request ({rm_topic=})")
return make_reply(rc=ReturnCode.BAD_REQUEST_DATA)
@classmethod
def _receive_reply(cls, topic: str, request: Shareable, fl_ctx: FLContext):
tx_id = request.get_header(HEADER_TX_ID)
fl_ctx.set_prop(key=PROP_KEY_TX_ID, value=tx_id, private=True, sticky=False)
cls.debug(fl_ctx, f"received aux msg ({topic=}) for RM reply")
receiver = cls._reply_receivers.get(tx_id)
if not receiver:
cls.error(fl_ctx, "received reply but we are no longer waiting for it")
else:
assert isinstance(receiver, _ReplyReceiver)
cls.info(fl_ctx, f"received reply in {time.time()-receiver.tx_start_time} secs - set waiter")
receiver.process(request)
return make_reply(ReturnCode.OK)
[docs] @classmethod
def enable(cls, fl_ctx: FLContext):
"""Enable ReliableMessage. This method can be called multiple times, but only the 1st call has effect.
Args:
fl_ctx: FL Context
Returns:
"""
if cls._enabled:
return
cls._enabled = True
max_request_workers = ConfigService.get_int_var(
name=ConfigVarName.RM_MAX_REQUEST_WORKERS, conf=SystemConfigs.APPLICATION_CONF, default=20
)
query_interval = ConfigService.get_float_var(
name=ConfigVarName.RM_QUERY_INTERVAL, conf=SystemConfigs.APPLICATION_CONF, default=2.0
)
cls._query_interval = query_interval
cls._executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_request_workers)
engine = fl_ctx.get_engine()
engine.register_aux_message_handler(
topic=TOPIC_RELIABLE_REQUEST,
message_handle_func=cls._receive_request,
)
engine.register_aux_message_handler(
topic=TOPIC_RELIABLE_REPLY,
message_handle_func=cls._receive_reply,
)
t = threading.Thread(target=cls._monitor_req_receivers, daemon=True)
t.start()
cls._logger.info(f"enabled reliable message: {max_request_workers=} {query_interval=}")
@classmethod
def _monitor_req_receivers(cls):
while not cls._shutdown_asked:
expired_receivers = []
with cls._tx_lock:
now = time.time()
for tx_id, receiver in cls._req_receivers.items():
assert isinstance(receiver, _RequestReceiver)
if receiver.rcv_time and now - receiver.rcv_time > 4 * receiver.tx_timeout:
cls._logger.info(f"detected expired request receiver {tx_id}")
expired_receivers.append(tx_id)
if expired_receivers:
with cls._tx_lock:
for tx_id in expired_receivers:
cls._req_receivers.pop(tx_id, None)
time.sleep(2.0)
cls._logger.info("shutdown reliable message monitor")
[docs] @classmethod
def shutdown(cls):
"""Shutdown ReliableMessage.
Returns:
"""
if not cls._shutdown_asked:
cls._shutdown_asked = True
cls._executor.shutdown(wait=False)
cls._logger.info("ReliableMessage is shutdown")
@classmethod
def _log_msg(cls, fl_ctx: FLContext, msg: str):
props = []
tx_id = fl_ctx.get_prop(PROP_KEY_TX_ID)
if tx_id:
props.append(f"rm_tx={tx_id}")
op = fl_ctx.get_prop(PROP_KEY_OP)
if op:
props.append(f"rm_op={op}")
topic = fl_ctx.get_prop(PROP_KEY_TOPIC)
if topic:
props.append(f"rm_topic={topic}")
rm_ctx = ""
if props:
rm_ctx = " ".join(props)
if rm_ctx:
msg = f"[{rm_ctx}] {msg}"
return generate_log_message(fl_ctx, msg)
[docs] @classmethod
def info(cls, fl_ctx: FLContext, msg: str):
cls._logger.info(cls._log_msg(fl_ctx, msg))
[docs] @classmethod
def error(cls, fl_ctx: FLContext, msg: str):
cls._logger.error(cls._log_msg(fl_ctx, msg))
[docs] @classmethod
def debug(cls, fl_ctx: FLContext, msg: str):
cls._logger.debug(cls._log_msg(fl_ctx, msg))
[docs] @classmethod
def send_request(
cls,
target: str,
topic: str,
request: Shareable,
per_msg_timeout: float,
tx_timeout: float,
abort_signal: Signal,
fl_ctx: FLContext,
) -> Shareable:
"""Send a reliable request.
Args:
target: the target cell of this request
topic: topic of the request;
request: the request to be sent
per_msg_timeout: timeout when sending a message
tx_timeout: the timeout of the whole transaction
abort_signal: abort signal
fl_ctx: the FL context
Returns: reply from the peer.
If tx_timeout is not specified or <= per_msg_timeout, the request will be sent only once without retrying.
"""
check_positive_number("per_msg_timeout", per_msg_timeout)
if tx_timeout:
check_positive_number("tx_timeout", tx_timeout)
if not tx_timeout or tx_timeout <= per_msg_timeout:
# simple aux message
cls.info(fl_ctx, f"send request with simple Aux Msg: {per_msg_timeout=} {tx_timeout=}")
engine = fl_ctx.get_engine()
reply = engine.send_aux_request(
targets=[target],
topic=topic,
request=request,
timeout=per_msg_timeout,
fl_ctx=fl_ctx,
)
result, _ = _extract_result(reply, target)
return result
tx_id = str(uuid.uuid4())
fl_ctx.set_prop(key=PROP_KEY_TX_ID, value=tx_id, private=True, sticky=False)
cls.info(fl_ctx, f"send request with Reliable Msg {per_msg_timeout=} {tx_timeout=}")
receiver = _ReplyReceiver(tx_id, per_msg_timeout, tx_timeout)
cls._reply_receivers[tx_id] = receiver
request.set_header(HEADER_TX_ID, tx_id)
request.set_header(HEADER_OP, OP_REQUEST)
request.set_header(HEADER_TOPIC, topic)
request.set_header(HEADER_PER_MSG_TIMEOUT, per_msg_timeout)
request.set_header(HEADER_TX_TIMEOUT, tx_timeout)
try:
result = cls._send_request(target, request, abort_signal, fl_ctx, receiver)
except Exception as e:
cls.error(fl_ctx, f"exception sending reliable message: {secure_format_traceback()}")
result = _error_reply(ReturnCode.ERROR, secure_format_exception(e))
cls._reply_receivers.pop(tx_id)
return result
@classmethod
def _send_request(
cls,
target: str,
request: Shareable,
abort_signal: Signal,
fl_ctx: FLContext,
receiver: _ReplyReceiver,
) -> Shareable:
engine = fl_ctx.get_engine()
# keep sending the request until a positive ack or result is received
tx_timeout = receiver.tx_timeout
per_msg_timeout = receiver.per_msg_timeout
num_tries = 0
while True:
if abort_signal and abort_signal.triggered:
cls.info(fl_ctx, "send_request abort triggered")
return make_reply(ReturnCode.TASK_ABORTED)
if time.time() - receiver.tx_start_time >= receiver.tx_timeout:
cls.error(fl_ctx, f"aborting send_request since exceeded {tx_timeout=}")
return make_reply(ReturnCode.COMMUNICATION_ERROR)
if num_tries > 0:
cls.info(fl_ctx, f"retry #{num_tries} sending request: {per_msg_timeout=}")
ack = engine.send_aux_request(
targets=[target],
topic=TOPIC_RELIABLE_REQUEST,
request=request,
timeout=per_msg_timeout,
fl_ctx=fl_ctx,
)
ack, rc = _extract_result(ack, target)
if ack and rc not in [ReturnCode.COMMUNICATION_ERROR]:
# is this result?
op = ack.get_header(HEADER_OP)
if op == OP_REPLY:
# the reply is already the result - we are done!
# this could happen when we didn't get positive ack for our first request, and the result was
# already produced when we did the 2nd request (this request).
cls.info(fl_ctx, f"C1: received result in {time.time()-receiver.tx_start_time} seconds; {rc=}")
return ack
# the ack is a status report - check status
status = ack.get_header(HEADER_STATUS)
if status and status != STATUS_NOT_RECEIVED:
# status should never be STATUS_NOT_RECEIVED, unless there is a bug in the receiving logic
# STATUS_NOT_RECEIVED is only possible during "query" phase.
cls.info(fl_ctx, f"received status ack: {rc=} {status=}")
break
if time.time() + cls._query_interval - receiver.tx_start_time >= tx_timeout:
cls.error(fl_ctx, f"aborting send_request since it will exceed {tx_timeout=}")
return make_reply(ReturnCode.COMMUNICATION_ERROR)
# we didn't get a positive ack - wait a short time and re-send the request.
cls.info(fl_ctx, f"unsure the request was received ({rc=}): will retry in {cls._query_interval} secs")
num_tries += 1
start = time.time()
while time.time() - start < cls._query_interval:
if abort_signal and abort_signal.triggered:
cls.info(fl_ctx, "abort send_request triggered by signal")
return make_reply(ReturnCode.TASK_ABORTED)
time.sleep(0.1)
cls.info(fl_ctx, "request was received by the peer - will query for result")
return cls._query_result(target, abort_signal, fl_ctx, receiver)
@classmethod
def _query_result(
cls,
target: str,
abort_signal: Signal,
fl_ctx: FLContext,
receiver: _ReplyReceiver,
) -> Shareable:
tx_timeout = receiver.tx_timeout
per_msg_timeout = receiver.per_msg_timeout
# Querying phase - try to get result
engine = fl_ctx.get_engine()
query = Shareable()
query.set_header(HEADER_TX_ID, receiver.tx_id)
query.set_header(HEADER_OP, OP_QUERY)
num_tries = 0
last_query_time = 0
short_wait = 0.1
while True:
if time.time() - receiver.tx_start_time > tx_timeout:
cls.error(fl_ctx, f"aborted query since exceeded {tx_timeout=}")
return _error_reply(ReturnCode.COMMUNICATION_ERROR, f"max tx timeout ({tx_timeout}) reached")
if receiver.result_ready.wait(short_wait):
# we already received result sent by the target.
# Note that we don't wait forever here - we only wait for _query_interval, so we could
# check other condition and/or send query to ask for result.
cls.info(fl_ctx, f"C2: received result in {time.time()-receiver.tx_start_time} seconds")
return receiver.result
if abort_signal and abort_signal.triggered:
cls.info(fl_ctx, "aborted query triggered by abort signal")
return make_reply(ReturnCode.TASK_ABORTED)
if time.time() - last_query_time < cls._query_interval:
# don't query too quickly
continue
# send a query. The ack of the query could be the result itself, or a status report.
# Note: the ack could be the result because we failed to receive the result sent by the target earlier.
num_tries += 1
cls.info(fl_ctx, f"query #{num_tries}: try to get result from {target}: {per_msg_timeout=}")
ack = engine.send_aux_request(
targets=[target],
topic=TOPIC_RELIABLE_REQUEST,
request=query,
timeout=per_msg_timeout,
fl_ctx=fl_ctx,
)
last_query_time = time.time()
ack, rc = _extract_result(ack, target)
if ack and rc not in [ReturnCode.COMMUNICATION_ERROR]:
op = ack.get_header(HEADER_OP)
if op == OP_REPLY:
# the ack is result itself!
cls.info(fl_ctx, f"C3: received result in {time.time()-receiver.tx_start_time} seconds")
return ack
status = ack.get_header(HEADER_STATUS)
if status == STATUS_NOT_RECEIVED:
# the receiver side lost context!
cls.error(fl_ctx, f"peer {target} lost request!")
return _error_reply(ReturnCode.EXECUTION_EXCEPTION, "STATUS_NOT_RECEIVED")
elif status == STATUS_ABORTED:
cls.error(fl_ctx, f"peer {target} aborted processing!")
return _error_reply(ReturnCode.EXECUTION_EXCEPTION, "Aborted")
cls.info(fl_ctx, f"will retry query in {cls._query_interval} secs: {rc=} {status=} {op=}")
else:
cls.info(fl_ctx, f"will retry query in {cls._query_interval} secs: {rc=}")