Source code for nvflare.fuel.f3.cellnet.credential_manager

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import threading

from cryptography import x509
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from cryptography.x509 import Certificate

from nvflare.fuel.f3.cellnet.cell_cipher import SimpleCellCipher
from nvflare.fuel.f3.cellnet.fqcn import FQCN
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.f3.endpoint import Endpoint

log = logging.getLogger(__name__)

CERT_ERROR = "cert_error"
CERT_TARGET = "cert_target"
CERT_ORIGIN = "cert_origin"
CERT_CONTENT = "cert_content"
CERT_CA_CONTENT = "cert_ca_content"
CERT_REQ_TIMEOUT = 10


[docs]class CredentialManager: """Helper class for secure message. It holds the local credentials and certificate cache""" def __init__(self, local_endpoint: Endpoint): self.local_endpoint = local_endpoint self.cert_cache = {} self.lock = threading.Lock() conn_props = self.local_endpoint.conn_props ca_cert_path = conn_props.get(DriverParams.CA_CERT) server_cert_path = conn_props.get(DriverParams.SERVER_CERT) if server_cert_path: local_cert_path = server_cert_path local_key_path = conn_props.get(DriverParams.SERVER_KEY) else: local_cert_path = conn_props.get(DriverParams.CLIENT_CERT) local_key_path = conn_props.get(DriverParams.CLIENT_KEY) if not local_cert_path: log.debug("Certificate is not configured, secure message is not supported") self.ca_cert = None self.local_cert = None self.local_key = None self.cell_cipher = None else: self.ca_cert = self.read_file(ca_cert_path) self.local_cert = self.read_file(local_cert_path) self.local_key = self.read_file(local_key_path) self.cell_cipher = SimpleCellCipher(self.get_ca_cert(), self.get_local_key(), self.get_local_cert()) if not self.local_cert: log.debug("Certificate is not configured, secure message is not supported") self.cell_cipher = None else: self.cell_cipher = SimpleCellCipher(self.get_ca_cert(), self.get_local_key(), self.get_local_cert())
[docs] def encrypt(self, target_cert: bytes, payload: bytes) -> bytes: if not self.cell_cipher: raise RuntimeError("Secure message not supported, Cell not running in secure mode") return self.cell_cipher.encrypt(payload, x509.load_pem_x509_certificate(target_cert))
[docs] def decrypt(self, origin_cert: bytes, cipher: bytes) -> bytes: if not self.cell_cipher: raise RuntimeError("Secure message not supported, Cell not running in secure mode") return self.cell_cipher.decrypt(cipher, x509.load_pem_x509_certificate(origin_cert))
[docs] def get_certificate(self, fqcn: str) -> bytes: if not self.cell_cipher: raise RuntimeError("This cell doesn't support certificate exchange, not running in secure mode") target = FQCN.get_root(fqcn) return self.cert_cache.get(target)
[docs] def save_certificate(self, fqcn: str, cert: bytes): target = FQCN.get_root(fqcn) self.cert_cache[target] = cert
[docs] def create_request(self, target: str) -> dict: req = { CERT_TARGET: target, CERT_ORIGIN: FQCN.get_root(self.local_endpoint.name), CERT_CONTENT: self.local_cert, CERT_CA_CONTENT: self.ca_cert, } return req
[docs] def process_request(self, request: dict) -> dict: target = request.get(CERT_TARGET) origin = request.get(CERT_ORIGIN) reply = {CERT_TARGET: target, CERT_ORIGIN: origin} if not self.local_cert: reply[CERT_ERROR] = f"Target {target} is not running in secure mode" else: cert = request.get(CERT_CONTENT) # Save cert from requester in the cache self.cert_cache[origin] = cert reply[CERT_CONTENT] = self.local_cert reply[CERT_CA_CONTENT] = self.ca_cert return reply
[docs] @staticmethod def process_response(reply: dict) -> bytes: error = reply.get(CERT_ERROR) if error: raise RuntimeError(f"Request to get certificate from {target} failed: {error}") return reply.get(CERT_CONTENT)
[docs] def get_local_cert(self) -> Certificate: return x509.load_pem_x509_certificate(self.local_cert)
[docs] def get_local_key(self) -> RSAPrivateKey: return serialization.load_pem_private_key(self.local_key, password=None)
[docs] def get_ca_cert(self) -> Certificate: return x509.load_pem_x509_certificate(self.ca_cert)
[docs] @staticmethod def read_file(file_name: str): if not file_name: return None with open(file_name, "rb") as f: return f.read()