Source code for nvflare.recipe.poc_env

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import threading
import time
from typing import Optional

from pydantic import BaseModel, conint, model_validator

from nvflare.fuel.utils.log_utils import get_obj_logger
from nvflare.job_config.api import FedJob
from nvflare.recipe.spec import ExecEnv
from nvflare.recipe.utils import collect_non_local_scripts
from nvflare.tool.poc.poc_commands import (
    _clean_poc,
    _start_poc,
    _stop_poc,
    get_poc_workspace,
    get_prod_dir,
    is_poc_running,
    prepare_poc_provision,
    setup_service_config,
)
from nvflare.tool.poc.service_constants import FlareServiceConstants as SC

from .session_mgr import SessionManager

STOP_POC_TIMEOUT = 10
SERVICE_START_TIMEOUT = 3
DEFAULT_ADMIN_USER = "admin@nvidia.com"


# Internal — not part of the public API
class _PocEnvValidator(BaseModel):
    num_clients: Optional[conint(gt=0)] = None
    clients: Optional[list[str]] = None
    gpu_ids: Optional[list[int]] = None
    use_he: bool = False
    docker_image: Optional[str] = None
    project_conf_path: str = ""
    username: str = DEFAULT_ADMIN_USER

    @model_validator(mode="after")
    def check_client_configuration(self):
        # Check if clients list is empty
        if self.clients is not None and len(self.clients) == 0:
            raise ValueError("clients list cannot be empty")

        # Check if both num_clients and clients are specified and inconsistent
        if self.clients is not None and self.num_clients > 0 and len(self.clients) != self.num_clients:
            raise ValueError(
                f"Inconsistent: num_clients={self.num_clients} but clients list has {len(self.clients)} entries"
            )

        # Check if num_clients is valid when clients is None
        if self.clients is None and (self.num_clients is None or self.num_clients <= 0):
            raise ValueError("num_clients must be greater than 0")

        return self


[docs] class PocEnv(ExecEnv): """Proof of Concept execution environment for local testing and development. This environment sets up a POC deployment on a single machine with multiple processes representing the server, clients, and admin console. """ def __init__( self, *, num_clients: Optional[int] = 2, clients: Optional[list[str]] = None, gpu_ids: Optional[list[int]] = None, use_he: bool = False, docker_image: str = None, project_conf_path: str = "", username: str = DEFAULT_ADMIN_USER, extra: Optional[dict] = None, ): """Initialize POC execution environment. Args: num_clients (int, optional): Number of clients to use in POC mode. Defaults to 2. clients (list[str], optional): List of client names. If None, will generate site-1, site-2, etc. Defaults to None. If specified, number_of_clients argument will be ignored. gpu_ids (list[int], optional): List of GPU IDs to assign to clients. If None, uses CPU only. Defaults to None. use_he (bool, optional): Whether to use HE. Defaults to False. docker_image (str, optional): Docker image to use for POC. Defaults to None. project_conf_path (str, optional): Path to the project configuration file. Defaults to "". If specified, 'number_of_clients','clients' and 'docker' specific options will be ignored. username (str, optional): Admin user. Defaults to "admin@nvidia.com". extra: extra env info. """ super().__init__(extra) self.logger = get_obj_logger(self) v = _PocEnvValidator( num_clients=num_clients, clients=clients, gpu_ids=gpu_ids, use_he=use_he, docker_image=docker_image, project_conf_path=project_conf_path, username=username, ) self.clients = v.clients self.num_clients = len(v.clients) if v.clients is not None else v.num_clients self.poc_workspace = get_poc_workspace() self.gpu_ids = v.gpu_ids or [] self.use_he = v.use_he self.project_conf_path = v.project_conf_path self.docker_image = v.docker_image self.username = v.username self._session_manager = None # Lazy initialization self._session_manager_lock = threading.Lock()
[docs] def deploy(self, job: FedJob) -> str: """Deploy a FedJob to the POC environment. Args: job (FedJob): The FedJob to deploy. Returns: str: Job ID. Raises: ValueError: If scripts do not exist locally. """ # Validate scripts exist locally for POC non_local_scripts = collect_non_local_scripts(job) if non_local_scripts: raise ValueError( f"The following scripts do not exist locally: {non_local_scripts}. " f"For PocEnv, all scripts must be present on the local machine." ) if self._check_poc_running(): self.stop(clean_up=True) self.logger.info("Preparing and starting fresh POC services...") prepare_poc_provision( clients=self.clients or [], # Empty list if None, let prepare_clients generate number_of_clients=self.num_clients, workspace=self.poc_workspace, docker_image=self.docker_image, use_he=self.use_he, project_conf_path=self.project_conf_path, examples_dir=None, ) _start_poc( poc_workspace=self.poc_workspace, gpu_ids=self.gpu_ids, excluded=[self.username], services_list=[], ) self.logger.info("POC services started successfully") # Give services time to start up time.sleep(SERVICE_START_TIMEOUT) # Submit job using SessionManager return self._get_session_manager().submit_job(job)
def _check_poc_running(self) -> bool: """Check if POC services are currently running. Returns: bool: True if POC is running, False otherwise. """ try: project_config, service_config = setup_service_config(self.poc_workspace) except Exception: # POC workspace is not initialized yet, so we don't need to stop and clean it return False if not is_poc_running(self.poc_workspace, service_config, project_config): return False return True
[docs] def stop(self, clean_up: bool = False) -> None: """Try to stop and clean existing POC. This method is idempotent - safe to call multiple times. Args: clean_up (bool, optional): Whether to clean the POC workspace. Defaults to False. """ # Check if already stopped (idempotent) if not self._check_poc_running(): # POC already stopped or workspace doesn't exist if clean_up and os.path.exists(self.poc_workspace): self.logger.info(f"Removing POC workspace: {self.poc_workspace}") shutil.rmtree(self.poc_workspace, ignore_errors=True) self._session_manager = None # Clear stale session manager return try: project_config, service_config = setup_service_config(self.poc_workspace) self.logger.info("Stopping existing POC services...") _stop_poc( poc_workspace=self.poc_workspace, excluded=[self.username], # Exclude admin console (consistent with start) services_list=[], ) count = 0 poc_running = True while count < STOP_POC_TIMEOUT: try: if not is_poc_running(self.poc_workspace, service_config, project_config): poc_running = False break except Exception: poc_running = False break time.sleep(1) count += 1 if clean_up: if poc_running: self.logger.warning( f"POC still running after {STOP_POC_TIMEOUT} seconds, cannot clean workspace. Skipping cleanup." ) else: try: _clean_poc(self.poc_workspace) except Exception as e: self.logger.debug(f"Failed to clean POC: {e}") except Exception as e: self.logger.warning(f"Failed to stop and clean existing POC: {e}") finally: self._session_manager = None # Clear stale session manager
[docs] def get_job_status(self, job_id: str) -> Optional[str]: """Get the status of a job. Args: job_id: The job ID to check status for. Returns: Optional[str]: The status of the job, or None if not available. """ return self._get_session_manager().get_job_status(job_id)
[docs] def abort_job(self, job_id: str) -> None: """Abort a running job. Args: job_id: The job ID to abort. """ self._get_session_manager().abort_job(job_id)
[docs] def get_job_result(self, job_id: str, timeout: float = 0.0) -> Optional[str]: """Get the result workspace of a job. Args: job_id: The job ID to get results for. timeout: The timeout for the job to complete. Defaults to 0.0 (no timeout). Returns: Optional[str]: The result workspace path if job completed, None otherwise. """ return self._get_session_manager().get_job_result(job_id, timeout)
def _get_admin_startup_kit_path(self) -> str: """Get the path to the admin startup kit for POC. Returns: str: Path to admin startup kit directory. """ try: project_config, service_config = setup_service_config(self.poc_workspace) project_name = project_config.get("name") prod_dir = get_prod_dir(self.poc_workspace, project_name) # POC admin directory structure: {workspace}/{project_name}/prod_00/admin@nvidia.com project_admin_dir = service_config.get(SC.FLARE_PROJ_ADMIN, SC.FLARE_PROJ_ADMIN) admin_dir = os.path.join(prod_dir, project_admin_dir) if not os.path.exists(admin_dir): raise RuntimeError(f"Admin startup kit not found at: {admin_dir}") return admin_dir except Exception as e: raise RuntimeError(f"Failed to locate admin startup kit: {e}") from e def _get_session_manager(self) -> SessionManager: """Get or create SessionManager with lazy initialization (thread-safe).""" with self._session_manager_lock: if self._session_manager is None: session_params = { "username": self.username, "startup_kit_location": self._get_admin_startup_kit_path(), "timeout": self.get_extra_prop("login_timeout", 10), } self._session_manager = SessionManager(session_params) return self._session_manager