Source code for nvflare.client.config

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
from typing import Dict, Optional

from nvflare.fuel.utils.config_factory import ConfigFactory


[docs]class ExchangeFormat: RAW = "raw" PYTORCH = "pytorch" NUMPY = "numpy"
[docs]class TransferType: FULL = "FULL" DIFF = "DIFF"
[docs]class ConfigKey: EXCHANGE_FORMAT = "exchange_format" TRANSFER_TYPE = "transfer_type" TRAIN_WITH_EVAL = "train_with_eval" TRAIN_TASK_NAME = "train_task_name" EVAL_TASK_NAME = "eval_task_name" SUBMIT_MODEL_TASK_NAME = "submit_model_task_name" PIPE_CHANNEL_NAME = "pipe_channel_name" PIPE = "pipe" CLASS_NAME = "CLASS_NAME" ARG = "ARG" TASK_NAME = "TASK_NAME" TASK_EXCHANGE = "TASK_EXCHANGE" METRICS_EXCHANGE = "METRICS_EXCHANGE"
[docs]class ClientConfig: """Config class used in `nvflare.client` module. Note: The config has the following keys: .. code-block:: EXCHANGE_FORMAT: Format to exchange, pytorch, raw, or numpy TRANSFER_TYPE: Either FULL or DIFF (means difference) TRAIN_WITH_EVAL: Whether train task needs to also do evaluation TRAIN_TASK_NAME: Name of the train task EVAL_TASK_NAME: Name of the evaluate task SUBMIT_MODEL_TASK_NAME: Name of the submit_model task PIPE_CHANNEL_NAME: Channel name of the pipe PIPE: pipe section CLASS_NAME: Class name ARG: Arguments SITE_NAME: Site name JOB_ID: Job id TASK_EXCHANGE: TASK_EXCHANGE section METRICS_EXCHANGE: METRICS_EXCHANGE section Example: The content of config looks like: .. code-block:: json { "METRICS_EXCHANGE": { "pipe_channel_name": "metric", "pipe": { "CLASS_NAME": "nvflare.fuel.utils.pipe.cell_pipe.CellPipe", "ARG": { "mode": "ACTIVE", "site_name": "site-1", "token": "simulate_job", "root_url": "tcp://0:51893", "secure_mode": false, "workspace_dir": "xxx" } } }, "SITE_NAME": "site-1", "JOB_ID": "simulate_job", "TASK_EXCHANGE": { "train_with_eval": true, "exchange_format": "numpy", "transfer_type": "DIFF", "train_task_name": "train", "eval_task_name": "evaluate", "submit_model_task_name": "submit_model", "pipe_channel_name": "task", "pipe": { "CLASS_NAME": "nvflare.fuel.utils.pipe.cell_pipe.CellPipe", "ARG": { "mode": "ACTIVE", "site_name": "site-1", "token": "simulate_job", "root_url": "tcp://0:51893", "secure_mode": false, "workspace_dir": "xxx" } } } } """ def __init__(self, config: Optional[Dict] = None): if config is None: config = {} self.config = config
[docs] def get_config(self) -> Dict: return self.config
[docs] def get_pipe_channel_name(self, section: str) -> str: return self.config[section][ConfigKey.PIPE_CHANNEL_NAME]
[docs] def get_pipe_args(self, section: str) -> dict: return self.config[section][ConfigKey.PIPE][ConfigKey.ARG]
[docs] def get_pipe_class(self, section: str) -> str: return self.config[section][ConfigKey.PIPE][ConfigKey.CLASS_NAME]
[docs] def get_exchange_format(self) -> str: return self.config[ConfigKey.TASK_EXCHANGE][ConfigKey.EXCHANGE_FORMAT]
[docs] def get_transfer_type(self) -> str: return self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.TRANSFER_TYPE, "FULL")
[docs] def get_train_task(self): return self.config[ConfigKey.TASK_EXCHANGE][ConfigKey.TRAIN_TASK_NAME]
[docs] def get_eval_task(self): return self.config[ConfigKey.TASK_EXCHANGE][ConfigKey.EVAL_TASK_NAME]
[docs] def get_submit_model_task(self): return self.config[ConfigKey.TASK_EXCHANGE][ConfigKey.SUBMIT_MODEL_TASK_NAME]
[docs] def to_json(self, config_file: str): with open(config_file, "w") as f: json.dump(self.config, f, indent=2)
[docs]def from_file(config_file: str): config = ConfigFactory.load_config(config_file) if config is None: raise RuntimeError(f"Load config file {config_file} failed") return ClientConfig(config=config.to_dict())
[docs]def write_config_to_file(config_data: dict, config_file_path: str): """Writes client api config file. Args: config_data (dict): data to be updated. config_file_path (str): filepath to write. """ if os.path.exists(config_file_path): client_config = from_file(config_file=config_file_path) else: client_config = ClientConfig() configuration = client_config.config configuration.update(config_data) client_config.to_json(config_file_path)