Source code for nvflare.lighter.impl.cert

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import json
import os

from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID

from nvflare.lighter.spec import Builder, Participant


[docs]def serialize_pri_key(pri_key): return pri_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, encryption_algorithm=serialization.NoEncryption(), )
[docs]def serialize_cert(cert): return cert.public_bytes(serialization.Encoding.PEM)
[docs]def load_crt(path): serialized_cert = open(path, "rb").read() return x509.load_pem_x509_certificate(serialized_cert, default_backend())
[docs]class CertBuilder(Builder): def __init__(self): """Build certificate chain for every participant. Handles building (creating and self-signing) the root CA certificates, creating server, client and admin certificates, and having them signed by the root CA for secure communication. If the state folder has information about previously generated certs, it loads them back and reuses them. """ self.root_cert = None self.persistent_state = dict()
[docs] def initialize(self, ctx): state_dir = self.get_state_dir(ctx) cert_file = os.path.join(state_dir, "cert.json") if os.path.exists(cert_file): self.persistent_state = json.load(open(cert_file, "rt")) self.serialized_cert = self.persistent_state["root_cert"].encode("ascii") self.root_cert = x509.load_pem_x509_certificate(self.serialized_cert, default_backend()) self.pri_key = serialization.load_pem_private_key( self.persistent_state["root_pri_key"].encode("ascii"), password=None, backend=default_backend() ) self.pub_key = self.pri_key.public_key() self.subject = self.root_cert.subject self.issuer = self.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value
def _build_root(self, subject, subject_org): if not self.persistent_state: pri_key, pub_key = self._generate_keys() self.issuer = subject self.root_cert = self._generate_cert(subject, subject_org, self.issuer, pri_key, pub_key, ca=True) self.pri_key = pri_key self.pub_key = pub_key self.serialized_cert = serialize_cert(self.root_cert) self.persistent_state["root_cert"] = self.serialized_cert.decode("ascii") self.persistent_state["root_pri_key"] = serialize_pri_key(self.pri_key).decode("ascii") def _build_write_cert_pair(self, participant, base_name, ctx): subject = self.get_subject(participant) if self.persistent_state and subject in self.persistent_state: cert = x509.load_pem_x509_certificate( self.persistent_state[subject]["cert"].encode("ascii"), default_backend() ) pri_key = serialization.load_pem_private_key( self.persistent_state[subject]["pri_key"].encode("ascii"), password=None, backend=default_backend() ) if participant.type == "admin": cn_list = cert.subject.get_attributes_for_oid(NameOID.UNSTRUCTURED_NAME) for cn in cn_list: role = cn.value new_role = participant.props.get("role") if role != new_role: err_msg = ( f"{participant.name}'s previous role is {role} but is now {new_role}.\n" + "Please delete existing workspace and provision from scratch." ) raise RuntimeError(err_msg) else: pri_key, cert = self.get_pri_key_cert(participant) self.persistent_state[subject] = dict( cert=serialize_cert(cert).decode("ascii"), pri_key=serialize_pri_key(pri_key).decode("ascii") ) dest_dir = self.get_kit_dir(participant, ctx) with open(os.path.join(dest_dir, f"{base_name}.crt"), "wb") as f: f.write(serialize_cert(cert)) with open(os.path.join(dest_dir, f"{base_name}.key"), "wb") as f: f.write(serialize_pri_key(pri_key)) if base_name == "client" and (listening_host := participant.props.get("listening_host")): tmp_participant = Participant("server", listening_host, participant.org) tmp_pri_key, tmp_cert = self.get_pri_key_cert(tmp_participant) with open(os.path.join(dest_dir, "server.crt"), "wb") as f: f.write(serialize_cert(tmp_cert)) with open(os.path.join(dest_dir, "server.key"), "wb") as f: f.write(serialize_pri_key(tmp_pri_key)) pkcs12 = serialization.pkcs12.serialize_key_and_certificates( subject.encode("ascii"), pri_key, cert, None, serialization.BestAvailableEncryption(subject.encode("ascii")) ) with open(os.path.join(dest_dir, f"{base_name}.pfx"), "wb") as f: f.write(pkcs12) with open(os.path.join(dest_dir, "rootCA.pem"), "wb") as f: f.write(self.serialized_cert)
[docs] def build(self, project, ctx): self._build_root(project.name, subject_org=None) ctx["root_cert"] = self.root_cert ctx["root_pri_key"] = self.pri_key overseer = project.get_participants_by_type("overseer") if overseer: self._build_write_cert_pair(overseer, "overseer", ctx) servers = project.get_participants_by_type("server", first_only=False) for server in servers: self._build_write_cert_pair(server, "server", ctx) for client in project.get_participants_by_type("client", first_only=False): self._build_write_cert_pair(client, "client", ctx) for admin in project.get_participants_by_type("admin", first_only=False): self._build_write_cert_pair(admin, "client", ctx)
[docs] def get_pri_key_cert(self, participant): pri_key, pub_key = self._generate_keys() subject = self.get_subject(participant) subject_org = participant.org if participant.type == "admin": role = participant.props.get("role") else: role = None cert = self._generate_cert(subject, subject_org, self.issuer, self.pri_key, pub_key, role=role) return pri_key, cert
[docs] def get_subject(self, participant): return participant.subject
def _generate_keys(self): pri_key = rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend()) pub_key = pri_key.public_key() return pri_key, pub_key def _generate_cert( self, subject, subject_org, issuer, signing_pri_key, subject_pub_key, valid_days=360, ca=False, role=None ): x509_subject = self._x509_name(subject, subject_org, role) x509_issuer = self._x509_name(issuer) builder = ( x509.CertificateBuilder() .subject_name(x509_subject) .issuer_name(x509_issuer) .public_key(subject_pub_key) .serial_number(x509.random_serial_number()) .not_valid_before(datetime.datetime.utcnow()) .not_valid_after( # Our certificate will be valid for 360 days datetime.datetime.utcnow() + datetime.timedelta(days=valid_days) # Sign our certificate with our private key ) .add_extension(x509.SubjectAlternativeName([x509.DNSName(subject)]), critical=False) ) if ca: builder = ( builder.add_extension( x509.SubjectKeyIdentifier.from_public_key(subject_pub_key), critical=False, ) .add_extension( x509.AuthorityKeyIdentifier.from_issuer_public_key(subject_pub_key), critical=False, ) .add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=False) ) return builder.sign(signing_pri_key, hashes.SHA256(), default_backend()) def _x509_name(self, cn_name, org_name=None, role=None): name = [x509.NameAttribute(NameOID.COMMON_NAME, cn_name)] if org_name is not None: name.append(x509.NameAttribute(NameOID.ORGANIZATION_NAME, org_name)) if role: name.append(x509.NameAttribute(NameOID.UNSTRUCTURED_NAME, role)) return x509.Name(name)
[docs] def finalize(self, ctx): state_dir = self.get_state_dir(ctx) cert_file = os.path.join(state_dir, "cert.json") json.dump(self.persistent_state, open(cert_file, "wt"))