Source code for nvflare.private.fed.server.client_manager

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
import time
import uuid
from typing import Optional

from nvflare.apis.client import Client
from nvflare.apis.fl_constant import FLContextKey, ReservedKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.fuel.f3.cellnet.defs import IdentityChallengeKey, MessageHeaderKey
from nvflare.fuel.utils.admin_name_utils import is_valid_admin_client_name
from nvflare.fuel.utils.log_utils import get_obj_logger
from nvflare.private.defs import CellMessageHeaderKeys, ClientRegSession, ClientType, InternalFLContextKey
from nvflare.private.fed.server.cred_keeper import CredKeeper
from nvflare.private.fed.utils.identity_utils import load_crt_bytes
from nvflare.security.logging import secure_format_exception


[docs] class ClientManager: def __init__(self, project_name=None, min_num_clients=2, max_num_clients=10): """Manages client adding and removing. Args: project_name: project name min_num_clients: minimum number of clients allowed. max_num_clients: maximum number of clients allowed. """ self.project_name = project_name # TODO:: remove min num clients self.min_num_clients = min_num_clients self.max_num_clients = max_num_clients self.clients = dict() # token => Client self.name_to_clients = dict() # name => Client self.cred_keeper = CredKeeper() self.lock = threading.Lock() self.num_relays = 0 self.logger = get_obj_logger(self)
[docs] def set_clients(self, clients: dict): self.clients = clients self.name_to_clients = {} for c in clients.values(): self.name_to_clients[c.name] = c
[docs] def authenticate(self, request, fl_ctx: FLContext) -> Optional[Client]: client_type = request.get_header(CellMessageHeaderKeys.CLIENT_TYPE) client = self.login_client(request, fl_ctx, client_type) if not client: return None # client_ip = context.peer().split(":")[1] client_ip = request.get_header(CellMessageHeaderKeys.CLIENT_IP) # new client join with self.lock: if client_type == ClientType.REGULAR: self.name_to_clients[client.name] = client self.clients.update({client.token: client}) client_kind = "client" else: # do not update self.clients for non-regular clients client_kind = client_type self.logger.info( "Client: New {} {} joined. Sent token: {}. Total clients: {}".format( client_kind, client.name + "@" + client_ip, client.token, len(self.clients) ) ) return client
[docs] def remove_client(self, token): """Remove a registered client. Args: token: client token Returns: The removed Client object """ with self.lock: client = self.clients.pop(token, None) if client: self.name_to_clients.pop(client.name, None) self.logger.info( "Client Name:{} \tToken: {} left. Total clients: {}".format(client.name, token, len(self.clients)) ) return client
[docs] def login_client(self, client_login, fl_ctx: FLContext, client_type): proj_name = client_login.get_header(CellMessageHeaderKeys.PROJECT_NAME) if not self.is_valid_task(proj_name): fl_ctx.set_prop( FLContextKey.UNAUTHENTICATED, "Requested task does not match the current server task", sticky=False ) self.logger.error(f"login_client failed: {proj_name}") return None return self.authenticated_client(client_login, fl_ctx, client_type)
[docs] def has_relays(self): return self.num_relays > 0
[docs] def validate_client(self, request, fl_ctx: FLContext, allow_new=False): """Validate the client state message. Args: request: A request from client. fl_ctx: FLContext allow_new: whether to allow new client. Note that its task should still match server's. Returns: client id if it's a valid client """ # token = client_state.token token = request.get_header(CellMessageHeaderKeys.TOKEN) if not token: fl_ctx.set_prop(FLContextKey.UNAUTHENTICATED, "Could not read client uid from the payload", sticky=False) client = None elif not self.is_valid_task(request.get_header(CellMessageHeaderKeys.PROJECT_NAME)): fl_ctx.set_prop( FLContextKey.UNAUTHENTICATED, "Requested task does not match the current server task", sticky=False ) client = None elif not (allow_new or self.is_from_authorized_client(token)): fl_ctx.set_prop(FLContextKey.UNAUTHENTICATED, "Unknown client identity", sticky=False) client = None else: client = self.clients.get(token) return client
def _get_id_verifier(self, fl_ctx: FLContext): return self.cred_keeper.get_id_verifier(fl_ctx)
[docs] def authenticated_client(self, request, fl_ctx: FLContext, client_type) -> Optional[Client]: """Use SSL certificate for authenticate the client. Args: request: client login request Message fl_ctx: FL_Context client_type: type of the client Returns: Client object. """ client_name = request.get_header(CellMessageHeaderKeys.CLIENT_NAME) shareable = request.payload if not isinstance(shareable, Shareable): self.logger.error(f"payload must be Shareable but got {type(shareable)}") return None secure_mode = fl_ctx.get_prop(FLContextKey.SECURE_MODE, False) if secure_mode: # verify client identity asserter_cert_data = shareable.get(IdentityChallengeKey.CERT) if not asserter_cert_data: self.logger.error("missing client cert in register request") return None signature = shareable.get(IdentityChallengeKey.SIGNATURE) if not signature: self.logger.error("missing signature in register request") return None asserter_cert = load_crt_bytes(asserter_cert_data) id_verifier = self._get_id_verifier(fl_ctx) reg = fl_ctx.get_prop(InternalFLContextKey.CLIENT_REG_SESSION) if not reg: self.logger.error(f"missing {InternalFLContextKey.CLIENT_REG_SESSION} in FLContext!") return None if not isinstance(reg, ClientRegSession): self.logger.error(f"reg should be ClientRegSession but got {type(reg)}") return None try: id_verifier.verify_common_name( asserted_cn=client_name, asserter_cert=asserter_cert, signature=signature, nonce=reg.nonce, ) except Exception as ex: self.logger.error(f"failed to verify client identity: {secure_format_exception(ex)}") return None self.logger.debug(f"identity verified for client '{client_name}'") with self.lock: clients_to_be_removed = [token for token, client in self.clients.items() if client.name == client_name] for item in clients_to_be_removed: client = self.clients.pop(item, None) if client: self.name_to_clients.pop(client.name, None) self.logger.info(f"Client: {client_name} already registered. Re-login the client with a new token.") client = Client(client_name, str(uuid.uuid4())) client_fqcn = request.get_header(MessageHeaderKey.ORIGIN) self._set_client_props(client, client_fqcn, fl_ctx) self.logger.debug(f"authenticated client {client_name}: {client_fqcn=}") if client_type == ClientType.REGULAR and len(self.clients) >= self.max_num_clients: # only impose the limit to REGULAR clients fl_ctx.set_prop(FLContextKey.UNAUTHENTICATED, "Maximum number of clients reached", sticky=False) self.logger.info(f"Maximum number of clients reached. Reject client: {client_name} login.") return None if client_type == ClientType.RELAY: self.num_relays += 1 return client
[docs] def is_from_authorized_client(self, token): """Check if a client is authorized. Args: token: client token Returns: True if it is a recognised client """ return token in self.clients
[docs] def is_valid_task(self, task): """Check whether the requested task matches the server's project_name. Returns: True if task name is the same as server's project name. """ # TODO: change the name of this method return task == self.project_name
[docs] def heartbeat(self, token, client_name, client_fqcn, fl_ctx: FLContext): """Update the heartbeat of the client. Args: token: client token client_name: client name client_fqcn: FQCN of the client fl_ctx: FLContext Returns: If a new client needs to be created. """ with self.lock: client = self.clients.get(token) if client: client.last_connect_time = time.time() self.logger.debug(f"Receive heartbeat from Client:{token}") return False else: for _token, _client in self.clients.items(): if _client.name == client_name: fl_ctx.set_prop( FLContextKey.COMMUNICATION_ERROR, "Client ID already registered as a client: {}".format(client_name), sticky=False, ) self.logger.info( f"Failed to re-activate the client:{client_name} with token: {token}. " f"Client already exist with token: {_token}." ) return False client = Client(client_name, token) self._set_client_props(client, client_fqcn, fl_ctx) self.clients.update({token: client}) self.name_to_clients[client.name] = client self.logger.info(f"Re-activate the client: {client_name} at {client_fqcn} with token: {token}") return True
@staticmethod def _set_client_props(client: Client, fqcn: str, fl_ctx: FLContext): client.set_fqcn(fqcn) client.last_connect_time = time.time() peer_ctx = fl_ctx.get_peer_context() if peer_ctx: client.set_fqsn(peer_ctx.get_prop(ReservedKey.FQSN, "?")) client.set_is_leaf(peer_ctx.get_prop(ReservedKey.IS_LEAF, "?"))
[docs] def get_clients(self): """Get the list of registered clients. Returns: A dict of {client_token: client} """ return self.clients
[docs] def get_min_clients(self): return self.min_num_clients
[docs] def get_max_clients(self): return self.max_num_clients
[docs] def get_all_clients_from_inputs(self, inputs): clients = [] invalid_inputs = [] for item in inputs: client = self.clients.get(item) # if item in self.get_all_clients(): if client: clients.append(client) else: client = self.get_client_from_name(item) if client: clients.append(client) else: invalid_inputs.append(item) return clients, invalid_inputs
[docs] def get_client_from_name(self, client_name): result = self.name_to_clients.get(client_name) if not result: # Check whether this is a valid admin client. # Note that since admin clients are not kept in name_to_clients, we assume that the admin client # is valid and dynamically create the Client object as the result. if is_valid_admin_client_name(client_name): result = Client(client_name, None) result.set_fqcn(client_name) else: self.logger.debug( f"no client for {client_name}: I have {self.name_to_clients.keys()} {self.clients.keys()}" ) return result