# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import tempfile
import threading
import time
import uuid
from contextlib import suppress
from typing import Optional
from nvflare.apis.client import Client, ClientPropKey
from nvflare.apis.fl_constant import FLContextKey, ReservedKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.fuel.f3.cellnet.defs import IdentityChallengeKey, MessageHeaderKey
from nvflare.fuel.utils.admin_name_utils import is_valid_admin_client_name
from nvflare.fuel.utils.log_utils import get_obj_logger
from nvflare.private.defs import CellMessageHeaderKeys, ClientRegSession, ClientType, InternalFLContextKey
from nvflare.private.fed.server.cred_keeper import CredKeeper
from nvflare.private.fed.utils.identity_utils import get_org_from_cert, load_crt_bytes
from nvflare.security.logging import secure_format_exception
[docs]
class ClientManager:
def __init__(self, project_name=None, min_num_clients=2, max_num_clients=10):
"""Manages client adding and removing.
Args:
project_name: project name
min_num_clients: minimum number of clients allowed.
max_num_clients: maximum number of clients allowed.
"""
self.project_name = project_name
# TODO:: remove min num clients
self.min_num_clients = min_num_clients
self.max_num_clients = max_num_clients
self.clients = dict() # token => Client
self.name_to_clients = dict() # name => Client
self.disabled_clients = set()
self.disabled_clients_file = None
self.cred_keeper = CredKeeper()
self.lock = threading.Lock()
self.num_relays = 0
self.logger = get_obj_logger(self)
[docs]
def set_disabled_clients_file(self, file_path: str):
self.disabled_clients_file = file_path
self._load_disabled_clients()
def _load_disabled_clients(self):
if not self.disabled_clients_file or not os.path.exists(self.disabled_clients_file):
return
try:
with open(self.disabled_clients_file) as f:
data = json.load(f)
if not isinstance(data, dict):
raise ValueError("disabled clients file must be a JSON object")
clients = data.get("disabled_clients")
if not isinstance(clients, list):
raise ValueError("disabled_clients must be a list")
with self.lock:
self.disabled_clients = {str(client_name) for client_name in clients if client_name}
except Exception as ex:
self.logger.critical(
f"failed to load disabled clients from {self.disabled_clients_file}: {ex}; "
"refusing to start to preserve disable-client policy"
)
raise
def _save_disabled_clients(self, disabled_clients=None):
if not self.disabled_clients_file:
return
if disabled_clients is None:
with self.lock:
disabled_clients = set(self.disabled_clients)
dirname = os.path.dirname(self.disabled_clients_file)
if dirname:
os.makedirs(dirname, exist_ok=True)
data = {"disabled_clients": sorted(disabled_clients)}
tmp_path = None
try:
fd, tmp_path = tempfile.mkstemp(
prefix=f"{os.path.basename(self.disabled_clients_file)}.",
suffix=".tmp",
dir=dirname or ".",
text=True,
)
with os.fdopen(fd, "w") as f:
json.dump(data, f, indent=2)
os.replace(tmp_path, self.disabled_clients_file)
except Exception:
if tmp_path:
with suppress(OSError):
os.unlink(tmp_path)
raise
[docs]
def is_client_disabled(self, client_name: str) -> bool:
with self.lock:
return client_name in self.disabled_clients
[docs]
def disable_client(self, client_name: str) -> list:
with self.lock:
already_disabled = client_name in self.disabled_clients
self.disabled_clients.add(client_name)
removed_clients = []
for token, client in list(self.clients.items()):
if client.name == client_name:
removed_clients.append((token, client))
self.clients.pop(token, None)
self.name_to_clients.pop(client_name, None)
disabled_snapshot = set(self.disabled_clients)
try:
self._save_disabled_clients(disabled_snapshot)
except Exception as ex:
if not already_disabled:
self.disabled_clients.discard(client_name)
for token, client in removed_clients:
self.clients[token] = client
self.name_to_clients[client.name] = client
self.logger.error(f"failed to persist disabled-client state for {client_name}: {ex}")
raise
removed_tokens = [token for token, _client in removed_clients]
self.logger.info(f"Client {client_name} disabled. Removed active tokens: {removed_tokens}")
return removed_tokens
[docs]
def enable_client(self, client_name: str) -> bool:
with self.lock:
was_disabled = client_name in self.disabled_clients
if was_disabled:
self.disabled_clients.remove(client_name)
disabled_snapshot = set(self.disabled_clients)
else:
disabled_snapshot = None
if was_disabled:
try:
self._save_disabled_clients(disabled_snapshot)
except Exception as ex:
self.disabled_clients.add(client_name)
self.logger.error(f"failed to persist enabled-client state for {client_name}: {ex}")
raise
self.logger.info(f"Client {client_name} enabled. Was disabled: {was_disabled}")
return was_disabled
[docs]
def set_clients(self, clients: dict):
self.clients = clients
self.name_to_clients = {}
for c in clients.values():
self.name_to_clients[c.name] = c
[docs]
def authenticate(self, request, fl_ctx: FLContext) -> Optional[Client]:
client_type = request.get_header(CellMessageHeaderKeys.CLIENT_TYPE)
client = self.login_client(request, fl_ctx, client_type)
if not client:
return None
# client_ip = context.peer().split(":")[1]
client_ip = request.get_header(CellMessageHeaderKeys.CLIENT_IP)
# new client join
with self.lock:
if client_type == ClientType.REGULAR:
self.name_to_clients[client.name] = client
self.clients.update({client.token: client})
client_kind = "client"
else:
# do not update self.clients for non-regular clients
client_kind = client_type
self.logger.info(
"Client: New {} {} joined. Sent token: {}. Total clients: {}".format(
client_kind, client.name + "@" + client_ip, client.token, len(self.clients)
)
)
return client
[docs]
def remove_client(self, token):
"""Remove a registered client's active token entry.
Args:
token: client token
Returns:
The removed Client object, if the token was active
"""
with self.lock:
client = self.clients.pop(token, None)
if client:
self.name_to_clients.pop(client.name, None)
self.logger.info(
"Client Name:{} \tToken: {} left. Total clients: {}".format(client.name, token, len(self.clients))
)
else:
self.logger.warning("remove_client: unknown token %s", token)
return client
[docs]
def login_client(self, client_login, fl_ctx: FLContext, client_type):
proj_name = client_login.get_header(CellMessageHeaderKeys.PROJECT_NAME)
if not self.is_valid_task(proj_name):
fl_ctx.set_prop(
FLContextKey.UNAUTHENTICATED, "Requested task does not match the current server task", sticky=False
)
self.logger.error(f"login_client failed: {proj_name}")
return None
return self.authenticated_client(client_login, fl_ctx, client_type)
[docs]
def has_relays(self):
return self.num_relays > 0
[docs]
def validate_client(self, request, fl_ctx: FLContext, allow_new=False):
"""Validate the client state message.
Args:
request: A request from client.
fl_ctx: FLContext
allow_new: whether to allow new client. Note that its task should still match server's.
Returns:
client id if it's a valid client
"""
# token = client_state.token
token = request.get_header(CellMessageHeaderKeys.TOKEN)
if not token:
fl_ctx.set_prop(FLContextKey.UNAUTHENTICATED, "Could not read client uid from the payload", sticky=False)
client = None
elif not self.is_valid_task(request.get_header(CellMessageHeaderKeys.PROJECT_NAME)):
fl_ctx.set_prop(
FLContextKey.UNAUTHENTICATED, "Requested task does not match the current server task", sticky=False
)
client = None
elif not (allow_new or self.is_from_authorized_client(token)):
fl_ctx.set_prop(FLContextKey.UNAUTHENTICATED, "Unknown client identity", sticky=False)
client = None
else:
client = self.clients.get(token)
return client
def _get_id_verifier(self, fl_ctx: FLContext):
return self.cred_keeper.get_id_verifier(fl_ctx)
[docs]
def authenticated_client(self, request, fl_ctx: FLContext, client_type) -> Optional[Client]:
"""Use SSL certificate for authenticate the client.
Args:
request: client login request Message
fl_ctx: FL_Context
client_type: type of the client
Returns:
Client object.
"""
client_name = request.get_header(CellMessageHeaderKeys.CLIENT_NAME)
if self.is_client_disabled(client_name):
fl_ctx.set_prop(FLContextKey.UNAUTHENTICATED, f"Client '{client_name}' is disabled", sticky=False)
self.logger.warning(f"Reject disabled client registration: {client_name}")
return None
shareable = request.payload
if not isinstance(shareable, Shareable):
self.logger.error(f"payload must be Shareable but got {type(shareable)}")
return None
secure_mode = fl_ctx.get_prop(FLContextKey.SECURE_MODE, False)
client_org = ""
asserter_cert_data = shareable.get(IdentityChallengeKey.CERT)
if secure_mode:
# verify client identity
if not asserter_cert_data:
self.logger.error("missing client cert in register request")
return None
signature = shareable.get(IdentityChallengeKey.SIGNATURE)
if not signature:
self.logger.error("missing signature in register request")
return None
asserter_cert = load_crt_bytes(asserter_cert_data)
id_verifier = self._get_id_verifier(fl_ctx)
reg = fl_ctx.get_prop(InternalFLContextKey.CLIENT_REG_SESSION)
if not reg:
self.logger.error(f"missing {InternalFLContextKey.CLIENT_REG_SESSION} in FLContext!")
return None
if not isinstance(reg, ClientRegSession):
self.logger.error(f"reg should be ClientRegSession but got {type(reg)}")
return None
try:
id_verifier.verify_common_name(
asserted_cn=client_name,
asserter_cert=asserter_cert,
signature=signature,
nonce=reg.nonce,
)
except Exception as ex:
self.logger.error(f"failed to verify client identity: {secure_format_exception(ex)}")
return None
self.logger.debug(f"identity verified for client '{client_name}'")
client_org = get_org_from_cert(asserter_cert)
elif asserter_cert_data:
try:
asserter_cert = load_crt_bytes(asserter_cert_data)
client_org = get_org_from_cert(asserter_cert)
except Exception:
pass
with self.lock:
# Recheck under lock so disable_client cannot race with registration after the fast-path checks above.
if client_name in self.disabled_clients:
fl_ctx.set_prop(FLContextKey.UNAUTHENTICATED, f"Client '{client_name}' is disabled", sticky=False)
self.logger.warning(f"Reject disabled client registration: {client_name}")
return None
clients_to_be_removed = [token for token, client in self.clients.items() if client.name == client_name]
for item in clients_to_be_removed:
client = self.clients.pop(item, None)
if client:
self.name_to_clients.pop(client.name, None)
self.logger.info(f"Client: {client_name} already registered. Re-login the client with a new token.")
client = Client(client_name, str(uuid.uuid4()))
client.set_prop(ClientPropKey.ORG, client_org)
client_fqcn = request.get_header(MessageHeaderKey.ORIGIN)
self._set_client_props(client, client_fqcn, fl_ctx)
self.logger.debug(f"authenticated client {client_name}: {client_fqcn=}")
if client_type == ClientType.REGULAR and len(self.clients) >= self.max_num_clients:
# only impose the limit to REGULAR clients
fl_ctx.set_prop(FLContextKey.UNAUTHENTICATED, "Maximum number of clients reached", sticky=False)
self.logger.info(f"Maximum number of clients reached. Reject client: {client_name} login.")
return None
if client_type == ClientType.RELAY:
self.num_relays += 1
return client
[docs]
def is_from_authorized_client(self, token):
"""Check if a client is authorized.
Args:
token: client token
Returns:
True if it is a recognised client
"""
return token in self.clients
[docs]
def is_valid_task(self, task):
"""Check whether the requested task matches the server's project_name.
Returns:
True if task name is the same as server's project name.
"""
# TODO: change the name of this method
return task == self.project_name
[docs]
def heartbeat(self, token, client_name, client_fqcn, fl_ctx: FLContext):
"""Update the heartbeat of the client.
Args:
token: client token
client_name: client name
client_fqcn: FQCN of the client
fl_ctx: FLContext
Returns:
If a new client needs to be created.
"""
with self.lock:
if client_name in self.disabled_clients:
fl_ctx.set_prop(FLContextKey.UNAUTHENTICATED, f"Client '{client_name}' is disabled", sticky=False)
self.logger.warning(f"Reject disabled client heartbeat: {client_name}")
return False
client = self.clients.get(token)
if client:
client.last_connect_time = time.time()
self.logger.debug(f"Receive heartbeat from Client:{token}")
return False
else:
for _token, _client in self.clients.items():
if _client.name == client_name:
fl_ctx.set_prop(
FLContextKey.COMMUNICATION_ERROR,
"Client ID already registered as a client: {}".format(client_name),
sticky=False,
)
self.logger.info(
f"Failed to re-activate the client:{client_name} with token: {token}. "
f"Client already exist with token: {_token}."
)
return False
client = Client(client_name, token)
self._set_client_props(client, client_fqcn, fl_ctx)
self.clients.update({token: client})
self.name_to_clients[client.name] = client
self.logger.info(f"Re-activate the client: {client_name} at {client_fqcn} with token: {token}")
return True
@staticmethod
def _set_client_props(client: Client, fqcn: str, fl_ctx: FLContext):
client.set_fqcn(fqcn)
client.last_connect_time = time.time()
peer_ctx = fl_ctx.get_peer_context()
if peer_ctx:
client.set_fqsn(peer_ctx.get_prop(ReservedKey.FQSN, "?"))
client.set_is_leaf(peer_ctx.get_prop(ReservedKey.IS_LEAF, "?"))
site_config = fl_ctx.get_prop(FLContextKey.CLIENT_SITE_CONFIG)
if site_config is not None:
client.set_site_config(site_config)
[docs]
def get_clients(self):
"""Get the list of registered clients.
Returns:
A dict of {client_token: client}
"""
return self.clients
[docs]
def get_min_clients(self):
return self.min_num_clients
[docs]
def get_max_clients(self):
return self.max_num_clients
[docs]
def get_client_from_name(self, client_name):
result = self.name_to_clients.get(client_name)
if not result:
# Check whether this is a valid admin client.
# Note that since admin clients are not kept in name_to_clients, we assume that the admin client
# is valid and dynamically create the Client object as the result.
if is_valid_admin_client_name(client_name):
result = Client(client_name, None)
result.set_fqcn(client_name)
else:
self.logger.debug(
f"no client for {client_name}: I have {self.name_to_clients.keys()} {self.clients.keys()}"
)
return result