Source code for nvflare.app_opt.confidential_computing.snp_authorizer

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import logging
import os
import random
import re
import shutil
import subprocess
import time
import uuid

from filelock import FileLock

from nvflare.app_opt.confidential_computing.cc_authorizer import CCAuthorizer

from .utils import NonceHistory

SNP_NAMESPACE = "x-snp"
REPORT_PATH = "report.bin"
REQUEST_PATH = "request.bin"

AMD_ARK = "ark.pem"
AMD_ASK = "ask.pem"
AMD_VCEK = "vcek.pem"


[docs] def parse_chip_id(report_text: str) -> str: # Find the block starting with "Chip ID:" followed by multiple lines of hex bytes match = re.search( r"Chip ID:\s*((?:[0-9A-Fa-f]{2}\s+){15}[0-9A-Fa-f]{2}(?:\s*\n\s*(?:[0-9A-Fa-f]{2}\s+){15}[0-9A-Fa-f]{2})*)", report_text, re.MULTILINE, ) if not match: return "" # Extract all hex bytes and remove spaces/newlines hex_block = match.group(1) # Remove all whitespace characters and convert to lowercase chip_id = "".join(hex_block.split()).lower() return chip_id
[docs] def parse_reported_tcb(report_text: str) -> dict: # Match the entire Reported TCB block after the line "Reported TCB:" match = re.search( r"Reported TCB:\s*" r"TCB Version:\s*" r"Microcode:\s*(\d+)\s*" r"SNP:\s*(\d+)\s*" r"TEE:\s*(\d+)\s*" r"Boot Loader:\s*(\d+)\s*" r"FMC:\s*(\w+)", report_text, re.MULTILINE, ) if not match: return {} # Parse FMC which can be 'None' or something else microcode, snp, tee, boot_loader, fmc = match.groups() return { "Microcode": int(microcode), "SNP": int(snp), "TEE": int(tee), "Boot Loader": int(boot_loader), "FMC": None if fmc == "None" else fmc, }
[docs] class SNPAuthorizer(CCAuthorizer): """AMD SEV-SNP Authorizer""" def __init__( self, max_nonce_history=1000, amd_certs_dir="/opt/certs", snpguest_binary="snpguest", cpu_model="milan", max_retries=5, retry_interval=10, cmd_timeout=60, ): """ Initialize the SNPAuthorizer instance. Args: max_nonce_history (int, optional): Maximum number of nonces to keep in history for replay protection. Defaults to 1000. amd_certs_dir (str, optional): Directory path where AMD certificates are stored. Defaults to "/opt/certs". snpguest_binary (str, optional): Path to the `snpguest` binary used for generating and verifying reports. Defaults to "/host/bin/snpguest". cpu_model (str, optional): CPU model identifier used when fetching certificates. Defaults to "milan". max_retries (int): Max number of retries on transient failures. retry_interval (int): Wait time (seconds) between retries. cmd_timeout (int): SNPGuest command timeout. """ super().__init__() self.logger = logging.getLogger(self.__class__.__name__) self.my_nonce_history = NonceHistory(max_nonce_history) self.seen_nonce_history = NonceHistory(max_nonce_history) self.amd_certs_dir = amd_certs_dir self.snpguest_binary = snpguest_binary self.cpu_model = cpu_model self.max_retries = max_retries self.retry_interval = retry_interval self.cmd_timeout = cmd_timeout def _run_with_retry(self, cmd: list[str], action_name: str) -> subprocess.CompletedProcess: for attempt in range(1, self.max_retries + 1): self.logger.info(f"[{action_name}] Attempt {attempt}/{self.max_retries}: running {cmd}") try: result = subprocess.run(cmd, capture_output=True, timeout=self.cmd_timeout) if result.returncode == 0: return result else: self.logger.warning( f"[{action_name}] Failed with return code {result.returncode}. " f"stderr: {result.stderr.decode().strip()}" ) except subprocess.TimeoutExpired: self.logger.warning(f"[{action_name}] Command timed out.") if attempt < self.max_retries: time.sleep(min(self.retry_interval * 2 ** (attempt - 1), 60)) # Exponential backoff raise RuntimeError(f"[{action_name}] Failed after {self.max_retries} attempts.")
[docs] def generate(self): nonce = bytearray([random.randint(0, 255) for _ in range(64)]) with open(REQUEST_PATH, "wb") as request_file: request_file.write(nonce) cmd = [self.snpguest_binary, "report", REPORT_PATH, REQUEST_PATH] self._run_with_retry(cmd, "generate_report") with open(REPORT_PATH, "rb") as report_file: token = base64.b64encode(report_file.read()) self.my_nonce_history.add(nonce) return token
[docs] def verify(self, token): tmp_bin_file = uuid.uuid4().hex try: self._ensure_amd_ca_certs() report_bin = base64.b64decode(token) with open(tmp_bin_file, "wb") as report_file: report_file.write(report_bin) vcek_cache_key = self._parse_report(tmp_bin_file) self._ensure_amd_vcek(vcek_cache_key, tmp_bin_file) # Verify attestation cmd = [self.snpguest_binary, "verify", "attestation", self.amd_certs_dir, tmp_bin_file] result = self._run_with_retry(cmd, "verify_attestation") if result.returncode == 0: self.logger.info("Attestation passed") if self._check_nonce(tmp_bin_file): self.logger.info("Check nonce passed") return True else: self.logger.info("Check nonce failed") return False else: self.logger.warning("Attestation verification failed.") return False except Exception as e: self.logger.error(f"Token verification failed: {e}") return False finally: if os.path.exists(tmp_bin_file): os.remove(tmp_bin_file)
def _ensure_amd_ca_certs(self): """Ensures AMD CA certs are inside the amd_certs_dir.""" ask_path = os.path.join(self.amd_certs_dir, AMD_ASK) ark_path = os.path.join(self.amd_certs_dir, AMD_ARK) if not (os.path.exists(ark_path) and os.path.exists(ask_path)): self.logger.info("AMD CA certs not found. Fetching...") cmd = [self.snpguest_binary, "fetch", "ca", "pem", self.amd_certs_dir, self.cpu_model] self._run_with_retry(cmd, "fetch_ca_certs") else: self.logger.info("AMD CA certs already exist.") def _ensure_amd_vcek(self, vcek_cache_key, report_bin_file, timeout=60): """Ensures AMD VCEK is inside the amd_certs_dir.""" cache_path = os.path.join(self.amd_certs_dir, vcek_cache_key) vcek_file = os.path.join(self.amd_certs_dir, AMD_VCEK) lock_file = cache_path + ".lock" with FileLock(lock_file, timeout=timeout): if not os.path.exists(cache_path): self.logger.info("AMD VCEK not cached. Fetching and caching...") cmd = [self.snpguest_binary, "fetch", "vcek", "pem", self.amd_certs_dir, report_bin_file] self._run_with_retry(cmd, "fetch_vcek") if not os.path.exists(vcek_file): raise RuntimeError(f"VCEK file not generated at expected path: {vcek_file}") # Rename vcek.pem to the cache file name shutil.move(vcek_file, cache_path) else: self.logger.info("Using cached AMD VCEK") shutil.copy(cache_path, vcek_file) def _parse_report(self, report_bin_file): """Parses the Reported TCB and Chip ID info. This method is used to generate a unique id to cache VCEK. Because AMD KDS has rate limitation, we should avoid keep polling. """ cmd = [self.snpguest_binary, "display", "report", report_bin_file] cp = subprocess.run(cmd, capture_output=True) if cp.returncode != 0: self.logger.error("Can't display SNP report") raise RuntimeError("Can't display SNP report") output_string = cp.stdout report_text = output_string.decode("utf-8") chip_id = parse_chip_id(report_text) reported_tcb = parse_reported_tcb(report_text) if not reported_tcb: raise RuntimeError("Failed to parse Reported TCB from report") cache_key = f"{chip_id}-{reported_tcb['Microcode']}-{reported_tcb['SNP']}-{reported_tcb['TEE']}-{reported_tcb['Boot Loader']}" return cache_key def _check_nonce(self, report_bin_file): """Parses nonce from the Report Data section and checks if it is fresh.""" cmd = [self.snpguest_binary, "display", "report", report_bin_file] cp = subprocess.run(cmd, capture_output=True) if cp.returncode != 0: return False output_string = cp.stdout lines = output_string.decode("utf-8").split("\n") report_data_string = "" for i in range(len(lines)): if lines[i] == "Report Data:": report_data_string = " ".join(lines[i + 1 : i + 6]).replace(" ", "") break return self.seen_nonce_history.add(report_data_string)
[docs] def get_namespace(self) -> str: return SNP_NAMESPACE