Source code for nvflare.private.fed.server.client_manager

# Copyright (c) 2021-2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import threading
import time
import uuid

import grpc

from nvflare.apis.client import Client


[docs]class ClientManager: def __init__(self, project_name=None, min_num_clients=2, max_num_clients=10): """Manages client adding and removing. Args: project_name: project name min_num_clients: minimum number of clients allowed. max_num_clients: maximum number of clients allowed. """ self.project_name = project_name # TODO:: remove min num clients self.min_num_clients = min_num_clients self.max_num_clients = max_num_clients self.clients = dict() # token => Client self.lock = threading.Lock() self.logger = logging.getLogger(self.__class__.__name__)
[docs] def authenticate(self, request, context): client = self.login_client(request, context) if not client: return None client_ip = context.peer().split(":")[1] if len(self.clients) >= self.max_num_clients: context.abort(grpc.StatusCode.RESOURCE_EXHAUSTED, "Maximum number of clients reached") # new client will join the current round immediately with self.lock: self.clients.update({client.token: client}) self.logger.info( "Client: New client {} joined. Sent token: {}. Total clients: {}".format( request.client_name + "@" + client_ip, client.token, len(self.clients) ) ) return client.token
[docs] def remove_client(self, token): """Remove a registered client. Args: token: client token Returns: The removed Client object """ with self.lock: client = self.clients.pop(token) self.logger.info( "Client Name:{} \tToken: {} left. Total clients: {}".format(client.name, token, len(self.clients)) ) return client
[docs] def login_client(self, client_login, context): if not self.is_valid_task(client_login.meta.project): context.abort(grpc.StatusCode.INVALID_ARGUMENT, "Requested task does not match the current server task") return self.authenticated_client(client_login, context)
[docs] def validate_client(self, client_state, context, allow_new=False): """Validate the client state message. Args: client_state: A ClientState message received by server context: gRPC connection context allow_new: whether to allow new client. Note that its task should still match server's. Returns: client id if it's a valid client """ token = client_state.token if not token: context.abort(grpc.StatusCode.INVALID_ARGUMENT, "Could not read client uid from the payload") client = None elif not self.is_valid_task(client_state.meta.project): context.abort(grpc.StatusCode.INVALID_ARGUMENT, "Requested task does not match the current server task") client = None elif not (allow_new or self.is_from_authorized_client(token)): context.abort(grpc.StatusCode.UNAUTHENTICATED, "Unknown client identity") client = None else: client = self.clients.get(token) return client
[docs] def authenticated_client(self, client_login, context) -> Client: """Use SSL certificate for authenticate the client. Args: client_login: client login request context: gRPC connection context Returns: Client object. """ client = self.clients.get(client_login.token) if not client: cn_names = context.auth_context().get("x509_common_name") if cn_names: client_name = cn_names[0].decode("utf-8") if client_login.client_name: if not client_login.client_name == client_name: context.abort( grpc.StatusCode.UNAUTHENTICATED, "client ID does not match the SSL certificate CN" ) return None else: client_name = client_login.client_name for token, client in self.clients.items(): if client.name == client_name: context.abort( grpc.StatusCode.FAILED_PRECONDITION, "Client ID already registered as a client: {}".format(client_name), ) return None client = Client(client_name, str(uuid.uuid4())) return client
[docs] def is_from_authorized_client(self, token): """Check if a client is authorized. Args: token: client token Returns: True if it is a recognised client """ return token in self.clients
[docs] def is_valid_task(self, task): """Check whether the requested task matches the server's project_name. Returns: True if task name is the same as server's project name. """ return task.name == self.project_name
[docs] def heartbeat(self, token, client_name, context): """Update the heartbeat of the client. Args: token: client token client_name: client name context: grpc context Returns: If a new client needs to be created. """ with self.lock: client = self.clients.get(token) if client: client.last_connect_time = time.time() # self.clients.update({token: time.time()}) self.logger.debug(f"Receive heartbeat from Client:{token}") return False else: for _token, _client in self.clients.items(): if _client.name == client_name: context.abort( grpc.StatusCode.FAILED_PRECONDITION, "Client ID already registered as a client: {}".format(client_name), ) self.logger.info( "Failed to re-activate dead client:{} with token: {}. Client already exist.".format( client_name, _token ) ) return False client = Client(client_name, token) client.last_connect_time = time.time() # self._set_instance_name(client) self.clients.update({token: client}) self.logger.info("Re-activate dead client:{} with token: {}".format(client_name, token)) return True
[docs] def get_clients(self): """Get the list of registered clients. Returns: A dict of {client_token: client} """ return self.clients
[docs] def get_min_clients(self): return self.min_num_clients
[docs] def get_max_clients(self): return self.max_num_clients