Source code for nvflare.private.fed.utils.identity_utils

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from cryptography.x509.oid import NameOID

from nvflare.lighter.utils import (
    load_crt,
    load_crt_bytes,
    load_private_key_file,
    sign_content,
    verify_cert,
    verify_content,
)
from nvflare.security.logging import secure_format_exception


[docs] class CNMismatch(Exception): pass
[docs] class MissingCN(Exception): pass
[docs] class InvalidAsserterCert(Exception): pass
[docs] class InvalidCNSignature(Exception): pass
[docs] def get_cn_from_cert(cert): subject = cert.subject attr = subject.get_attributes_for_oid(NameOID.COMMON_NAME) if not attr: raise MissingCN() return attr[0].value
[docs] def load_cert_file(path: str): return load_crt(path)
[docs] def load_cert_bytes(data: bytes): return load_crt_bytes(data)
[docs] class IdentityAsserter: def __init__(self, private_key_file: str, cert_file: str): with open(cert_file, "rb") as f: self.cert_data = f.read() self.private_key_file = private_key_file self.pri_key = load_private_key_file(private_key_file) self.cert_file = cert_file self.cert = load_cert_bytes(self.cert_data) self.cn = get_cn_from_cert(self.cert)
[docs] def sign_common_name(self, nonce: str) -> str: return sign_content(self.cn + nonce, self.pri_key, return_str=False)
[docs] class IdentityVerifier: def __init__(self, root_cert_file: str): self.root_cert = load_cert_file(root_cert_file) self.root_public_key = self.root_cert.public_key()
[docs] def verify_common_name(self, asserted_cn: str, nonce: str, asserter_cert, signature) -> bool: # verify asserter_cert try: verify_cert( cert_to_be_verified=asserter_cert, root_ca_public_key=self.root_public_key, ) except: raise InvalidAsserterCert() # verify signature provided by the asserter asserter_public_key = asserter_cert.public_key() cn = get_cn_from_cert(asserter_cert) if cn != asserted_cn: raise CNMismatch() assert isinstance(cn, str) try: verify_content(content=cn + nonce, signature=signature, public_key=asserter_public_key) except Exception as ex: raise InvalidCNSignature(f"cannot verify common name signature: {secure_format_exception(ex)}") return True