# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.x509.oid import NameOID
from nvflare.lighter.constants import CertFileBasename, CtxKey, ParticipantType, PropKey
from nvflare.lighter.ctx import ProvisionContext
from nvflare.lighter.entity import Participant, Project
from nvflare.lighter.spec import Builder
from nvflare.lighter.utils import Identity, generate_cert, generate_keys, serialize_cert, serialize_pri_key
MAX_CN_LENGTH = 64
class _CertState:
CERT_STATE_FILE = "cert.json"
PROP_ROOT_CERT = CtxKey.ROOT_CERT
PROP_ROOT_PRI_KEY = CtxKey.ROOT_PRI_KEY
PROP_CERT = "cert"
PROP_PRI_KEY = "pri_key"
def __init__(self, state_dir: str):
self.is_available = False
self.state_dir = state_dir
self.content = {}
cert_file = os.path.join(state_dir, self.CERT_STATE_FILE)
if os.path.exists(cert_file):
self.is_available = True
with open(cert_file, "rt") as f:
self.content.update(json.load(f))
def get_root_cert(self):
return self.content.get(self.PROP_ROOT_CERT)
def set_root_cert(self, cert):
self.content[self.PROP_ROOT_CERT] = cert
def get_root_pri_key(self):
return self.content.get(self.PROP_ROOT_PRI_KEY)
def set_root_pri_key(self, key):
self.content[self.PROP_ROOT_PRI_KEY] = key
def has_subject(self, subject: str):
return subject in self.content
def _add_subject_prop(self, subject: str, key: str, value):
subject_data = self.content.get(subject)
if not subject_data:
subject_data = {}
self.content[subject] = subject_data
subject_data[key] = value
def _get_subject_prop(self, subject: str, key: str):
subject_data = self.content.get(subject)
if not subject_data:
return None
return subject_data.get(key)
def add_subject_cert(self, subject: str, cert):
self._add_subject_prop(subject, self.PROP_CERT, cert)
def get_subject_cert(self, subject: str):
return self._get_subject_prop(subject, self.PROP_CERT)
def add_subject_pri_key(self, subject: str, pri_key):
self._add_subject_prop(subject, self.PROP_PRI_KEY, pri_key)
def get_subject_pri_key(self, subject: str):
return self._get_subject_prop(subject, self.PROP_PRI_KEY)
def persist(self):
cert_file = os.path.join(self.state_dir, self.CERT_STATE_FILE)
with open(cert_file, "wt") as f:
json.dump(self.content, f)
[docs]
class CertBuilder(Builder):
def __init__(self):
"""Build certificate chain for every participant.
Handles building (creating and self-signing) the root CA certificates, creating server, client and
admin certificates, and having them signed by the root CA for secure communication. If the state folder has
information about previously generated certs, it loads them back and reuses them.
"""
self.root_cert = None
self.persistent_state = None
self.serialized_cert = None
self.pri_key = None
self.pub_key = None
self.subject = None
self.issuer = None
@staticmethod
def _fix_server_name(server: Participant):
"""Server Name is used as CN of the cert. But the CN cannot exceed 63 chars. So we have to truncate it
to make the cert.
Server Name also serves as the identity of the server for all clients to verify, and it must match the
CN in the server's cert.
Server Name is also the default host name (unless default host is explicitly specified) for clients to
connect to. Truncated name won't be a valid host name.
We have to accommodate all these factors:
- We truncate the server name and use it for both name and subject of the server. This will satisfy CN
requirement of the cert, and will satisfy server identity validation by clients.
- We check whether the DEFAULT_HOST property is explicitly specified in the server. If not, we explicitly
set it to the original name.
Args:
server: the server to be fixed.
Returns:
"""
original_name = server.name
if len(original_name) > MAX_CN_LENGTH:
truncated_name = original_name[:MAX_CN_LENGTH]
# both name and subject of the server must use the truncated name!
server.name = truncated_name
server.subject = truncated_name
# also make the original_name the default host
default_host = server.get_prop(PropKey.DEFAULT_HOST)
if not default_host:
# must use the original name as the default host
server.set_prop(PropKey.DEFAULT_HOST, original_name)
[docs]
def initialize(self, project: Project, ctx: ProvisionContext):
self._fix_server_name(project.get_server())
state_dir = ctx.get_state_dir()
self.persistent_state = _CertState(state_dir)
state = self.persistent_state
if project.root_private_key:
# using project provided credentials
self.serialized_cert = project.serialized_root_cert
self.root_cert = x509.load_pem_x509_certificate(self.serialized_cert, default_backend())
self.pri_key = project.root_private_key
self.pub_key = self.pri_key.public_key()
self.subject = self.root_cert.subject
self.issuer = self.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value
state.is_available = True
elif state.is_available:
state_root_cert = state.get_root_cert()
self.serialized_cert = state_root_cert.encode("ascii")
self.root_cert = x509.load_pem_x509_certificate(self.serialized_cert, default_backend())
state_pri_key = state.get_root_pri_key()
self.pri_key = serialization.load_pem_private_key(
state_pri_key.encode("ascii"), password=None, backend=default_backend()
)
self.pub_key = self.pri_key.public_key()
self.subject = self.root_cert.subject
self.issuer = self.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value
def _build_root(self, subject, subject_org):
assert isinstance(self.persistent_state, _CertState)
if not self.persistent_state.is_available:
pri_key, pub_key = generate_keys()
self.issuer = subject
self.root_cert = self._generate_cert(subject, subject_org, self.issuer, pri_key, pub_key, ca=True)
self.pri_key = pri_key
self.pub_key = pub_key
self.serialized_cert = serialize_cert(self.root_cert)
self.persistent_state.set_root_cert(self.serialized_cert.decode("ascii"))
self.persistent_state.set_root_pri_key(serialize_pri_key(self.pri_key).decode("ascii"))
def _build_write_cert_pair(self, participant: Participant, base_name, ctx: ProvisionContext):
assert isinstance(self.persistent_state, _CertState)
subject = participant.subject
if self.persistent_state.has_subject(subject):
subject_cert = self.persistent_state.get_subject_cert(subject)
cert = x509.load_pem_x509_certificate(subject_cert.encode("ascii"), default_backend())
subject_pri_key = self.persistent_state.get_subject_pri_key(subject)
pri_key = serialization.load_pem_private_key(
subject_pri_key.encode("ascii"), password=None, backend=default_backend()
)
if participant.type == ParticipantType.ADMIN:
cn_list = cert.subject.get_attributes_for_oid(NameOID.UNSTRUCTURED_NAME)
for cn in cn_list:
role = cn.value
new_role = participant.get_prop(PropKey.ROLE)
if role != new_role:
err_msg = (
f"{participant.name}'s previous role is {role} but is now {new_role}.\n"
+ "Please delete existing workspace and provision from scratch."
)
raise RuntimeError(err_msg)
else:
pri_key, cert = self.get_pri_key_cert(participant)
self.persistent_state.add_subject_cert(subject, serialize_cert(cert).decode("ascii"))
self.persistent_state.add_subject_pri_key(subject, serialize_pri_key(pri_key).decode("ascii"))
dest_dir = ctx.get_kit_dir(participant)
with open(os.path.join(dest_dir, f"{base_name}.crt"), "wb") as f:
f.write(serialize_cert(cert))
with open(os.path.join(dest_dir, f"{base_name}.key"), "wb") as f:
f.write(serialize_pri_key(pri_key))
if participant.type in [ParticipantType.CLIENT, ParticipantType.RELAY]:
self._build_internal_listener_cert(participant, ctx)
with open(os.path.join(dest_dir, "rootCA.pem"), "wb") as f:
f.write(self.serialized_cert)
def _build_internal_listener_cert(self, participant: Participant, ctx: ProvisionContext):
"""Build server cert if the participant has internal listeners.
Note that internal listener used to be only used for connecting SJ to SP, and CJ to SP, but now
relay hierarchy is connected to internal listeners.
Just like the FL Server, a relay could offer one or more hosts for other relays and clients to
connect to. Therefore, the relay's server cert must include all these host names and IP addresses
for others to make SSL-based connections using any one of these host names/addresses.
Args:
participant: the participant being provisioned
ctx: a ProvisionContext object
Returns: None
"""
lh = participant.get_listening_host()
if not lh:
return
dest_dir = ctx.get_kit_dir(participant)
project = ctx.get_project()
# make a fake/temp server participant to use the get_pri_key_cert() method!
tmp_participant = Participant(
type=ParticipantType.SERVER,
name=participant.name,
org=participant.org,
project=project,
props={
PropKey.HOST_NAMES: lh.host_names,
PropKey.DEFAULT_HOST: lh.default_host,
},
)
tmp_pri_key, tmp_cert = self.get_pri_key_cert(tmp_participant)
# The listener cert is a Server Cert.
bn = CertFileBasename.SERVER
with open(os.path.join(dest_dir, f"{bn}.crt"), "wb") as f:
f.write(serialize_cert(tmp_cert))
with open(os.path.join(dest_dir, f"{bn}.key"), "wb") as f:
f.write(serialize_pri_key(tmp_pri_key))
[docs]
def build(self, project: Project, ctx: ProvisionContext):
self._build_root(project.name, subject_org=None)
ctx[CtxKey.ROOT_CERT] = self.root_cert
ctx[CtxKey.ROOT_PRI_KEY] = self.pri_key
server = project.get_server()
if server:
self._build_write_cert_pair(server, CertFileBasename.SERVER, ctx)
for client in project.get_clients():
self._build_write_cert_pair(client, CertFileBasename.CLIENT, ctx)
for relay in project.get_relays():
self._build_write_cert_pair(relay, CertFileBasename.CLIENT, ctx)
for admin in project.get_admins():
self._build_write_cert_pair(admin, CertFileBasename.CLIENT, ctx)
[docs]
def get_pri_key_cert(self, participant: Participant):
pri_key, pub_key = generate_keys()
subject = participant.subject
subject_org = participant.org
if participant.type == ParticipantType.ADMIN:
role = participant.get_prop(PropKey.ROLE)
else:
role = None
server = participant if participant.type == ParticipantType.SERVER else None
cert = self._generate_cert(
subject,
subject_org,
self.issuer,
self.pri_key,
pub_key,
role=role,
server=server,
)
return pri_key, cert
@staticmethod
def _generate_cert(
subject,
subject_org,
issuer,
signing_pri_key,
subject_pub_key,
valid_days=360,
ca=False,
role=None,
server: Participant = None,
):
server_default_host = None
server_additional_hosts = None
if server:
# This is to generate a server cert.
# Use SubjectAlternativeName for all host names
server_default_host = server.get_default_host()
server_additional_hosts = server.get_prop(PropKey.HOST_NAMES)
return generate_cert(
subject=Identity(subject, subject_org, role),
issuer=Identity(issuer),
signing_pri_key=signing_pri_key,
subject_pub_key=subject_pub_key,
valid_days=valid_days,
ca=ca,
server_default_host=server_default_host,
server_additional_hosts=server_additional_hosts,
)
[docs]
def finalize(self, project: Project, ctx: ProvisionContext):
assert isinstance(self.persistent_state, _CertState)
self.persistent_state.persist()