Source code for nvflare.private.fed.client.communicator

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import socket
import time
from typing import List, Optional

from nvflare.apis.event_type import EventType
from nvflare.apis.filter import Filter
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_constant import ReturnCode as ShareableRC
from nvflare.apis.fl_constant import ServerCommandKey, ServerCommandNames
from nvflare.apis.fl_context import FLContext
from nvflare.apis.fl_exception import FLCommunicationError
from nvflare.apis.shareable import Shareable
from nvflare.apis.utils.fl_context_utils import gen_new_peer_ctx
from nvflare.fuel.f3.cellnet.core_cell import FQCN, CoreCell
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode
from nvflare.fuel.f3.cellnet.utils import format_size
from nvflare.private.defs import CellChannel, CellChannelTopic, CellMessageHeaderKeys, SpecialTaskName, new_cell_message
from nvflare.private.fed.client.client_engine_internal_spec import ClientEngineInternalSpec
from nvflare.security.logging import secure_format_exception


def _get_client_ip():
    """Return localhost IP.

    More robust than ``socket.gethostbyname(socket.gethostname())``. See
    https://stackoverflow.com/questions/166506/finding-local-ip-addresses-using-pythons-stdlib/28950776#28950776
    for more details.

    Returns:
        The host IP

    """
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    try:
        s.connect(("10.255.255.255", 1))  # doesn't even have to be reachable
        ip = s.getsockname()[0]
    except Exception:
        ip = "127.0.0.1"
    finally:
        s.close()
    return ip


[docs]class Communicator: def __init__( self, ssl_args=None, secure_train=False, client_state_processors: Optional[List[Filter]] = None, compression=None, cell: CoreCell = None, client_register_interval=2, timeout=5.0, maint_msg_timeout=5.0, ): """To init the Communicator. Args: ssl_args: SSL args secure_train: True/False to indicate if secure train client_state_processors: Client state processor filters compression: communicate compression algorithm """ self.cell = cell self.ssl_args = ssl_args self.secure_train = secure_train self.verbose = False self.should_stop = False self.heartbeat_done = False self.client_state_processors = client_state_processors self.compression = compression self.client_register_interval = client_register_interval self.timeout = timeout self.maint_msg_timeout = maint_msg_timeout self.logger = logging.getLogger(self.__class__.__name__)
[docs] def client_registration(self, client_name, project_name, fl_ctx: FLContext): """Client's metadata used to authenticate and communicate. Args: client_name: client name project_name: FL study project name fl_ctx: FLContext Returns: The client's token """ local_ip = _get_client_ip() shareable = Shareable() shared_fl_ctx = gen_new_peer_ctx(fl_ctx) shareable.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx) headers = { CellMessageHeaderKeys.CLIENT_NAME: client_name, CellMessageHeaderKeys.CLIENT_IP: local_ip, CellMessageHeaderKeys.PROJECT_NAME: project_name, } login_message = new_cell_message(headers, shareable) start = time.time() while not self.cell: self.logger.info("Waiting for the client cell to be created.") if time.time() - start > 15.0: raise RuntimeError("Client cell could not be created. Failed to login the client.") time.sleep(0.5) while not self.cell.is_cell_connected(FQCN.ROOT_SERVER): time.sleep(0.1) if time.time() - start > 30.0: raise FLCommunicationError("error:Could not connect to the server for client_registration.") while True: try: result = self.cell.send_request( target=FQCN.ROOT_SERVER, channel=CellChannel.SERVER_MAIN, topic=CellChannelTopic.Register, request=login_message, timeout=self.maint_msg_timeout, ) return_code = result.get_header(MessageHeaderKey.RETURN_CODE) if return_code == ReturnCode.UNAUTHENTICATED: unauthenticated = result.get_header(MessageHeaderKey.ERROR) raise FLCommunicationError("error:client_registration " + unauthenticated) token = result.get_header(CellMessageHeaderKeys.TOKEN) ssid = result.get_header(CellMessageHeaderKeys.SSID) if not token and not self.should_stop: time.sleep(self.client_register_interval) else: break except Exception as ex: raise FLCommunicationError("error:client_registration", ex) return token, ssid
[docs] def pull_task(self, project_name, token, ssid, fl_ctx: FLContext, timeout=None): """Get a task from server. Args: project_name: FL study project name token: client token ssid: service session ID fl_ctx: FLContext timeout: how long to wait for response from server Returns: A CurrentTask message from server """ start_time = time.time() shareable = Shareable() shared_fl_ctx = gen_new_peer_ctx(fl_ctx) shareable.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx) client_name = fl_ctx.get_identity_name() task_message = new_cell_message( { CellMessageHeaderKeys.TOKEN: token, CellMessageHeaderKeys.CLIENT_NAME: client_name, CellMessageHeaderKeys.SSID: ssid, CellMessageHeaderKeys.PROJECT_NAME: project_name, }, shareable, ) job_id = str(shared_fl_ctx.get_prop(FLContextKey.CURRENT_RUN)) if not timeout: timeout = self.timeout fqcn = FQCN.join([FQCN.ROOT_SERVER, job_id]) task = self.cell.send_request( target=fqcn, channel=CellChannel.SERVER_COMMAND, topic=ServerCommandNames.GET_TASK, request=task_message, timeout=timeout, optional=True, ) end_time = time.time() return_code = task.get_header(MessageHeaderKey.RETURN_CODE) if return_code == ReturnCode.OK: size = task.get_header(MessageHeaderKey.PAYLOAD_LEN) task_name = task.payload.get_header(ServerCommandKey.TASK_NAME) fl_ctx.set_prop(FLContextKey.SSID, ssid, sticky=False) if task_name not in [SpecialTaskName.END_RUN, SpecialTaskName.TRY_AGAIN]: self.logger.info( f"Received from {project_name} server. getTask: {task_name} size: {format_size(size)} " f"({size} Bytes) time: {end_time - start_time:.6f} seconds" ) elif return_code == ReturnCode.AUTHENTICATION_ERROR: self.logger.warning("get_task request authentication failed.") time.sleep(5.0) return None else: task = None self.logger.warning(f"Failed to get_task from {project_name} server. Will try it again.") return task
[docs] def submit_update( self, project_name, token, ssid, fl_ctx: FLContext, client_name, shareable, execute_task_name, timeout=None ): """Submit the task execution result back to the server. Args: project_name: server project name token: client token ssid: service session ID fl_ctx: fl_ctx client_name: client name shareable: execution task result shareable execute_task_name: execution task name timeout: how long to wait for response from server Returns: ReturnCode """ start_time = time.time() shared_fl_ctx = gen_new_peer_ctx(fl_ctx) shareable.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx) # shareable.add_cookie(name=FLContextKey.TASK_ID, data=task_id) shareable.set_header(FLContextKey.TASK_NAME, execute_task_name) task_ssid = fl_ctx.get_prop(FLContextKey.SSID) if task_ssid != ssid: self.logger.warning("submit_update request failed because SSID mismatch.") return ReturnCode.INVALID_SESSION rc = shareable.get_return_code() optional = rc == ShareableRC.TASK_ABORTED task_message = new_cell_message( { CellMessageHeaderKeys.TOKEN: token, CellMessageHeaderKeys.CLIENT_NAME: client_name, CellMessageHeaderKeys.SSID: ssid, CellMessageHeaderKeys.PROJECT_NAME: project_name, }, shareable, ) job_id = str(shared_fl_ctx.get_prop(FLContextKey.CURRENT_RUN)) if not timeout: timeout = self.timeout fqcn = FQCN.join([FQCN.ROOT_SERVER, job_id]) result = self.cell.send_request( target=fqcn, channel=CellChannel.SERVER_COMMAND, topic=ServerCommandNames.SUBMIT_UPDATE, request=task_message, timeout=timeout, optional=optional, ) end_time = time.time() return_code = result.get_header(MessageHeaderKey.RETURN_CODE) size = task_message.get_header(MessageHeaderKey.PAYLOAD_LEN) self.logger.info( f" SubmitUpdate size: {format_size(size)} ({size} Bytes). time: {end_time - start_time:.6f} seconds" ) return return_code
[docs] def quit_remote(self, servers, task_name, token, ssid, fl_ctx: FLContext): """Sending the last message to the server before leaving. Args: servers: FL servers task_name: project name token: FL client token fl_ctx: FLContext Returns: server's reply to the last message """ shared_fl_ctx = gen_new_peer_ctx(fl_ctx) shareable = Shareable() shareable.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx) client_name = fl_ctx.get_identity_name() quit_message = new_cell_message( { CellMessageHeaderKeys.TOKEN: token, CellMessageHeaderKeys.CLIENT_NAME: client_name, CellMessageHeaderKeys.SSID: ssid, CellMessageHeaderKeys.PROJECT_NAME: task_name, }, shareable, ) try: result = self.cell.send_request( target=FQCN.ROOT_SERVER, channel=CellChannel.SERVER_MAIN, topic=CellChannelTopic.Quit, request=quit_message, timeout=self.maint_msg_timeout, ) return_code = result.get_header(MessageHeaderKey.RETURN_CODE) if return_code == ReturnCode.UNAUTHENTICATED: self.logger.info(f"Client token: {token} has been removed from the server.") server_message = result.get_header(CellMessageHeaderKeys.MESSAGE) except Exception as ex: raise FLCommunicationError("error:client_quit", ex) return server_message
[docs] def send_heartbeat(self, servers, task_name, token, ssid, client_name, engine: ClientEngineInternalSpec, interval): fl_ctx = engine.new_context() simulate_mode = fl_ctx.get_prop(FLContextKey.SIMULATE_MODE, False) wait_times = int(interval / 2) num_heartbeats_sent = 0 heartbeats_log_interval = 10 while not self.heartbeat_done: try: engine.fire_event(EventType.BEFORE_CLIENT_HEARTBEAT, fl_ctx) shareable = Shareable() shared_fl_ctx = gen_new_peer_ctx(fl_ctx) shareable.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx) job_ids = engine.get_all_job_ids() heartbeat_message = new_cell_message( { CellMessageHeaderKeys.TOKEN: token, CellMessageHeaderKeys.SSID: ssid, CellMessageHeaderKeys.CLIENT_NAME: client_name, CellMessageHeaderKeys.PROJECT_NAME: task_name, CellMessageHeaderKeys.JOB_IDS: job_ids, }, shareable, ) try: result = self.cell.send_request( target=FQCN.ROOT_SERVER, channel=CellChannel.SERVER_MAIN, topic=CellChannelTopic.HEART_BEAT, request=heartbeat_message, timeout=self.maint_msg_timeout, ) return_code = result.get_header(MessageHeaderKey.RETURN_CODE) if return_code == ReturnCode.UNAUTHENTICATED: unauthenticated = result.get_header(MessageHeaderKey.ERROR) raise FLCommunicationError("error:client_quit " + unauthenticated) num_heartbeats_sent += 1 if num_heartbeats_sent % heartbeats_log_interval == 0: self.logger.debug(f"Client: {client_name} has sent {num_heartbeats_sent} heartbeats.") if not simulate_mode: # server_message = result.get_header(CellMessageHeaderKeys.MESSAGE) abort_jobs = result.get_header(CellMessageHeaderKeys.ABORT_JOBS, []) self._clean_up_runs(engine, abort_jobs) else: if return_code != ReturnCode.OK: break except Exception as ex: raise FLCommunicationError("error:client_quit", ex) engine.fire_event(EventType.AFTER_CLIENT_HEARTBEAT, fl_ctx) for i in range(wait_times): time.sleep(2) if self.heartbeat_done: break except Exception as e: self.logger.info(f"Failed to send heartbeat. Will try again. Exception: {secure_format_exception(e)}") time.sleep(5)
def _clean_up_runs(self, engine, abort_runs): # abort_runs = list(set(response.abort_jobs)) display_runs = ",".join(abort_runs) try: if abort_runs: for job in abort_runs: engine.abort_app(job) self.logger.debug(f"These runs: {display_runs} are not running on the server. Aborted them.") except: self.logger.debug(f"Failed to clean up the runs: {display_runs}")