# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os.path
from typing import Optional
from pydantic import BaseModel, PositiveFloat, model_validator
from nvflare.apis.job_def import DEFAULT_STUDY
from nvflare.apis.utils.format_check import name_check
from nvflare.job_config.api import FedJob
from nvflare.recipe.spec import ExecEnv
from nvflare.recipe.utils import collect_non_local_scripts
from .session_mgr import SessionManager
logger = logging.getLogger(__name__)
DEFAULT_ADMIN_USER = "admin@nvidia.com"
# Internal — not part of the public API
class _ProdEnvValidator(BaseModel):
startup_kit_location: str
login_timeout: PositiveFloat = 5.0
username: str = DEFAULT_ADMIN_USER
study: str = DEFAULT_STUDY
@model_validator(mode="after")
def check_startup_kit_location_exists(self) -> "_ProdEnvValidator":
if not os.path.exists(self.startup_kit_location):
raise ValueError(f"startup_kit_location path does not exist: {self.startup_kit_location}")
if name_check(self.study, "study")[0]:
raise ValueError(
f"study name '{self.study}' contains unsupported characters. Use only lowercase letters, numbers, underscores, and hyphens."
)
return self
[docs]
class ProdEnv(ExecEnv):
def __init__(
self,
startup_kit_location: str,
login_timeout: float = 5.0,
username: str = DEFAULT_ADMIN_USER,
study: str = DEFAULT_STUDY,
extra: Optional[dict] = None,
):
"""Production execution environment for submitting and monitoring NVFlare jobs.
This environment uses the startup kit of an NVFlare deployment to submit jobs via the Flare API.
Args:
startup_kit_location (str): Path to the admin's startup kit directory.
login_timeout (float): Timeout (in seconds) for logging into the Flare API session. Must be > 0.
username (str): Username to log in with.
study (str): Study name to tag submitted jobs. Defaults to "default".
extra: extra env info.
"""
super().__init__(extra)
v = _ProdEnvValidator(
startup_kit_location=startup_kit_location,
login_timeout=login_timeout,
username=username,
study=study,
)
self.startup_kit_location = v.startup_kit_location
self.login_timeout = v.login_timeout
self.username = v.username
self.study = v.study
self._session_manager = None # Lazy initialization
[docs]
def get_job_status(self, job_id: str) -> Optional[str]:
return self._get_session_manager().get_job_status(job_id)
[docs]
def abort_job(self, job_id: str) -> None:
self._get_session_manager().abort_job(job_id)
[docs]
def get_job_result(self, job_id: str, timeout: float = 0.0) -> Optional[str]:
return self._get_session_manager().get_job_result(job_id, timeout)
[docs]
def deploy(self, job: FedJob) -> str:
"""Deploy a job using SessionManager."""
# Log warnings for non-local scripts (assumed pre-installed on production)
non_local_scripts = collect_non_local_scripts(job)
for script in non_local_scripts:
logger.warning(
f"Script '{script}' not found locally. " f"Assuming it is pre-installed on the production system."
)
try:
return self._get_session_manager().submit_job(job)
except Exception as e:
raise RuntimeError(f"Failed to submit job via Flare API: {e}") from e
def _get_session_manager(self):
"""Get or create SessionManager with lazy initialization."""
if self._session_manager is None:
session_params = {
"username": self.username,
"startup_kit_location": self.startup_kit_location,
"timeout": self.login_timeout,
"study": self.study,
}
self._session_manager = SessionManager(session_params)
return self._session_manager