Source code for nvflare.private.fed.client.communicator

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time
from typing import List, Optional

from nvflare.apis.event_type import EventType
from nvflare.apis.filter import Filter
from nvflare.apis.fl_constant import FLContextKey, FLMetaKey, ReservedKey
from nvflare.apis.fl_constant import ReturnCode as ShareableRC
from nvflare.apis.fl_constant import SecureTrainConst, ServerCommandKey, ServerCommandNames
from nvflare.apis.fl_context import FLContext
from nvflare.apis.fl_exception import FLCommunicationError
from nvflare.apis.shareable import Shareable, make_copy
from nvflare.apis.signal import Signal
from nvflare.apis.utils.fl_context_utils import gen_new_peer_ctx
from nvflare.fuel.data_event.utils import get_scope_property, set_scope_property
from nvflare.fuel.f3.cellnet.cell import Cell
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode
from nvflare.fuel.f3.cellnet.fqcn import FQCN
from nvflare.fuel.f3.cellnet.utils import format_size
from nvflare.fuel.f3.message import Message as CellMessage
from nvflare.fuel.sec.authn import set_add_auth_headers_filters
from nvflare.fuel.utils.log_utils import get_obj_logger
from nvflare.private.defs import (
    CellChannel,
    CellChannelTopic,
    CellMessageHeaderKeys,
    ClientType,
    SpecialTaskName,
    new_cell_message,
)
from nvflare.private.fed.authenticator import Authenticator
from nvflare.private.fed.client.client_engine_internal_spec import ClientEngineInternalSpec
from nvflare.security.logging import secure_format_exception

from .utils import determine_parent_fqcn


[docs] class Communicator: def __init__( self, client_config=None, secure_train=False, client_state_processors: Optional[List[Filter]] = None, compression=None, cell: Cell = None, client_register_interval=2, timeout=5.0, maint_msg_timeout=5.0, ): """To init the Communicator. Args: client_config: client configuration data secure_train: True/False to indicate if secure train client_state_processors: Client state processor filters compression: communicate compression algorithm """ self.cell = cell self.client_config = client_config self.secure_train = secure_train self.verbose = False self.heartbeat_done = False self.client_state_processors = client_state_processors self.compression = compression self.client_register_interval = client_register_interval self.timeout = timeout self.maint_msg_timeout = maint_msg_timeout # token and token_signature are issued by the Server after the client is authenticated # they are added to every message going to the server as proof of authentication self.token = None self.token_signature = None self.ssid = None self.client_name = None self.token_verifier = None self.abort_signal = Signal() self.engine = None self.last_task_id = None # ID of the last task received self.pending_task = None # the task currently being processed self.logger = get_obj_logger(self) self._state_lock = threading.Lock() tmp_ctx = FLContext() tmp_ctx.set_prop( key=ReservedKey.IDENTITY_NAME, value=client_config["client_name"], private=False, sticky=True, ) self._peer_ctx = tmp_ctx """ To call set_add_auth_headers_filters, both cell and token must be available. The set_cell is called when cell becomes available, set_auth is called when token becomes available. In CP, set_cell happens before set_auth, hence we call set_add_auth_headers_filters in set_auth for CP. In CJ, set_auth happens before set_cell, hence we call set_add_auth_headers_filters in set_cell for CJ. """
[docs] def set_auth(self, client_name, token, token_signature, ssid): self.ssid = ssid self.token_signature = token_signature self.token = token self.client_name = client_name if self.cell: # for CP set_add_auth_headers_filters(self.cell, client_name, token, token_signature, ssid) # put auth properties in data bus so that they can be used elsewhere set_scope_property(scope_name=client_name, key=FLMetaKey.AUTH_TOKEN, value=token) set_scope_property(scope_name=client_name, key=FLMetaKey.AUTH_TOKEN_SIGNATURE, value=token_signature)
[docs] def set_cell(self, cell): self.cell = cell if self.token: # for CJ set_add_auth_headers_filters(self.cell, self.client_name, self.token, self.token_signature, self.ssid) # set CB to receive task messages from children cell.register_request_cb( channel=CellChannel.SERVER_COMMAND, topic=ServerCommandNames.GET_TASK, cb=self._process_get_task, ) cell.register_request_cb( channel=CellChannel.SERVER_COMMAND, topic=ServerCommandNames.SUBMIT_UPDATE, cb=self._process_submit_result, )
@staticmethod def _make_try_again(): shareable = Shareable() shareable.set_header(key=FLContextKey.TASK_ID, value="") shareable.set_header(key=ServerCommandKey.TASK_NAME, value=SpecialTaskName.TRY_AGAIN) return shareable def _process_get_task(self, request: CellMessage): req = request.payload origin = request.get_header(MessageHeaderKey.ORIGIN) if not isinstance(req, Shareable): self.logger.error(f"Bad get_task request from {origin}") # note: the self.pending_task is unset by "submit_update", which could happen at any time. # we first assign self.pending_task to a different var (pending_task) and use this var in our processing. pending_task = self.pending_task pending_task_id = None if not self.engine or not pending_task: task = self._make_try_again() else: assert isinstance(pending_task, Shareable) last_task_id = req.get_header(ServerCommandKey.LAST_TASK_ID) task_id = pending_task.get_header(FLContextKey.TASK_ID) if last_task_id == task_id: self.logger.debug(f"same task request from {origin=}: {last_task_id=} - ask it to try again") task = self._make_try_again() elif not pending_task.get_header(ReservedKey.TASK_IS_READY): self.logger.debug(f"task {task_id} not ready - ask it to try again") task = self._make_try_again() else: # we'll send the pending task to the child. # make a copy of the task - only headers are copied! task = make_copy(pending_task, exclude_headers=[ReservedKey.TASK_IS_READY]) pending_task_id = task_id if self.engine: if pending_task_id: # fire event to notify others that the pending task is sent to a child client with self.engine.new_context() as fl_ctx: requesting_client_ctx = req.get_peer_context() fl_ctx.set_peer_context(requesting_client_ctx) fl_ctx.set_prop(FLContextKey.TASK_ID, pending_task_id, private=True, sticky=False) self.engine.fire_event(EventType.TASK_ASSIGNMENT_SENT, fl_ctx) is_processed = fl_ctx.get_prop(FLContextKey.EVENT_PROCESSED) if not is_processed: # no one listened or processed this event self.logger.warning( f"event {EventType.TASK_ASSIGNMENT_SENT} for task {pending_task_id} is not processed" ) task.set_peer_context(self._peer_ctx) return new_cell_message({MessageHeaderKey.RETURN_CODE: ReturnCode.OK}, task) def _process_submit_result(self, request: CellMessage): if not self.engine: # this could happen only when we crashed after task was pulled and restarted # since we don't have CJ restart capability this is impossible currently. self.logger.error("received submit_result while no engine") return new_cell_message({}, Shareable()) with self.engine.new_context() as fl_ctx: assert isinstance(fl_ctx, FLContext) result = request.payload assert isinstance(result, Shareable) peer_ctx = result.get_peer_context() if peer_ctx: fl_ctx.set_peer_context(peer_ctx) # we also need to set peer_props since some app code expects it. result.set_peer_props(peer_ctx.get_all_public_props()) fl_ctx.set_prop( key=FLContextKey.TASK_RESULT, value=result, private=True, sticky=False, ) self.engine.fire_event(EventType.TASK_RESULT_RECEIVED, fl_ctx) is_processed = fl_ctx.get_prop(FLContextKey.EVENT_PROCESSED) if not is_processed: # no one listened or processed this event task_id = result.get_header(ReservedKey.TASK_ID) self.logger.warning(f"event {EventType.TASK_RESULT_RECEIVED} for task {task_id} is not processed") return new_cell_message({MessageHeaderKey.RETURN_CODE: ReturnCode.OK}, Shareable())
[docs] def client_registration(self, client_name, project_name, fl_ctx: FLContext): """Register the client with the FLARE Server. Note that the client no longer needs to be directly connected with the Server! Since the client may be connected with the Server indirectly (e.g. via bridge nodes or proxy), in the secure mode, the client authentication cannot be based on the connection's TLS cert. Instead, the server and the client will explicitly authenticate each other using their provisioned PKI credentials, as follows: 1. Make sure that the Server is authentic. The client sends a Challenge request with a random nonce. The server is expected to return the following in its reply: - its cert and common name (Server_CN) - signature on the received client nonce + Server_CN - a random Server Nonce. This will be used for the server to validate the client's identity in the Registration request. The client then validates to make sure: - the Server_CN is the same as presented in the server cert - the Server_CN is the same as configured in the client's config (fed_client.json) - the signature is valid 2. Client sends Registration request that contains: - client cert and common name (Client_CN) - signature on the received Server Nonce + Client_CN The Server then validates to make sure: - the Client_CN is the same as presented in the client cert - the signature is valid NOTE: we do not explicitly validate certs' expiration time. This is because currently the same certs are also used for SSL connections, which already validate expiration. Args: client_name: client name project_name: FL study project name fl_ctx: FLContext Returns: The client's token """ start = time.time() while not self.cell: self.logger.info("Waiting for the client cell to be created.") if time.time() - start > 15.0: raise RuntimeError("Client cell could not be created. Failed to login the client.") time.sleep(0.5) shared_fl_ctx = gen_new_peer_ctx(fl_ctx) private_key_file = None root_cert_file = None cert_file = None secure_mode = fl_ctx.get_prop(FLContextKey.SECURE_MODE, False) expected_host = None if secure_mode: # explicitly authenticate with the Server expected_host = None server_config = fl_ctx.get_prop(FLContextKey.SERVER_CONFIG) if server_config: server0 = server_config[0] expected_host = server0.get("identity") if not expected_host: # the provision was done with an old version # to be backward compatible, we expect the host to be the server host we connected to # we get the host name from DataBus! expected_host = get_scope_property(scope_name=client_name, key=FLContextKey.SERVER_HOST_NAME) if not expected_host: raise RuntimeError("cannot determine expected_host") client_config = fl_ctx.get_prop(FLContextKey.CLIENT_CONFIG) if not client_config: raise RuntimeError(f"missing {FLContextKey.CLIENT_CONFIG} in FL Context") private_key_file = client_config.get(SecureTrainConst.PRIVATE_KEY) cert_file = client_config.get(SecureTrainConst.SSL_CERT) root_cert_file = client_config.get(SecureTrainConst.SSL_ROOT_CERT) authenticator = Authenticator( cell=self.cell, project_name=project_name, client_name=client_name, client_type=ClientType.REGULAR, expected_sp_identity=expected_host, secure_mode=secure_mode, root_cert_file=root_cert_file, private_key_file=private_key_file, cert_file=cert_file, msg_timeout=self.maint_msg_timeout, retry_interval=self.client_register_interval, ) token, signature, ssid, token_verifier = authenticator.authenticate(shared_fl_ctx, self.abort_signal) self.token_verifier = token_verifier self.set_auth(client_name, token, signature, ssid) return token, signature, ssid
[docs] def pull_task(self, project_name, token, ssid, fl_ctx: FLContext, timeout=None): """Get a task from server. Args: project_name: FL study project name token: client token ssid: service session ID fl_ctx: FLContext timeout: how long to wait for response from server Returns: A CurrentTask message from server """ if not self.engine: self.engine = fl_ctx.get_engine() self._peer_ctx = gen_new_peer_ctx(fl_ctx) start_time = time.time() shareable = Shareable() shared_fl_ctx = gen_new_peer_ctx(fl_ctx) shareable.set_peer_context(shared_fl_ctx) if self.last_task_id: shareable.set_header(ServerCommandKey.LAST_TASK_ID, self.last_task_id) task_message = new_cell_message( { CellMessageHeaderKeys.PROJECT_NAME: project_name, }, shareable, ) job_id = fl_ctx.get_job_id() if not timeout: timeout = self.timeout parent_fqcn = determine_parent_fqcn(self.client_config, fl_ctx) self.logger.debug(f"pulling task from parent FQCN: {parent_fqcn}") fqcn = FQCN.join([parent_fqcn, job_id]) task = self.cell.send_request( target=fqcn, channel=CellChannel.SERVER_COMMAND, topic=ServerCommandNames.GET_TASK, request=task_message, timeout=timeout, optional=True, abort_signal=fl_ctx.get_run_abort_signal(), ) end_time = time.time() return_code = task.get_header(MessageHeaderKey.RETURN_CODE) if return_code == ReturnCode.OK: size = task.get_header(MessageHeaderKey.PAYLOAD_LEN) task_data = task.payload if not isinstance(task_data, Shareable): self.logger.error(f"bad task from {parent_fqcn}: expect Shareable but got {type(task_data)}") task_name = task_data.get_header(ServerCommandKey.TASK_NAME) self.logger.debug(f"received task from parent {parent_fqcn}: {task_name=}") fl_ctx.set_prop(FLContextKey.SSID, ssid, sticky=False) if task_name not in [SpecialTaskName.END_RUN, SpecialTaskName.TRY_AGAIN]: self.logger.info( f"Received from {parent_fqcn}. getTask: {task_name} size: {format_size(size)} " f"({size} Bytes) time: {end_time - start_time:.6f} seconds" ) self.last_task_id = task_data.get_header(FLContextKey.TASK_ID) self.pending_task = task_data elif return_code == ReturnCode.AUTHENTICATION_ERROR: self.logger.warning("get_task request authentication failed.") return None else: task = None self.logger.warning(f"Failed to get_task from {parent_fqcn}. Will try it again.") return task
[docs] def submit_update( self, project_name, token, ssid, fl_ctx: FLContext, client_name, shareable, execute_task_name, timeout=None ): """Submit the task execution result back to the server. Args: project_name: server project name token: client token ssid: service session ID fl_ctx: fl_ctx client_name: client name shareable: execution task result shareable execute_task_name: execution task name timeout: how long to wait for response from server Returns: ReturnCode """ # Set the pending_task to None immediately to reduce the chance that we send this task to a child # while we are still processing. self.pending_task = None start_time = time.time() shared_fl_ctx = gen_new_peer_ctx(fl_ctx) shareable.set_peer_context(shared_fl_ctx) # shareable.add_cookie(name=FLContextKey.TASK_ID, data=task_id) shareable.set_header(FLContextKey.TASK_NAME, execute_task_name) task_ssid = fl_ctx.get_prop(FLContextKey.SSID) if task_ssid != ssid: self.logger.warning("submit_update request failed because SSID mismatch.") return ReturnCode.INVALID_SESSION rc = shareable.get_return_code() optional = rc == ShareableRC.TASK_ABORTED task_message = new_cell_message( { CellMessageHeaderKeys.PROJECT_NAME: project_name, }, shareable, ) job_id = fl_ctx.get_job_id() if not timeout: timeout = self.timeout parent_fqcn = determine_parent_fqcn(self.client_config, fl_ctx) self.logger.debug(f"submitting update to parent FQCN: {parent_fqcn}") fqcn = FQCN.join([parent_fqcn, job_id]) result = self.cell.send_request( target=fqcn, channel=CellChannel.SERVER_COMMAND, topic=ServerCommandNames.SUBMIT_UPDATE, request=task_message, timeout=timeout, optional=optional, abort_signal=fl_ctx.get_run_abort_signal(), ) end_time = time.time() return_code = result.get_header(MessageHeaderKey.RETURN_CODE) size = task_message.get_header(MessageHeaderKey.PAYLOAD_LEN) self.logger.info( f"SubmitUpdate to: {parent_fqcn}. size: {format_size(size)} ({size} Bytes). time: {end_time - start_time:.6f} seconds" ) return return_code
[docs] def quit_remote(self, servers, task_name, token, ssid, fl_ctx: FLContext): """Sending the last message to the server before leaving. Args: servers: FL servers task_name: project name token: FL client token fl_ctx: FLContext Returns: server's reply to the last message """ self.abort_signal.trigger(True) shared_fl_ctx = gen_new_peer_ctx(fl_ctx) shareable = Shareable() shareable.set_peer_context(shared_fl_ctx) quit_message = new_cell_message( { CellMessageHeaderKeys.PROJECT_NAME: task_name, }, shareable, ) try: result = self.cell.send_request( target=FQCN.ROOT_SERVER, channel=CellChannel.SERVER_MAIN, topic=CellChannelTopic.Quit, request=quit_message, timeout=self.maint_msg_timeout, ) return_code = result.get_header(MessageHeaderKey.RETURN_CODE) if return_code == ReturnCode.UNAUTHENTICATED: self.logger.info(f"Client token: {token} has been removed from the server.") server_message = result.get_header(CellMessageHeaderKeys.MESSAGE) except Exception as ex: raise FLCommunicationError("error:client_quit", ex) return server_message
[docs] def send_heartbeat(self, servers, task_name, token, ssid, client_name, engine: ClientEngineInternalSpec, interval): fl_ctx = engine.new_context() simulate_mode = fl_ctx.get_prop(FLContextKey.SIMULATE_MODE, False) wait_times = int(interval / 2) num_heartbeats_sent = 0 heartbeats_log_interval = 10 while not self.heartbeat_done: try: engine.fire_event(EventType.BEFORE_CLIENT_HEARTBEAT, fl_ctx) shareable = Shareable() shared_fl_ctx = gen_new_peer_ctx(fl_ctx) shareable.set_peer_context(shared_fl_ctx) job_ids = engine.get_all_job_ids() heartbeat_message = new_cell_message( { CellMessageHeaderKeys.PROJECT_NAME: task_name, CellMessageHeaderKeys.JOB_IDS: job_ids, }, shareable, ) try: result = self.cell.send_request( target=FQCN.ROOT_SERVER, channel=CellChannel.SERVER_MAIN, topic=CellChannelTopic.HEART_BEAT, request=heartbeat_message, timeout=self.maint_msg_timeout, ) return_code = result.get_header(MessageHeaderKey.RETURN_CODE) if return_code == ReturnCode.UNAUTHENTICATED: unauthenticated = result.get_header(MessageHeaderKey.ERROR) raise FLCommunicationError("error:client_quit " + unauthenticated) num_heartbeats_sent += 1 if num_heartbeats_sent % heartbeats_log_interval == 0: self.logger.debug(f"Client: {client_name} has sent {num_heartbeats_sent} heartbeats.") if not simulate_mode: # server_message = result.get_header(CellMessageHeaderKeys.MESSAGE) abort_jobs = result.get_header(CellMessageHeaderKeys.ABORT_JOBS, []) self._clean_up_runs(engine, abort_jobs) else: if return_code != ReturnCode.OK: break except Exception as ex: raise FLCommunicationError("error:client_quit", ex) engine.fire_event(EventType.AFTER_CLIENT_HEARTBEAT, fl_ctx) for i in range(wait_times): time.sleep(2) if self.heartbeat_done: break except Exception as e: self.logger.info(f"Failed to send heartbeat. Will try again. Exception: {secure_format_exception(e)}") time.sleep(5)
def _clean_up_runs(self, engine, abort_runs): # abort_runs = list(set(response.abort_jobs)) display_runs = ",".join(abort_runs) try: if abort_runs: for job in abort_runs: engine.abort_app(job) self.logger.debug(f"These runs: {display_runs} are not running on the server. Aborted them.") except: self.logger.debug(f"Failed to clean up the runs: {display_runs}")