Source code for nvflare.app_opt.confidential_computing.aci_authorizer

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import time

import jwt
import requests

from nvflare.app_opt.confidential_computing.cc_authorizer import CCAuthorizer

ACI_NAMESPACE = "x-az-aci"


[docs] class ACIAuthorizer(CCAuthorizer): def __init__(self, maa_endpoint="sharedeus2.eus2.attest.azure.net", retry_count=5, retry_sleep=2): self.maa_endpoint = maa_endpoint self.retry_count = retry_count self.retry_sleep = retry_sleep
[docs] def generate(self): count = 0 token = "" while True: count = count + 1 try: r = requests.post( "http://localhost:8284/attest/maa", data=json.dumps({"maa_endpoint": self.maa_endpoint, "runtime_data": "ewp9"}), headers={"Content-Type": "application/json"}, ) if r.status_code == requests.codes.ok: token = r.json().get("token") break except: if count > self.retry_count: break time.sleep(self.retry_sleep) return token
[docs] def verify(self, token): try: header = jwt.get_unverified_header(token) alg = header.get("alg") jwks_client = jwt.PyJWKClient(f"https://{self.maa_endpoint}/certs") signing_key = jwks_client.get_signing_key_from_jwt(token) claims = jwt.decode(token, signing_key.key, algorithms=[alg]) if claims: return True except: return False return False
[docs] def get_namespace(self) -> str: return ACI_NAMESPACE