Source code for nvflare.app_opt.job_launcher.workspace_cell_transfer

# Copyright (c) 2026, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CellNet-based workspace transfer for launched jobs.

The parent process exposes a small transfer service on its existing CellNet cell.
Launched job pods create short-lived bootstrap child cells to:

1. request a workspace bundle from the parent
2. upload final job results back to the parent

The actual payload transfer uses the existing F3 file downloader infrastructure,
so large bundles move in chunks instead of being buffered into a single message.
"""

from __future__ import annotations

import hashlib
import logging
import os
import secrets
import shutil
import stat
import tempfile
import threading
import time
import zipfile
from dataclasses import dataclass
from pathlib import PurePosixPath

from nvflare.fuel.f3.cellnet.cell import Cell
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode
from nvflare.fuel.f3.cellnet.fqcn import FQCN
from nvflare.fuel.f3.cellnet.net_agent import NetAgent
from nvflare.fuel.f3.cellnet.utils import make_reply, new_cell_message
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.f3.message import Message
from nvflare.fuel.f3.streaming.download_service import DownloadService
from nvflare.fuel.f3.streaming.file_downloader import add_file, download_file
from nvflare.fuel.f3.streaming.obj_downloader import ObjectDownloader
from nvflare.fuel.sec.authn import set_add_auth_headers_filters
from nvflare.private.defs import AUTH_CLIENT_NAME_FOR_SJ
from nvflare.security.logging import secure_format_exception

logger = logging.getLogger(__name__)

ENV_WORKSPACE_OWNER_FQCN = "NVFL_WORKSPACE_OWNER_FQCN"
ENV_WORKSPACE_TRANSFER_TOKEN = "NVFL_WORKSPACE_TRANSFER_TOKEN"

WORKSPACE_TRANSFER_CHANNEL = "workspace_transfer"
TOPIC_PREPARE_DOWNLOAD = "prepare_download"
TOPIC_PUBLISH_RESULTS = "publish_results"

DOWNLOAD_TIMEOUT = 600.0
PER_REQUEST_TIMEOUT = 300.0
BOOTSTRAP_CONNECT_TIMEOUT = 30.0
BOOTSTRAP_CONNECT_POLL_INTERVAL = 0.1

_BOOTSTRAP_CELL_PREFIX = "ws_transfer_"
_WORKSPACE_DOWNLOAD_EXCLUDES = frozenset({"local/study_data.yaml"})


@dataclass
class _JobTransferRecord:
    job_id: str
    workspace_root: str
    transfer_token: str
    download_tx_id: str = ""
    download_bundle_path: str = ""


def _write_dir_to_zip(zf: zipfile.ZipFile, src: str, root: str, excluded_paths: frozenset[str] = frozenset()) -> None:
    if not os.path.isdir(src):
        return
    for dirpath, _dirs, files in os.walk(src):
        for fname in files:
            abs_path = os.path.join(dirpath, fname)
            rel_path = os.path.relpath(abs_path, root).replace(os.sep, "/")
            if rel_path in excluded_paths:
                continue
            zf.write(abs_path, rel_path)


def _zip_workspace_to_file(workspace_root: str, job_id: str, file_path: str) -> None:
    with zipfile.ZipFile(file_path, "w", zipfile.ZIP_DEFLATED) as zf:
        _write_dir_to_zip(zf, os.path.join(workspace_root, "local"), workspace_root, _WORKSPACE_DOWNLOAD_EXCLUDES)
        _write_dir_to_zip(zf, os.path.join(workspace_root, job_id), workspace_root)


def _zip_results_to_file(workspace_root: str, job_id: str, file_path: str) -> None:
    with zipfile.ZipFile(file_path, "w", zipfile.ZIP_DEFLATED) as zf:
        _write_dir_to_zip(zf, os.path.join(workspace_root, job_id), workspace_root)


def _validate_relative_zip_members(zf: zipfile.ZipFile) -> None:
    for info in zf.infolist():
        path = PurePosixPath(info.filename)
        if path.is_absolute() or ".." in path.parts:
            raise ValueError(f"unsafe zip member: {info.filename}")


def _validate_job_zip_members(zf: zipfile.ZipFile, job_id: str) -> None:
    _validate_relative_zip_members(zf)
    for info in zf.infolist():
        if stat.S_ISLNK(info.external_attr >> 16):
            raise ValueError(f"symlink not allowed in results archive: {info.filename}")
        parts = PurePosixPath(info.filename).parts
        if not parts or parts[0] != job_id:
            raise ValueError(f"zip member outside job workspace: {info.filename}")


def _hash_file(path: str) -> str:
    digest = hashlib.sha256()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(1024 * 1024), b""):
            digest.update(chunk)
    return digest.hexdigest()


[docs] def make_workspace_transfer_fqcn(owner_fqcn: str, job_id: str) -> str: return FQCN.join([owner_fqcn, f"{_BOOTSTRAP_CELL_PREFIX}{job_id}"])
def _cleanup_files(paths) -> None: for path in paths: if not path: continue try: os.remove(path) except FileNotFoundError: pass def _cleanup_transfer_files(_tx_id: str, _status: str, _objects: list, temp_paths=None, **_kwargs) -> None: _cleanup_files(temp_paths or []) def _cleanup_download(tx_id: str, bundle_path: str) -> None: if tx_id: DownloadService.delete_transaction(tx_id) if bundle_path: _cleanup_files([bundle_path]) def _make_error(message: str, rc: str = ReturnCode.INVALID_REQUEST) -> Message: logger.error(message) return make_reply(rc, error=message)
[docs] class WorkspaceTransferManager: """Manage per-job workspace transfer over an existing CellNet cell.""" def __init__( self, cell: Cell, download_timeout: float = DOWNLOAD_TIMEOUT, per_request_timeout: float = PER_REQUEST_TIMEOUT, ): self.cell = cell self.owner_fqcn = cell.get_fqcn() self.download_timeout = download_timeout self.per_request_timeout = per_request_timeout self.jobs: dict[str, _JobTransferRecord] = {} self._lock = threading.Lock() self.cell.register_request_cb( channel=WORKSPACE_TRANSFER_CHANNEL, topic=TOPIC_PREPARE_DOWNLOAD, cb=self._handle_prepare_download, ) self.cell.register_request_cb( channel=WORKSPACE_TRANSFER_CHANNEL, topic=TOPIC_PUBLISH_RESULTS, cb=self._handle_publish_results, )
[docs] @classmethod def get_or_create(cls, cell: Cell) -> "WorkspaceTransferManager": """Return the per-cell manager, constructing one on first use.""" lock = cell.__dict__.setdefault("_workspace_transfer_lock", threading.Lock()) with lock: manager = cell.__dict__.get("_workspace_transfer_manager") if manager is None: manager = cls(cell) cell.__dict__["_workspace_transfer_manager"] = manager return manager
[docs] def add_job(self, job_id: str, workspace_root: str) -> str: record = _JobTransferRecord( job_id=job_id, workspace_root=workspace_root, transfer_token=secrets.token_urlsafe(24), ) with self._lock: old = self.jobs.get(job_id) self.jobs[job_id] = record if old is not None: _cleanup_download(old.download_tx_id, old.download_bundle_path) return record.transfer_token
[docs] def remove_job(self, job_id: str) -> None: with self._lock: record = self.jobs.pop(job_id, None) if record is not None: _cleanup_download(record.download_tx_id, record.download_bundle_path)
def _get_record(self, job_id: str) -> _JobTransferRecord | None: with self._lock: return self.jobs.get(job_id) def _authorize(self, request: Message, action: str) -> tuple[_JobTransferRecord | None, str, Message | None]: """Validate request payload and token. Returns (record, origin, error_reply).""" origin = request.get_header(MessageHeaderKey.ORIGIN) or "" payload = request.payload if not isinstance(payload, dict): return None, origin, _make_error(f"{action} payload must be dict") job_id = payload.get("job_id") if not job_id: return None, origin, _make_error(f"{action} missing job_id") transfer_token = payload.get("transfer_token") if not transfer_token: return None, origin, _make_error(f"{action} missing transfer_token") record = self._get_record(job_id) if not record: return None, origin, _make_error(f"unknown job_id for {action}: {job_id}", rc=ReturnCode.INVALID_TARGET) if not secrets.compare_digest(transfer_token, record.transfer_token): return None, origin, _make_error(f"{action} token mismatch for {job_id}", rc=ReturnCode.UNAUTHENTICATED) return record, origin, None def _handle_prepare_download(self, request: Message) -> Message: record, _origin, err = self._authorize(request, "prepare_download") if err is not None: return err job_id = record.job_id tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".zip") tmp.close() downloader = None try: _zip_workspace_to_file(record.workspace_root, record.job_id, tmp.name) downloader = ObjectDownloader( cell=self.cell, timeout=self.download_timeout, num_receivers=1, transaction_done_cb=self._download_transaction_done, job_id=job_id, ) ref_id = add_file(downloader, tmp.name) bundle_sha = _hash_file(tmp.name) bundle_size = os.path.getsize(tmp.name) except Exception as e: _cleanup_download(downloader.tx_id if downloader else "", tmp.name) return _make_error(f"failed to prepare workspace download for {job_id}: {e}") with self._lock: current = self.jobs.get(job_id) if not current: _cleanup_download(downloader.tx_id, tmp.name) return _make_error(f"job removed while preparing workspace download: {job_id}") old_tx_id, old_bundle_path = current.download_tx_id, current.download_bundle_path current.download_tx_id = downloader.tx_id current.download_bundle_path = tmp.name _cleanup_download(old_tx_id, old_bundle_path) return make_reply( ReturnCode.OK, body={ "job_id": job_id, "ref_id": ref_id, "sha256": bundle_sha, "size": bundle_size, }, ) def _download_transaction_done(self, tx_id: str, _status: str, objects: list, job_id: str) -> None: _cleanup_files(objects) with self._lock: record = self.jobs.get(job_id) if not record or record.download_tx_id != tx_id: return record.download_tx_id = "" record.download_bundle_path = "" def _handle_publish_results(self, request: Message) -> Message: logger.info("[ws-transfer] publish_results handler entered on parent") record, origin, err = self._authorize(request, "publish_results") if err is not None: logger.info("[ws-transfer] publish_results rejected by authorize") return err job_id = record.job_id if not origin: return _make_error("publish_results missing request origin") payload = request.payload ref_id = payload.get("ref_id") if not ref_id: return _make_error("publish_results missing ref_id") logger.info("[ws-transfer] publish_results accepted for job=%s origin=%s ref=%s", job_id, origin, ref_id) expected_sha = payload.get("sha256") temp_dir = tempfile.mkdtemp(prefix="workspace-results-") try: err, file_path = download_file( from_fqcn=origin, ref_id=ref_id, per_request_timeout=self.per_request_timeout, cell=self.cell, location=temp_dir, ) if err: logger.info("[ws-transfer] publish_results download_file failed for %s: %s", job_id, err) return _make_error(f"failed to download results for {job_id}: {err}", rc=ReturnCode.COMM_ERROR) if expected_sha and _hash_file(file_path) != expected_sha: return _make_error(f"results checksum mismatch for {job_id}") os.makedirs(record.workspace_root, exist_ok=True) with zipfile.ZipFile(file_path) as zf: _validate_job_zip_members(zf, record.job_id) zf.extractall(record.workspace_root) logger.info("[ws-transfer] publish_results extracted job=%s into %s", job_id, record.workspace_root) except ValueError as e: return _make_error(str(e)) except zipfile.BadZipFile as e: return _make_error(f"invalid results bundle for {job_id}: {e}") except Exception as e: return _make_error(f"unexpected error processing results for {job_id}: {e}") finally: shutil.rmtree(temp_dir, ignore_errors=True) self.remove_job(job_id) return make_reply(ReturnCode.OK, body={"job_id": job_id})
# Process-level singleton. download_workspace runs at startup and # upload_results runs at shutdown inside the SAME process, so they share # one bootstrap cell for the life of the job. Creating a second cell with # the same FQCN crashes CellNet's registry ("there is already a cell # named ..."), and nothing in CellNet unregisters the name reliably after # cell.stop(), so keeping one alive is the simplest contract. _bootstrap_cell: Cell | None = None _bootstrap_net_agent: NetAgent | None = None _bootstrap_lock = threading.Lock() def _get_bootstrap_cell(args, owner_fqcn: str, secure_mode: bool) -> Cell: global _bootstrap_cell, _bootstrap_net_agent with _bootstrap_lock: if _bootstrap_cell is None: _bootstrap_cell, _bootstrap_net_agent = _create_bootstrap_cell( args=args, owner_fqcn=owner_fqcn, secure_mode=secure_mode ) return _bootstrap_cell def _close_bootstrap_cell() -> None: global _bootstrap_cell, _bootstrap_net_agent with _bootstrap_lock: if _bootstrap_net_agent is not None: try: _bootstrap_net_agent.close() except Exception: pass _bootstrap_net_agent = None if _bootstrap_cell is not None: try: _bootstrap_cell.stop() except Exception: pass _bootstrap_cell = None def _get_root_url(args) -> str: root_url = getattr(args, "root_url", "") if root_url: return root_url scheme = getattr(args, "sp_scheme", "") target = getattr(args, "sp_target", "") if scheme and target: return f"{scheme}://{target}" raise RuntimeError("unable to determine root_url for workspace transfer bootstrap cell") def _get_bootstrap_tls_pair(startup_dir: str, owner_fqcn: str) -> tuple[str, str, str, str]: prefer_server = FQCN.get_root(owner_fqcn) == FQCN.ROOT_SERVER if prefer_server: candidates = [ ("server.crt", "server.key", DriverParams.SERVER_CERT.value, DriverParams.SERVER_KEY.value), ("client.crt", "client.key", DriverParams.CLIENT_CERT.value, DriverParams.CLIENT_KEY.value), ] else: candidates = [ ("client.crt", "client.key", DriverParams.CLIENT_CERT.value, DriverParams.CLIENT_KEY.value), ("server.crt", "server.key", DriverParams.SERVER_CERT.value, DriverParams.SERVER_KEY.value), ] for cert_name, key_name, cert_key, key_key in candidates: cert_path = os.path.join(startup_dir, cert_name) key_path = os.path.join(startup_dir, key_name) if os.path.exists(cert_path) and os.path.exists(key_path): return cert_path, key_path, cert_key, key_key raise RuntimeError(f"workspace transfer requires cert/key files in startup dir: {startup_dir}") def _create_bootstrap_cell(args, owner_fqcn: str, secure_mode: bool) -> tuple[Cell, NetAgent]: startup_dir = os.path.join(args.workspace, "startup") credentials = {} if secure_mode: root_ca = os.path.join(startup_dir, "rootCA.pem") if not os.path.exists(root_ca): raise RuntimeError(f"workspace transfer requires rootCA.pem in startup dir: {startup_dir}") cert_path, key_path, cert_key, key_key = _get_bootstrap_tls_pair(startup_dir, owner_fqcn) credentials = { DriverParams.CA_CERT.value: root_ca, cert_key: cert_path, key_key: key_path, } parent_resources = {} parent_conn_sec = getattr(args, "parent_conn_sec", "") if parent_conn_sec: parent_resources[DriverParams.CONNECTION_SECURITY.value] = parent_conn_sec fqcn = make_workspace_transfer_fqcn(owner_fqcn, args.job_id) cell = Cell( fqcn=fqcn, root_url=_get_root_url(args), secure=secure_mode, credentials=credentials, create_internal_listener=False, parent_url=args.parent_url, parent_resources=parent_resources or None, ) # Install auth headers BEFORE cell.start(): the cell's initial # cellnet.channel registration handshake fires during start(), and the # parent's authenticator drops any unsigned message, preventing the cell # from ever registering with its parent. if FQCN.get_root(owner_fqcn) == FQCN.ROOT_SERVER: client_name = AUTH_CLIENT_NAME_FOR_SJ auth_token = args.job_id else: client_name = getattr(args, "client_name", "") or "" auth_token = getattr(args, "token", "") or "" set_add_auth_headers_filters( cell, client_name=client_name, auth_token=auth_token, token_signature=getattr(args, "token_signature", "") or "", ssid=getattr(args, "ssid", "") or None, ) cell.start() net_agent = NetAgent(cell) return cell, net_agent def _wait_for_bootstrap_ready(cell: Cell, owner_fqcn: str, timeout: float = BOOTSTRAP_CONNECT_TIMEOUT) -> None: deadline = time.monotonic() + timeout while time.monotonic() < deadline: if cell.is_backbone_ready() and cell.is_cell_connected(owner_fqcn): return time.sleep(BOOTSTRAP_CONNECT_POLL_INTERVAL) raise RuntimeError( f"workspace transfer bootstrap cell failed to connect to parent {owner_fqcn} within {timeout} seconds" ) def _request_workspace_bundle(cell: Cell, owner_fqcn: str, job_id: str, transfer_token: str) -> dict: _wait_for_bootstrap_ready(cell, owner_fqcn) request = new_cell_message({}, {"job_id": job_id, "transfer_token": transfer_token}) reply = cell.send_request( channel=WORKSPACE_TRANSFER_CHANNEL, target=owner_fqcn, topic=TOPIC_PREPARE_DOWNLOAD, request=request, timeout=PER_REQUEST_TIMEOUT, ) rc = reply.get_header(MessageHeaderKey.RETURN_CODE) if rc != ReturnCode.OK: raise RuntimeError(f"workspace download preparation failed for {job_id}: {rc}") payload = reply.payload if not isinstance(payload, dict): raise RuntimeError(f"invalid workspace download reply payload for {job_id}: {type(payload)}") return payload
[docs] def download_workspace(args, secure_mode: bool) -> None: owner_fqcn = os.environ.get(ENV_WORKSPACE_OWNER_FQCN, "") if not owner_fqcn: return transfer_token = os.environ.get(ENV_WORKSPACE_TRANSFER_TOKEN, "") if not transfer_token: raise RuntimeError(f"workspace transfer requires env var {ENV_WORKSPACE_TRANSFER_TOKEN}") os.makedirs(args.workspace, exist_ok=True) temp_dir = tempfile.mkdtemp(prefix="workspace-download-") try: cell = _get_bootstrap_cell(args, owner_fqcn, secure_mode) payload = _request_workspace_bundle(cell, owner_fqcn, args.job_id, transfer_token) ref_id = payload.get("ref_id") expected_sha = payload.get("sha256") err, bundle_path = download_file( from_fqcn=owner_fqcn, ref_id=ref_id, per_request_timeout=PER_REQUEST_TIMEOUT, cell=cell, location=temp_dir, ) if err: raise RuntimeError(f"failed to download workspace for {args.job_id}: {err}") if expected_sha and _hash_file(bundle_path) != expected_sha: raise RuntimeError(f"workspace bundle checksum mismatch for {args.job_id}") with zipfile.ZipFile(bundle_path) as zf: _validate_relative_zip_members(zf) zf.extractall(args.workspace) finally: shutil.rmtree(temp_dir, ignore_errors=True)
[docs] def upload_results(args, secure_mode: bool) -> None: owner_fqcn = os.environ.get(ENV_WORKSPACE_OWNER_FQCN, "") if not owner_fqcn: return transfer_token = os.environ.get(ENV_WORKSPACE_TRANSFER_TOKEN, "") if not transfer_token: raise RuntimeError(f"workspace transfer requires env var {ENV_WORKSPACE_TRANSFER_TOKEN}") run_dir = os.path.join(args.workspace, args.job_id) if not os.path.isdir(run_dir): return temp_bundle = tempfile.NamedTemporaryFile(delete=False, suffix=".zip") temp_bundle.close() downloader = None try: _zip_results_to_file(args.workspace, args.job_id, temp_bundle.name) bundle_sha = _hash_file(temp_bundle.name) bundle_size = os.path.getsize(temp_bundle.name) logger.info( "[ws-transfer] upload_results start job=%s bundle_size=%d target=%s", args.job_id, bundle_size, owner_fqcn, ) cell = _get_bootstrap_cell(args, owner_fqcn, secure_mode) _wait_for_bootstrap_ready(cell, owner_fqcn) downloader = ObjectDownloader( cell=cell, timeout=DOWNLOAD_TIMEOUT, num_receivers=1, transaction_done_cb=_cleanup_transfer_files, temp_paths=[temp_bundle.name], ) ref_id = add_file(downloader, temp_bundle.name) logger.info( "[ws-transfer] upload_results registered ref=%s tx=%s, sending publish_results", ref_id, getattr(downloader, "tx_id", "<unknown>"), ) request = new_cell_message( {}, { "job_id": args.job_id, "ref_id": ref_id, "transfer_token": transfer_token, "sha256": bundle_sha, "size": bundle_size, }, ) reply = cell.send_request( channel=WORKSPACE_TRANSFER_CHANNEL, target=owner_fqcn, topic=TOPIC_PUBLISH_RESULTS, request=request, timeout=DOWNLOAD_TIMEOUT, ) rc = reply.get_header(MessageHeaderKey.RETURN_CODE) reply_origin = reply.get_header(MessageHeaderKey.ORIGIN) reply_err = reply.get_header(MessageHeaderKey.ERROR, "") logger.info( "[ws-transfer] upload_results reply rc=%s origin=%s err=%s payload=%r", rc, reply_origin, reply_err, getattr(reply, "payload", None), ) if rc != ReturnCode.OK: raise RuntimeError(f"results upload failed for {args.job_id}: rc={rc} err={reply_err}") downloader.delete_transaction() downloader = None logger.info("[ws-transfer] upload_results SUCCESS job=%s", args.job_id) finally: if downloader is not None: downloader.delete_transaction() _cleanup_files([temp_bundle.name]) # Cell is no longer needed after upload; final chance to free it. _close_bootstrap_cell()
[docs] def upload_results_safely(args, secure_mode: bool, log=None) -> None: try: upload_results(args, secure_mode) except Exception as e: (log or logger).warning(f"failed to upload job results for {args.job_id}: {secure_format_exception(e)}")