nvflare.app_opt.pt.recipes.fedopt module
- class FedOptRecipe(*, name: str = 'fedopt', model: Any | dict[str, Any] | None = None, initial_ckpt: str | None = None, min_clients: int, num_rounds: int = 2, train_script: str, train_args: str = '', aggregator: Aggregator | None = None, launch_external_process: bool = False, command: str = 'python3 -u', server_expected_format: ExchangeFormat = ExchangeFormat.NUMPY, device: str | None = None, source_model: str = 'model', optimizer_args: dict | None = None, lr_scheduler_args: dict | None = None, server_memory_gc_rounds: int = 1, client_memory_gc_rounds: int = 0, cuda_empty_cache: bool = False)[source]
Bases:
RecipeA recipe for implementing Federated Optimization (FedOpt) in NVFlare.
FedOpt is a federated learning algorithm that optimizes the global model using a server-side optimizer and learning rate scheduler. After each round, the global model is updated using the specified optimizer and learning rate scheduler. The algorithm is proposed in Reddi et al. “Adaptive Federated Optimization.” arXiv preprint arXiv:2003.00295 (2020).
Note: FedOpt is only implemented for params_transfer_type == TransferType.DIFF and DataKind.WEIGHT_DIFF in the aggregator.
- Parameters:
name – Name of the federated learning job. Defaults to “fedopt”.
model – Initial model to start federated training with (REQUIRED). Can be: - nn.Module instance - Dict config: {“class_path”: “module.ClassName”, “args”: {“param”: value}} Note: FedOpt requires a model for the server-side optimizer to work.
initial_ckpt – Absolute path to a pre-trained checkpoint file. The file may not exist locally as it could be on the server. Used to load initial weights. Note: PyTorch requires model when using initial_ckpt (for architecture).
min_clients – Minimum number of clients required to start a training round.
num_rounds – Number of federated training rounds to execute. Defaults to 2.
train_script – Path to the training script that will be executed on each client.
train_args – Command line arguments to pass to the training script.
aggregator – Aggregator for combining client updates. If None, uses InTimeAccumulateWeightedAggregator with expected_data_kind=DataKind.WEIGHT_DIFF.
launch_external_process (bool) – Whether to launch the script in external process. Defaults to False.
command (str) – If launch_external_process=True, command to run script (prepended to script). Defaults to “python3”.
server_expected_format (str) – What format to exchange the parameters between server and client.
source_model (str) – ID of the source model component. Defaults to “model”.
optimizer_args (dict) – Configuration for server-side optimizer with keys: - path: Fully qualified optimizer class (e.g., “torch.optim.SGD”). “class_path” is also accepted. - args: Dictionary of optimizer arguments (e.g., {“lr”: 1.0, “momentum”: 0.6}) - config_type: Optional; if omitted, set to “dict” so the config is not instantiated at load time.
lr_scheduler_args (dict) – Optional configuration for learning rate scheduler with keys: - path: Fully qualified scheduler class (e.g., “torch.optim.lr_scheduler.CosineAnnealingLR”). “class_path” is also accepted. - args: Dictionary of scheduler arguments (e.g., {“T_max”: 100, “eta_min”: 0.9}) - config_type: Optional; if omitted, set to “dict” so the config is not instantiated at load time.
device (str) – Device to use for server-side optimization, e.g. “cpu” or “cuda:0”. Defaults to None; will default to cuda if available and no device is specified.
server_memory_gc_rounds – Run memory cleanup (gc.collect + malloc_trim) every N rounds on server. Set to 0 to disable. Defaults to 1 (every round).
Example
```python recipe = FedOptRecipe(
name=”my_fedopt_job”, model=pretrained_model, min_clients=2, num_rounds=10, train_script=”client.py”, train_args=”–epochs 5 –batch_size 32”, device=”cpu”, source_model=”model”, optimizer_args={
“path”: “torch.optim.SGD”, “args”: {“lr”: 1.0, “momentum”: 0.6}, “config_type”: “dict”
}, lr_scheduler_args={
“path”: “torch.optim.lr_scheduler.CosineAnnealingLR”, “args”: {“T_max”: “{num_rounds}”, “eta_min”: 0.9}, “config_type”: “dict”
}
)
This is base class of a recipe. Recipes are implemented by jobs. A concrete recipe must provide the job for recipe implementation.
- param job:
the job that implements the recipe.