# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Union
from nvflare.apis.utils.format_check import name_check
from .constants import DEFINED_PARTICIPANT_TYPES, DEFINED_ROLES, ConnSecurity, ParticipantType, PropKey
[docs]
class ListeningHost:
def __init__(self, scheme, host_names, default_host, port, conn_sec):
self.scheme = scheme
self.host_names = host_names
self.default_host = default_host
self.port = port
self.conn_sec = conn_sec
def __str__(self):
scheme, host_names, default_host, port, conn_sec = (
self.scheme,
self.host_names,
self.default_host,
self.port,
self.conn_sec,
)
return f"ListeningHost[{scheme=} {host_names=} {default_host=} {port=} {conn_sec=}]"
[docs]
class ConnectTo:
def __init__(self, name, host, port, conn_sec):
self.name = name
self.host = host
self.port = port
self.conn_sec = conn_sec
def __str__(self):
name, host, port, conn_sec = self.name, self.host, self.port, self.conn_sec
return f"ConnectTo[{name=} {host=} {port=} {conn_sec=}]"
def _check_host_name(scope: str, prop_key: str, value):
err, reason = name_check(value, "host_name")
if err:
raise ValueError(f"bad value for {prop_key} '{value}' in {scope}: {reason}")
def _check_host_names(scope: str, prop_key: str, value):
if isinstance(value, str):
_check_host_name(scope, prop_key, value)
elif isinstance(value, list):
for v in value:
_check_host_name(scope, prop_key, v)
def _check_admin_role(scope: str, prop_key: str, value):
if not isinstance(value, str):
raise ValueError(f"bad value for {prop_key} '{value}' in {scope}: must be str but got {type(value)}")
if not value:
raise ValueError(f"empty value for {prop_key} '{value}' in {scope}")
[docs]
def parse_connect_to(value, scope=None, prop_key=None) -> ConnectTo:
"""Parse the "connect_to" property.
Args:
value: value to be parsed. It is either a str or a dict.
scope: scope of the property
prop_key: key of the property
Returns: a ConnectTo object
"""
if isinstance(value, str):
# old format - for server only
return ConnectTo(None, value, None, None)
elif isinstance(value, dict):
name = value.get(PropKey.NAME)
host = value.get(PropKey.HOST)
port = value.get(PropKey.PORT)
conn_sec = value.get(PropKey.CONN_SECURITY)
return ConnectTo(name, host, port, conn_sec)
else:
raise ValueError(
f"bad value for {prop_key} '{value}' in {scope}: invalid type {type(value)}; must be str or dict"
)
def _check_connect_to(scope: str, prop_key: str, value):
ct = parse_connect_to(value, scope, prop_key)
if ct.host:
err, reason = name_check(ct.host, "host_name")
if err:
raise ValueError(f"bad value for {prop_key} '{value}' in {scope}: {reason}")
if ct.port is not None:
if not isinstance(ct.port, int):
raise ValueError(f"bad value for {prop_key} '{value}' in {scope}: port {ct.port} must be int")
if ct.port < 0:
raise ValueError(f"bad value for {prop_key} '{value}' in {scope}: invalid port {ct.port}")
def _check_conn_security(scope: str, prop_key: str, value):
if not isinstance(value, str):
raise ValueError(f"bad value for {prop_key} '{value}' in {scope}: must be a str but got {type(value)}")
valid_conn_secs = [ConnSecurity.CLEAR, ConnSecurity.MTLS, ConnSecurity.TLS]
if value.lower() not in valid_conn_secs:
raise ValueError(f"bad value for {prop_key} '{value}' in {scope}: must be one of {valid_conn_secs}")
[docs]
def parse_listening_host(value, scope=None, prop_key=None) -> ListeningHost:
"""Parse the "listening_host" property. It must be either str or a dict
Args:
value: value to be parsed
scope: scope of the prop
prop_key: key of the property
Returns: a ListeningHost object
"""
if isinstance(value, str):
# old format - for server only
return ListeningHost(None, None, value, None, None)
elif isinstance(value, dict):
scheme = value.get(PropKey.SCHEME)
host_names = value.get(PropKey.HOST_NAMES)
default_host = value.get(PropKey.DEFAULT_HOST)
port = value.get(PropKey.PORT)
conn_sec = value.get(PropKey.CONN_SECURITY)
return ListeningHost(scheme, host_names, default_host, port, conn_sec)
else:
raise ValueError(
f"bad value for {prop_key} '{value}' in {scope}: invalid type {type(value)}; must be str or dict"
)
def _check_listening_host(scope: str, prop_key: str, value):
h = parse_listening_host(value, scope, prop_key)
if h.host_names:
_check_host_names(scope, prop_key, h.host_names)
if h.default_host:
_check_host_name(scope, prop_key, h.default_host)
if h.port is not None:
if not isinstance(h.port, int):
raise ValueError(f"bad value for {prop_key} '{value}' in {scope}: port {h.port} must be int")
if h.port < 0:
raise ValueError(f"bad value for {prop_key} '{value}' in {scope}: invalid port {h.port}")
# validator functions for common properties
# Validator function must follow this signature:
# func(scope: str, prop_key: str, value)
_PROP_VALIDATORS = {
PropKey.HOST_NAMES: _check_host_names,
PropKey.CONNECT_TO: _check_connect_to,
PropKey.LISTENING_HOST: _check_listening_host,
PropKey.DEFAULT_HOST: _check_host_name,
PropKey.ROLE: _check_admin_role,
PropKey.CONN_SECURITY: _check_conn_security,
}
[docs]
class Entity:
def __init__(self, scope: str, name: str, props: dict, parent=None):
if not props:
props = {}
for k, v in props.items():
validator = _PROP_VALIDATORS.get(k)
if validator is not None:
validator(scope, k, v)
self.name = name
self.props = props
self.parent = parent
[docs]
def get_prop(self, key: str, default=None):
return self.props.get(key, default)
[docs]
def set_prop(self, key: str, value: Any):
self.props[key] = value
[docs]
def get_prop_fb(self, key: str, fb_key=None, default=None):
"""Get property value with fallback.
If I have the property, then return it.
If not, I return the fallback property of my parent. If I don't have parent, return default.
Args:
key: key of the property
fb_key: key of the fallback property.
default: value to return if no one has the property
Returns: property value
"""
value = self.get_prop(key)
if value:
return value
elif not self.parent:
return default
else:
# get the value from the parent
if not fb_key:
fb_key = key
return self.parent.get_prop(fb_key, default)
[docs]
class Participant(Entity):
def __init__(self, type: str, name: str, org: str, props: dict = None, project: Entity = None):
"""Class to represent a participant.
Each participant communicates to other participant. Therefore, each participant has its
own name, type, organization it belongs to, rules and other information.
Args:
type (str): server, client, admin, relay or other string that builders can handle
name (str): system-wide unique name
org (str): system-wide unique organization
props (dict): properties
project: the project that the participant belongs to
Raises:
ValueError: if name or org is not compliant with characters or format specification.
"""
Entity.__init__(self, f"{type}::{name}", name, props, parent=project)
if type in DEFINED_PARTICIPANT_TYPES:
err, reason = name_check(name, type)
if err:
raise ValueError(reason)
else:
err, reason = name_check(type, "simple_name")
if err:
raise ValueError(reason)
print(f"Warning: participant type '{type}' of {name} is not a defined type {DEFINED_PARTICIPANT_TYPES}")
err, reason = name_check(org, "org")
if err:
raise ValueError(reason)
if type == ParticipantType.ADMIN:
if not props:
raise ValueError(f"missing role for admin '{name}'")
role = props.get(PropKey.ROLE)
if not role:
raise ValueError(f"missing role for admin '{name}'")
err, reason = name_check(role, "simple_name")
if err:
raise ValueError(f"bad role value '{role}' for admin '{name}': {reason}")
if role not in DEFINED_ROLES:
print(f"Warning: '{role}' of admin '{name}' is not a defined role {DEFINED_ROLES}")
self.type = type
self.org = org
self.subject = name
[docs]
def get_default_host(self) -> str:
"""Get the default host name for accessing this participant (server).
If the "default_host" attribute is explicitly specified, then it's the default host.
If the "default_host" attribute is not explicitly specified, then use the "name" attribute.
Returns: a host name
"""
h = self.get_prop(PropKey.DEFAULT_HOST)
if h:
return h
else:
return self.name
[docs]
def get_listening_host(self) -> Optional[ListeningHost]:
"""Get listening host property of the participant
Returns: a ListeningHost object, or None if the property is not defined.
"""
h = self.get_prop(PropKey.LISTENING_HOST)
if not h:
return None
lh = parse_listening_host(h)
if not lh.scheme:
lh.scheme = "tcp"
if not lh.port:
lh.port = 0 # any port
if not lh.conn_sec:
lh.conn_sec = ConnSecurity.CLEAR
if not lh.default_host:
if self.type == ParticipantType.SERVER:
lh.default_host = self.get_default_host()
else:
lh.default_host = "localhost"
return lh
[docs]
def get_connect_to(self) -> Optional[ConnectTo]:
"""Get the connect_to property of the participant
Returns: a ConnectTo object
"""
h = self.get_prop(PropKey.CONNECT_TO)
if not h:
return None
else:
return parse_connect_to(h)
def _must_get(d: dict, key: str):
"""Must get property of the specified key from the dict
Args:
d: the dict that contains participant properties
key: key of the property to get
Returns: the value of the property. If the property does not exist, ValueError exception is raised.
"""
v = d.pop(key, None)
if not v:
raise ValueError(f"missing participant {key}")
return v
[docs]
def participant_from_dict(participant_def: dict) -> Participant:
"""Create a Participant from a dict that contains participant property definitions.
Args:
participant_def: the dict that contains participant definition
Returns: a Participant object
"""
if not isinstance(participant_def, dict):
raise ValueError(f"participant_def must be dict but got {type(participant_def)}")
name = _must_get(participant_def, PropKey.NAME)
t = _must_get(participant_def, PropKey.TYPE)
org = _must_get(participant_def, PropKey.ORG)
return Participant(type=t, name=name, org=org, props=participant_def)
[docs]
class Project(Entity):
def __init__(
self,
name: str,
description: str,
participants=None,
props: dict = None,
serialized_root_cert=None,
root_private_key=None,
):
"""A container class to hold information about this FL project.
This class only holds information. It does not drive the workflow.
Args:
name (str): the project name
description (str): brief description on this name
participants: if provided, list of participants of the project
props: properties of the project
serialized_root_cert: if provided, the root cert to be used for the project
root_private_key: if provided, the root private key for signing certs of sites and admins
Raises:
ValueError: when participant criteria is violated
"""
Entity.__init__(self, "project", name, props)
if serialized_root_cert:
if not root_private_key:
raise ValueError("missing root_private_key while serialized_root_cert is provided")
self.description = description
self.serialized_root_cert = serialized_root_cert
self.root_private_key = root_private_key
self.server = None
self.overseer = None
self._participants_by_types = {} # participant type => list of participants
self._all_names = {} # name => participant
if participants:
if not isinstance(participants, list):
raise ValueError(f"participants must be a list of Participant but got {type(participants)}")
for p in participants:
if not isinstance(p, Participant):
raise ValueError(f"bad item in participants: must be Participant but got {type(p)}")
self.add_participant(p)
[docs]
def set_server(self, name: str, org: str, props: dict) -> Participant:
"""Set the server of the project.
Args:
name: name of the server.
org: org of the server
props: additional server properties.
Returns: a Participant object for the server
"""
return self.add_participant(Participant(ParticipantType.SERVER, name, org, props))
[docs]
def get_server(self) -> Optional[Participant]:
"""Get the server definition. Only one server is supported!
Returns: server participant
"""
return self.server
[docs]
def set_overseer(self, name: str, org: str, props: dict) -> Participant:
return self.add_participant(Participant(ParticipantType.OVERSEER, name, org, props))
[docs]
def get_overseer(self) -> Optional[Participant]:
"""Get the overseer definition. Only one overseer is supported!
Returns: overseer participant
"""
return self.overseer
[docs]
def add_participant(self, participant: Participant) -> Participant:
"""Add a participant to the project.
Before adding the participant, this method checks the following conditions:
- All participants in the project must have unique names
- Only one server is allowed in the project
- Only one overseer is allowed in the project
- Role must be specified for admin type of participant
Args:
participant: the participant to be added.
Returns: the participant object added.
"""
if participant.name in self._all_names:
raise ValueError(f"the project {self.name} already has a participant with the name '{participant.name}'")
participant.parent = self
if participant.type == ParticipantType.SERVER:
if self.server:
raise ValueError(f"cannot add participant {participant.name} as server - server already exists")
self.server = participant
elif participant.type == ParticipantType.OVERSEER:
if self.overseer:
raise ValueError(f"cannot add participant {participant.name} as overseer - overseer already exists")
self.overseer = participant
participants = self._participants_by_types.get(participant.type)
if not participants:
participants = []
self._participants_by_types[participant.type] = participants
participants.append(participant)
self._all_names[participant.name] = participant
return participant
[docs]
def add_client(self, name: str, org: str, props: dict) -> Participant:
"""Add a client to the project
Args:
name: name of the client
org: org of the client
props: additional properties of the client
Returns: the Participant object of the client
"""
return self.add_participant(Participant(ParticipantType.CLIENT, name, org, props))
[docs]
def get_clients(self) -> List[Participant]:
"""Get all clients of the project
Returns: a list of clients
"""
return self.get_all_participants(ParticipantType.CLIENT)
[docs]
def add_relay(self, name: str, org: str, props: dict) -> Participant:
"""Add a relay to the project
Args:
name: name of the relay
org: org of the relay
props: additional properties of the relay
Returns: the relay Participant object
"""
return self.add_participant(Participant(ParticipantType.RELAY, name, org, props))
[docs]
def get_relays(self) -> List[Participant]:
"""Get all relays of the project
Returns: the list of relays of the project
"""
return self.get_all_participants(ParticipantType.RELAY)
[docs]
def add_admin(self, name: str, org: str, props: dict) -> Participant:
"""Add an admin user to the project
Args:
name: name of the admin user.
org: org of the admin user.
props: properties of the user definition
Returns: a Participant object of the admin user
"""
return self.add_participant(Participant(ParticipantType.ADMIN, name, org, props))
[docs]
def get_admins(self) -> List[Participant]:
"""Get the list of admin users
Returns: list of admin users
"""
return self.get_all_participants(ParticipantType.ADMIN)
[docs]
def get_all_participants(self, types: Union[None, str, List[str]] = None):
"""Get all participants of the project of specified types.
Args:
types: types of the participants to be returned.
Returns: all participants of the project of specified types.
If 'types' is not specified (None), it returns all participants of the project;
If 'types' is a str, it is treated as a single type and participants of this type is returned;
If 'types' is a list of types, participants of these types are returned;
"""
if not types:
# get all types
return list(self._all_names.values())
if isinstance(types, str):
types = [types]
elif not isinstance(types, list):
raise ValueError(f"types must be a str or List[str] but got {type(types)}")
result = []
processed_types = [] # in case 'types' contains duplicates
for t in types:
if t not in processed_types:
ps = self._participants_by_types.get(t)
if ps:
result.extend(ps)
processed_types.append(t)
return result