# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from importlib.metadata import PackageNotFoundError
from importlib.metadata import version as get_package_version
from typing import Optional
from packaging.version import InvalidVersion, Version
from nvflare.app_common.tie.defs import Constant
from nvflare.client.api import ClientAPIType
from nvflare.client.api_spec import CLIENT_API_TYPE_KEY
from nvflare.fuel.utils.validation_utils import check_object_type
from nvflare.recipe.spec import Recipe
SUPPORTED_FLWR_MIN_VERSION = Version("1.26")
SUPPORTED_FLWR_SPEC = "flwr>=1.26"
def _validate_flwr_version():
try:
installed_version = get_package_version("flwr")
except PackageNotFoundError as ex:
raise RuntimeError(
f"Flower package 'flwr' is not installed. " f"FlowerRecipe requires '{SUPPORTED_FLWR_SPEC}'."
) from ex
try:
parsed_version = Version(installed_version)
except InvalidVersion as ex:
raise RuntimeError(
f"unable to parse installed flwr version '{installed_version}'. "
f"FlowerRecipe requires '{SUPPORTED_FLWR_SPEC}'."
) from ex
is_supported = parsed_version >= SUPPORTED_FLWR_MIN_VERSION
if not is_supported:
raise RuntimeError(
f"incompatible flwr version '{installed_version}'. " f"FlowerRecipe requires '{SUPPORTED_FLWR_SPEC}'."
)
def _create_flower_job(**kwargs):
from nvflare.app_opt.flower.flower_job import FlowerJob
return FlowerJob(**kwargs)
[docs]
class FlowerRecipe(Recipe):
"""Recipe class for Flower federated learning using NVFlare.
This class provides a high-level interface for configuring Flower
federated learning jobs. It wraps the FlowerJob and provides
a recipe-based interface for easier job configuration and execution.
Enables metric streaming and use of client API by default.
Flower CLI compatibility:
This recipe requires ``flwr>=1.26``. The integration uses Flower
Configuration under ``$FLWR_HOME/config.toml`` and the newer
SuperLink-based CLI workflow.
Example usage:
```python
recipe = FlowerRecipe(
name="my_flower_job",
flower_content="/path/to/flower/content",
min_clients=2,
stream_metrics=True
)
# Pre-deployed mode (no BYOC needed):
recipe = FlowerRecipe(
name="my_flower_job",
flower_app_path="local/custom/my_app",
min_clients=2,
)
```
Args:
flower_content (str, optional): Local directory path containing Flower app code (BYOC mode).
flower_app_path (str, optional): Relative path to pre-deployed Flower app under workspace's local/custom/ directory (pre-deployed mode, no BYOC needed). The server distributes the app to clients via Flower's FAB mechanism.
name (str): Name of the job. Defaults to "flower_job".
min_clients (int, optional): The minimum number of clients for the job. Defaults to 1.
mandatory_clients (List[str], optional): List of mandatory clients for the job. Defaults to None.
database (str, optional): Database string. Defaults to "".
superlink_ready_timeout (float, optional): Timeout for the superlink to be ready. Defaults to 10.0 seconds.
configure_task_timeout (float, optional): Timeout for configuring the task. Defaults to Constant.CONFIG_TASK_TIMEOUT.
start_task_timeout (float, optional): Timeout for starting the task. Defaults to Constant.START_TASK_TIMEOUT.
max_client_op_interval (float, optional): Maximum interval between client operations. Defaults to Constant.MAX_CLIENT_OP_INTERVAL.
progress_timeout (float, optional): Timeout for workflow progress. Defaults to Constant.WORKFLOW_PROGRESS_TIMEOUT.
per_msg_timeout (float, optional): Timeout for receiving individual messages. Defaults to 10.0 seconds.
tx_timeout (float, optional): Timeout for transmitting data. Defaults to 100.0 seconds.
client_shutdown_timeout (float, optional): Timeout for client shutdown. Defaults to 5.0 seconds.
extra_env (dict, optional): optional extra env variables to be passed to Flower client.
run_config (dict, optional): optional dict for flwr run --run-config arguments.
allow_runtime_dependency_installation (bool, optional): whether to allow dynamic dependency installation (only flwr>=1.29). Defaults to False.
"""
def __init__(
self,
flower_content: Optional[str] = None,
flower_app_path: Optional[str] = None,
name: str = "flower_job",
min_clients: int = 1,
mandatory_clients: Optional[list[str]] = None,
database: str = "",
superlink_ready_timeout: float = 10.0,
configure_task_timeout=Constant.CONFIG_TASK_TIMEOUT,
start_task_timeout=Constant.START_TASK_TIMEOUT,
max_client_op_interval: float = Constant.MAX_CLIENT_OP_INTERVAL,
progress_timeout: float = Constant.WORKFLOW_PROGRESS_TIMEOUT,
per_msg_timeout=10.0,
tx_timeout=100.0,
client_shutdown_timeout=5.0,
extra_env: Optional[dict] = None,
run_config: Optional[dict] = None,
allow_runtime_dependency_installation: bool = False,
):
"""Initialize the FlowerRecipe.
Creates a FlowerJob and wraps it in the Recipe interface.
"""
_validate_flwr_version()
if run_config is not None:
check_object_type("run_config", run_config, dict)
if extra_env is not None:
check_object_type("extra_env", extra_env, dict)
# needs to init client api to stream metrics
# only external client api works with the current flower integration
env = extra_env.copy() if extra_env is not None else {}
if CLIENT_API_TYPE_KEY in env and env[CLIENT_API_TYPE_KEY] != ClientAPIType.EX_PROCESS_API.value:
raise ValueError(
f"'extra_env[{CLIENT_API_TYPE_KEY}]' must be "
f"{ClientAPIType.EX_PROCESS_API.value!r} for the Flower integration; "
f"got {env[CLIENT_API_TYPE_KEY]!r}."
)
env[CLIENT_API_TYPE_KEY] = ClientAPIType.EX_PROCESS_API.value
job = _create_flower_job(
name=name,
flower_content=flower_content,
flower_app_path=flower_app_path,
min_clients=min_clients,
mandatory_clients=mandatory_clients,
database=database,
superlink_ready_timeout=superlink_ready_timeout,
configure_task_timeout=configure_task_timeout,
start_task_timeout=start_task_timeout,
max_client_op_interval=max_client_op_interval,
progress_timeout=progress_timeout,
per_msg_timeout=per_msg_timeout,
tx_timeout=tx_timeout,
client_shutdown_timeout=client_shutdown_timeout,
extra_env=env,
run_config=run_config,
allow_runtime_dependency_installation=allow_runtime_dependency_installation,
)
super().__init__(job)