Source code for nvflare.security.study_registry

# Copyright (c) 2026, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
from copy import deepcopy
from typing import Dict, Optional


[docs] class StudyRegistry: FORMAT_VERSION = "1.0" def __init__(self, studies_config: dict): if not isinstance(studies_config, dict): raise ValueError(f"studies_config must be dict but got {type(studies_config)}") format_version = studies_config.get("format_version") if format_version != self.FORMAT_VERSION: raise ValueError(f"missing or invalid study registry format_version: must be {self.FORMAT_VERSION}") studies = studies_config.get("studies") if not isinstance(studies, dict): raise ValueError(f"study registry 'studies' must be dict but got {type(studies)}") self._admins = {} self._site_orgs = {} self._sites = {} self._studies = {} for study_name, study_def in studies.items(): study_def = study_def or {} admins = study_def.get("admins", []) if admins is None: admins = [] if not isinstance(admins, list): raise ValueError(f"study '{study_name}' admins must be list but got {type(admins)}") site_orgs = study_def.get("site_orgs", {}) if site_orgs is None: site_orgs = {} if not isinstance(site_orgs, dict): raise ValueError(f"study '{study_name}' site_orgs must be dict but got {type(site_orgs)}") admin_list = [] seen_admins = set() for admin in admins: if not isinstance(admin, str): raise ValueError(f"study '{study_name}' admin entries must be str but got {type(admin)}") if admin in seen_admins: continue seen_admins.add(admin) admin_list.append(admin) normalized_site_orgs = {} sites = set() seen_sites = set() for org, org_sites in site_orgs.items(): if not isinstance(org_sites, list): raise ValueError(f"study '{study_name}' site_orgs[{org}] must be list but got {type(org_sites)}") normalized_sites = [] for site in org_sites: if not isinstance(site, str): raise ValueError( f"study '{study_name}' site entry for org '{org}' must be str but got {type(site)}" ) if site in seen_sites: raise ValueError(f"study '{study_name}' contains duplicate site '{site}' across org groups") seen_sites.add(site) normalized_sites.append(site) sites.add(site) normalized_site_orgs[org] = normalized_sites self._admins[study_name] = set(admin_list) self._site_orgs[study_name] = normalized_site_orgs self._sites[study_name] = sites self._studies[study_name] = { "site_orgs": deepcopy(normalized_site_orgs), "sites": sorted(sites), "admins": list(admin_list), }
[docs] def has_user(self, user_name: str, study: str) -> bool: return user_name in self._admins.get(study, set())
[docs] def get_sites(self, study: str) -> Optional[set]: return self._sites.get(study)
[docs] def has_study(self, study: str) -> bool: return study in self._studies
[docs] def has_org(self, study: str, org: str) -> bool: return org in self._site_orgs.get(study, {})
[docs] def get_site_orgs(self, study: str) -> Optional[dict]: site_orgs = self._site_orgs.get(study) return deepcopy(site_orgs) if site_orgs is not None else None
[docs] def get_studies(self) -> Dict[str, dict]: return deepcopy(self._studies)
[docs] def get_study(self, study: str) -> Optional[dict]: study_def = self._studies.get(study) return deepcopy(study_def) if study_def is not None else None
[docs] class StudyRegistryService: _registry: Optional[StudyRegistry] = None _mutation_lock = threading.Lock()
[docs] @staticmethod def initialize(registry: Optional[StudyRegistry]): StudyRegistryService._registry = registry
[docs] @staticmethod def get_registry() -> Optional[StudyRegistry]: return StudyRegistryService._registry
[docs] @staticmethod def reset(): StudyRegistryService._registry = None
[docs] @staticmethod def acquire_lock(timeout: float) -> bool: return StudyRegistryService._mutation_lock.acquire(timeout=timeout)
[docs] @staticmethod def release_lock(): StudyRegistryService._mutation_lock.release()