# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from enum import Enum
from typing import Dict, Optional
from nvflare.apis.fl_constant import ConnPropKey, FLMetaKey
from nvflare.fuel.utils.config_factory import ConfigFactory
from nvflare.fuel.utils.log_utils import get_obj_logger
[docs]
class TransferType(str, Enum):
FULL = "FULL"
DIFF = "DIFF"
[docs]
class ConfigKey:
EXCHANGE_FORMAT = "exchange_format"
SERVER_EXPECTED_FORMAT = "server_expected_format"
TRANSFER_TYPE = "transfer_type"
TRAIN_WITH_EVAL = "train_with_eval"
TRAIN_TASK_NAME = "train_task_name"
EVAL_TASK_NAME = "eval_task_name"
SUBMIT_MODEL_TASK_NAME = "submit_model_task_name"
PIPE_CHANNEL_NAME = "pipe_channel_name"
PIPE = "pipe"
CLASS_NAME = "CLASS_NAME"
ARG = "ARG"
TASK_NAME = "TASK_NAME"
TASK_EXCHANGE = "TASK_EXCHANGE"
METRICS_EXCHANGE = "METRICS_EXCHANGE"
HEARTBEAT_TIMEOUT = "HEARTBEAT_TIMEOUT"
MEMORY_GC_ROUNDS = "memory_gc_rounds"
CUDA_EMPTY_CACHE = "cuda_empty_cache"
SUBMIT_RESULT_TIMEOUT = "submit_result_timeout"
MAX_RESENDS = "max_resends"
DOWNLOAD_COMPLETE_TIMEOUT = "download_complete_timeout"
LAUNCH_ONCE = "launch_once"
[docs]
class ClientConfig:
"""Config class used in `nvflare.client` module.
Note:
The config has the following keys:
.. code-block::
EXCHANGE_FORMAT: Format to exchange, pytorch, raw, or numpy
TRANSFER_TYPE: Either FULL or DIFF (means difference)
TRAIN_WITH_EVAL: Whether train task needs to also do evaluation
TRAIN_TASK_NAME: Name of the train task
EVAL_TASK_NAME: Name of the evaluate task
SUBMIT_MODEL_TASK_NAME: Name of the submit_model task
PIPE_CHANNEL_NAME: Channel name of the pipe
PIPE: pipe section
CLASS_NAME: Class name
ARG: Arguments
SITE_NAME: Site name
JOB_ID: Job id
TASK_EXCHANGE: TASK_EXCHANGE section
METRICS_EXCHANGE: METRICS_EXCHANGE section
Example:
The content of config looks like:
.. code-block:: json
{
"METRICS_EXCHANGE": {
"pipe_channel_name": "metric",
"pipe": {
"CLASS_NAME": "nvflare.fuel.utils.pipe.cell_pipe.CellPipe",
"ARG": {
"mode": "ACTIVE",
"site_name": "site-1",
"token": "simulate_job",
"root_url": "tcp://0:51893",
"secure_mode": false,
"workspace_dir": "xxx"
}
}
},
"SITE_NAME": "site-1",
"JOB_ID": "simulate_job",
"TASK_EXCHANGE": {
"train_with_eval": true,
"exchange_format": "numpy",
"transfer_type": "DIFF",
"train_task_name": "train",
"eval_task_name": "validate",
"submit_model_task_name": "submit_model",
"pipe_channel_name": "task",
"pipe": {
"CLASS_NAME": "nvflare.fuel.utils.pipe.cell_pipe.CellPipe",
"ARG": {
"mode": "ACTIVE",
"site_name": "site-1",
"token": "simulate_job",
"root_url": "tcp://0:51893",
"secure_mode": false,
"workspace_dir": "xxx"
}
}
}
}
"""
def __init__(self, config: Optional[Dict] = None):
if config is None:
config = {}
self.config = config
self.logger = get_obj_logger(self)
[docs]
def get_config(self) -> Dict:
return self.config
[docs]
def get_pipe_channel_name(self, section: str) -> str:
return self.config[section][ConfigKey.PIPE_CHANNEL_NAME]
[docs]
def get_pipe_args(self, section: str) -> dict:
return self.config[section][ConfigKey.PIPE][ConfigKey.ARG]
[docs]
def get_pipe_class(self, section: str) -> str:
return self.config[section][ConfigKey.PIPE][ConfigKey.CLASS_NAME]
[docs]
def get_transfer_type(self) -> str:
return self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.TRANSFER_TYPE, "FULL")
[docs]
def get_train_task(self):
return self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.TRAIN_TASK_NAME, "")
[docs]
def get_eval_task(self):
return self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.EVAL_TASK_NAME, "")
[docs]
def get_submit_model_task(self):
return self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.SUBMIT_MODEL_TASK_NAME, "")
[docs]
def get_heartbeat_timeout(self):
# TODO decouple task and metric heartbeat timeouts
return self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(
ConfigKey.HEARTBEAT_TIMEOUT,
self.config.get(ConfigKey.METRICS_EXCHANGE, {}).get(ConfigKey.HEARTBEAT_TIMEOUT, 60),
)
[docs]
def get_max_resends(self):
"""Return the maximum number of pipe send retries for submitting task results.
None means unlimited; the default of 3 bounds the retry window and prevents
unbounded ArrayDownloadable accumulation (Root Cause 6).
Set via recipe.add_client_config({"max_resends": N}).
"""
value = self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.MAX_RESENDS, 3)
if value is None:
return None
result = int(value)
if result < 0:
self.logger.warning(f"max_resends={result} is negative, clamping to 0")
return 0
return result
[docs]
def get_launch_once(self) -> bool:
"""Return whether the subprocess is launched once for the entire job (True) or per-round (False).
True → subprocess handles multiple rounds; must NOT os._exit() after each send.
False → subprocess handles one round; must os._exit() after download so the deferred-stop
poller on the CJ side unblocks (default, preserves original behaviour).
"""
return bool(self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.LAUNCH_ONCE, False))
[docs]
def get_download_complete_timeout(self) -> float:
"""Return timeout (seconds) for subprocess to wait for the server to finish downloading its result.
After send_to_peer() ACKs, the server asynchronously downloads tensors from the subprocess
DownloadService. This timeout gates subprocess exit so the process does not disappear before
the download completes. Defaults to 1800 s (30 min) for large-model transfers.
"""
return float(self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.DOWNLOAD_COMPLETE_TIMEOUT, 1800.0))
[docs]
def get_submit_result_timeout(self) -> float:
"""Return the timeout (seconds) for the subprocess to wait for CJ to ACK a result message.
The value is read from the TASK_EXCHANGE section of the config, which is written by
ClientAPILauncherExecutor.prepare_config_for_launch(). If absent, a safe default of
300 s is returned — large enough for a single-chunk ACK with reverse PASS_THROUGH, and
a reasonable upper bound for direct (non-PASS_THROUGH) transfers at typical throughputs.
Changing this value via recipe.add_client_config() sets it for a specific job without
touching any process-level defaults.
"""
return float(self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.SUBMIT_RESULT_TIMEOUT, 300.0))
[docs]
def get_connection_security(self):
return self.config.get(ConnPropKey.CONNECTION_SECURITY)
[docs]
def get_root_conn_props(self):
return self.config.get(ConnPropKey.ROOT_CONN_PROPS)
[docs]
def get_cp_conn_props(self):
return self.config.get(ConnPropKey.CP_CONN_PROPS)
[docs]
def get_relay_conn_props(self):
return self.config.get(ConnPropKey.RELAY_CONN_PROPS)
[docs]
def get_site_name(self):
return self.config.get(FLMetaKey.SITE_NAME)
[docs]
def get_auth_token(self):
return self.config.get(FLMetaKey.AUTH_TOKEN)
[docs]
def get_auth_token_signature(self):
return self.config.get(FLMetaKey.AUTH_TOKEN_SIGNATURE)
[docs]
def to_json(self, config_file: str):
with open(config_file, "w") as f:
json.dump(self.config, f, indent=2)
[docs]
def from_file(config_file: str):
config = ConfigFactory.load_config(config_file)
if config is None:
raise RuntimeError(f"Load config file {config_file} failed")
return ClientConfig(config=config.to_dict())
[docs]
def write_config_to_file(config_data: dict, config_file_path: str):
"""Writes client api config file.
Args:
config_data (dict): data to be updated.
config_file_path (str): filepath to write.
"""
if os.path.exists(config_file_path):
client_config = from_file(config_file=config_file_path)
else:
client_config = ClientConfig()
configuration = client_config.config
configuration.update(config_data)
client_config.to_json(config_file_path)