Source code for nvflare.fuel.f3.cellnet.cell_cipher

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from cryptography.exceptions import InvalidKey, InvalidSignature
from cryptography.hazmat.primitives import asymmetric, ciphers, hashes, padding
from cryptography.x509 import Certificate

HASH_LENGTH = 4  # Adjustable to avoid collision
NONCE_LENGTH = 16  # For AES, this is 128 bits (i.e. block size)
KEY_LENGTH = 32  # AES 256.  Choose from 16, 24, 32
HEADER_LENGTH = HASH_LENGTH + NONCE_LENGTH
PADDING_LENGTH = NONCE_LENGTH * 8  # in bits
KEY_ENC_LENGTH = 256
SIGNATURE_LENGTH = 256
SIMPLE_HEADER_LENGTH = NONCE_LENGTH + KEY_ENC_LENGTH + SIGNATURE_LENGTH


[docs]def get_hash(value): hash = hashes.Hash(hashes.SHA256()) hash.update(value) return hash.finalize()
[docs]class SessionKeyUnavailable(Exception): pass
[docs]class InvalidCertChain(Exception): pass
def _asym_enc(k, m): return k.encrypt( m, asymmetric.padding.OAEP( mgf=asymmetric.padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None ), ) def _asym_dec(k, m): return k.decrypt( m, asymmetric.padding.OAEP( mgf=asymmetric.padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None ), ) def _sign(k, m): return k.sign( data=m, padding=asymmetric.padding.PSS( mgf=asymmetric.padding.MGF1(hashes.SHA256()), salt_length=asymmetric.padding.PSS.MAX_LENGTH, ), algorithm=hashes.SHA256(), ) def _verify(k, m, s): if not isinstance(m, bytes): m = bytes(m) if not isinstance(s, bytes): s = bytes(s) k.verify( s, m, asymmetric.padding.PSS( mgf=asymmetric.padding.MGF1(hashes.SHA256()), salt_length=asymmetric.padding.PSS.MAX_LENGTH ), hashes.SHA256(), ) def _sym_enc(k, n, m): cipher = ciphers.Cipher(ciphers.algorithms.AES(k), ciphers.modes.CBC(n)) encryptor = cipher.encryptor() padder = padding.PKCS7(PADDING_LENGTH).padder() padded_data = padder.update(m) + padder.finalize() return encryptor.update(padded_data) + encryptor.finalize() def _sym_dec(k, n, m): cipher = ciphers.Cipher(ciphers.algorithms.AES(k), ciphers.modes.CBC(n)) decryptor = cipher.decryptor() plain_text = decryptor.update(m) plain_text = plain_text + decryptor.finalize() unpadder = padding.PKCS7(PADDING_LENGTH).unpadder() return unpadder.update(plain_text) + unpadder.finalize()
[docs]class SessionKeyManager: def __init__(self, root_ca): self.key_hash_dict = dict() self.root_ca = root_ca self.root_ca_pub_key = root_ca.public_key()
[docs] def validate_cert_chain(self, cert): self.root_ca_pub_key.verify( cert.signature, cert.tbs_certificate_bytes, asymmetric.padding.PKCS1v15(), cert.signature_hash_algorithm )
[docs] def key_request(self, remote_cert, local_cert, local_pri_key): session_key = os.urandom(KEY_LENGTH) signature = _sign(local_pri_key, session_key) try: self.validate_cert_chain(remote_cert) except InvalidSignature: return False remote_pub_key = remote_cert.public_key() key_enc = _asym_enc(remote_pub_key, session_key) self.key_hash_dict[get_hash(session_key)[-HASH_LENGTH:]] = session_key key_response = key_enc + signature return key_response
[docs] def process_key_response(self, remote_cert, local_cert, local_pri_key, key_response): key_enc, signature = key_response[:KEY_ENC_LENGTH], key_response[KEY_ENC_LENGTH:] try: session_key = _asym_dec(local_pri_key, key_enc) self.validate_cert_chain(remote_cert) public_key = remote_cert.public_key() _verify(public_key, session_key, signature) self.key_hash_dict[get_hash(session_key)[-HASH_LENGTH:]] = session_key except (InvalidKey, InvalidSignature): return False return True
[docs] def key_available(self): return bool(self.key_hash_dict)
[docs] def get_key(self, key_hash): return self.key_hash_dict.get(key_hash)
[docs] def get_latest_key(self): try: k, last_value = _, self.key_hash_dict[k] = self.key_hash_dict.popitem() except KeyError as e: raise SessionKeyUnavailable("No session key established yet") return last_value
[docs]class CellCipher: def __init__(self, session_key_manager: SessionKeyManager): self.session_key_manager = session_key_manager
[docs] def encrypt(self, message): key = self.session_key_manager.get_latest_key() key_hash = get_hash(key) nonce = os.urandom(NONCE_LENGTH) return nonce + key_hash[-HASH_LENGTH:] + _sym_enc(key, nonce, message)
[docs] def decrypt(self, message): nonce, key_hash, message = ( message[:NONCE_LENGTH], message[NONCE_LENGTH:HEADER_LENGTH], message[HEADER_LENGTH:], ) key = self.session_key_manager.get_key(key_hash) if key is None: raise SessionKeyUnavailable("No session key found for received message") return _sym_dec(key, nonce, message)
[docs]class SimpleCellCipher: def __init__(self, root_ca: Certificate, pri_key: asymmetric.rsa.RSAPrivateKey, cert: Certificate): self._root_ca = root_ca self._root_ca_pub_key = root_ca.public_key() self._pri_key = pri_key self._cert = cert self._pub_key = cert.public_key() self._validate_cert_chain(self._cert) self._cached_enc = dict() self._cached_dec = dict() def _validate_cert_chain(self, cert: Certificate): self._root_ca_pub_key.verify( cert.signature, cert.tbs_certificate_bytes, asymmetric.padding.PKCS1v15(), cert.signature_hash_algorithm )
[docs] def encrypt(self, message: bytes, target_cert: Certificate): cert_hash = hash(target_cert) secret = self._cached_enc.get(cert_hash) if secret is None: self._validate_cert_chain(target_cert) key = os.urandom(KEY_LENGTH) remote_pub_key = target_cert.public_key() key_enc = _asym_enc(remote_pub_key, key) signature = _sign(self._pri_key, key_enc) self._cached_enc[cert_hash] = (key, key_enc, signature) else: (key, key_enc, signature) = secret nonce = os.urandom(NONCE_LENGTH) ct = nonce + key_enc + signature + _sym_enc(key, nonce, message) return ct
[docs] def decrypt(self, message: bytes, origin_cert: Certificate): nonce, key_enc, signature = ( message[:NONCE_LENGTH], message[NONCE_LENGTH : NONCE_LENGTH + KEY_ENC_LENGTH], message[NONCE_LENGTH + KEY_ENC_LENGTH : SIMPLE_HEADER_LENGTH], ) if not isinstance(key_enc, bytes): key_enc = bytes(key_enc) key_hash = hash(key_enc) dec = self._cached_dec.get(key_hash) if dec is None: self._validate_cert_chain(origin_cert) public_key = origin_cert.public_key() _verify(public_key, key_enc, signature) key = _asym_dec(self._pri_key, key_enc) self._cached_dec[key_hash] = key else: key = dec return _sym_dec(key, nonce, message[SIMPLE_HEADER_LENGTH:])