Source code for nvflare.fuel.f3.cellnet.credential_manager

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import threading

from cryptography import x509
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from cryptography.x509 import Certificate

from nvflare.fuel.f3.cellnet.cell_cipher import SimpleCellCipher
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.f3.endpoint import Endpoint
from nvflare.fuel.f3.message import Message

log = logging.getLogger(__name__)

CERT_ERROR = "cert_error"
CERT_TARGET = "cert_target"
CERT_ORIGIN = "cert_origin"
CERT_CONTENT = "cert_content"
CERT_CA_CONTENT = "cert_ca_content"
CERT_REQ_TIMEOUT = 10


[docs] class CredentialManager: """Helper class for secure message. It holds the local credentials and certificate cache""" def __init__(self, local_endpoint: Endpoint): self.local_endpoint = local_endpoint self.cert_cache = {} self.lock = threading.Lock() conn_props = self.local_endpoint.conn_props ca_cert_path = conn_props.get(DriverParams.CA_CERT) server_cert_path = conn_props.get(DriverParams.SERVER_CERT) if server_cert_path: local_cert_path = server_cert_path local_key_path = conn_props.get(DriverParams.SERVER_KEY) else: local_cert_path = conn_props.get(DriverParams.CLIENT_CERT) local_key_path = conn_props.get(DriverParams.CLIENT_KEY) if not local_cert_path: log.debug("Certificate is not configured, secure message is not supported") self.ca_cert = None self.local_cert = None self.local_key = None self.cell_cipher = None else: self.ca_cert = self.read_file(ca_cert_path) self.local_cert = self.read_file(local_cert_path) self.local_key = self.read_file(local_key_path) self.cell_cipher = SimpleCellCipher(self.get_ca_cert(), self.get_local_key(), self.get_local_cert()) if not self.local_cert: log.debug("Certificate is not configured, secure message is not supported") self.cell_cipher = None else: self.cell_cipher = SimpleCellCipher(self.get_ca_cert(), self.get_local_key(), self.get_local_cert())
[docs] def encrypt(self, target_cert: bytes, payload: bytes) -> bytes: if not self.cell_cipher: raise RuntimeError("Secure message not supported, Cell not running in secure mode") return self.cell_cipher.encrypt(payload, x509.load_pem_x509_certificate(target_cert))
[docs] def decrypt(self, origin_cert: bytes, cipher: bytes) -> bytes: if not self.cell_cipher: raise RuntimeError("Secure message not supported, Cell not running in secure mode") return self.cell_cipher.decrypt(cipher, x509.load_pem_x509_certificate(origin_cert))
[docs] def get_certificate(self, fqcn: str) -> bytes: if not self.cell_cipher: raise RuntimeError("This cell doesn't support certificate exchange, not running in secure mode") return self.cert_cache.get(fqcn)
[docs] def create_request(self) -> dict: req = { CERT_CONTENT: self.local_cert, CERT_CA_CONTENT: self.ca_cert, } return req
[docs] def process_request(self, request: Message) -> dict: origin = request.get_header(MessageHeaderKey.ORIGIN) target = request.get_header(MessageHeaderKey.DESTINATION) reply = {} if not self.local_cert: reply[CERT_ERROR] = f"Target {target} is not running in secure mode" else: payload = request.payload cert = payload.get(CERT_CONTENT) # Save cert from requester in the cache self.cert_cache[origin] = cert reply[CERT_CONTENT] = self.local_cert reply[CERT_CA_CONTENT] = self.ca_cert return reply
[docs] def process_response(self, message: Message) -> bytes: origin = message.get_header(MessageHeaderKey.ORIGIN) reply = message.payload error = reply.get(CERT_ERROR) if error: raise RuntimeError(f"Request to get certificate from {origin} failed: {error}") cert = reply.get(CERT_CONTENT) self.cert_cache[origin] = cert return cert
[docs] def get_local_cert(self) -> Certificate: return x509.load_pem_x509_certificate(self.local_cert)
[docs] def get_local_key(self) -> RSAPrivateKey: return serialization.load_pem_private_key(self.local_key, password=None)
[docs] def get_ca_cert(self) -> Certificate: return x509.load_pem_x509_certificate(self.ca_cert)
[docs] @staticmethod def read_file(file_name: str): if not file_name: return None with open(file_name, "rb") as f: return f.read()