# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from importlib.metadata import PackageNotFoundError
from importlib.metadata import version as get_package_version
from typing import Optional
from packaging.specifiers import SpecifierSet
from packaging.version import InvalidVersion, Version
from nvflare.app_common.tie.defs import Constant
from nvflare.client.api import ClientAPIType
from nvflare.client.api_spec import CLIENT_API_TYPE_KEY
from nvflare.recipe.spec import Recipe
SUPPORTED_FLWR_MIN_VERSION_SPEC = ">=1.16"
SUPPORTED_FLWR_MAX_VERSION_EXCLUSIVE = Version("1.26")
SUPPORTED_FLWR_SPEC = "flwr>=1.16,<1.26"
SUPPORTED_FLWR_MIN_SPEC_SET = SpecifierSet(SUPPORTED_FLWR_MIN_VERSION_SPEC)
def _validate_flwr_version():
try:
installed_version = get_package_version("flwr")
except PackageNotFoundError as ex:
raise RuntimeError(
f"Flower package 'flwr' is not installed. " f"FlowerRecipe requires '{SUPPORTED_FLWR_SPEC}'."
) from ex
try:
parsed_version = Version(installed_version)
except InvalidVersion as ex:
raise RuntimeError(
f"unable to parse installed flwr version '{installed_version}'. "
f"FlowerRecipe requires '{SUPPORTED_FLWR_SPEC}'."
) from ex
# Use a SpecifierSet for the lower bound and Version comparison for the upper bound.
# This keeps 1.16rc0 excluded while allowing 1.26.0rc0 as < 1.26.
is_supported = SUPPORTED_FLWR_MIN_SPEC_SET.contains(parsed_version, prereleases=True) and (
parsed_version < SUPPORTED_FLWR_MAX_VERSION_EXCLUSIVE
)
if not is_supported:
raise RuntimeError(
f"incompatible flwr version '{installed_version}'. " f"FlowerRecipe requires '{SUPPORTED_FLWR_SPEC}'."
)
def _create_flower_job(**kwargs):
from nvflare.app_opt.flower.flower_job import FlowerJob
return FlowerJob(**kwargs)
[docs]
class FlowerRecipe(Recipe):
"""Recipe class for Flower federated learning using NVFlare.
This class provides a high-level interface for configuring Flower
federated learning jobs. It wraps the FlowerJob and provides
a recipe-based interface for easier job configuration and execution.
Enables metric streaming and use of client API by default.
Flower CLI compatibility:
This recipe requires ``flwr>=1.16,<1.26``. The current
integration relies on legacy federation CLI arguments that are not
available in newer Flower CLI versions.
Example usage:
```python
recipe = FlowerRecipe(
name="my_flower_job",
flower_content="/path/to/flower/content",
min_clients=2,
stream_metrics=True
)
```
Args:
flower_content (str): Content for the flower job. Required.
name (str): Name of the job. Defaults to "flower_job".
min_clients (int, optional): The minimum number of clients for the job. Defaults to 1.
mandatory_clients (List[str], optional): List of mandatory clients for the job. Defaults to None.
database (str, optional): Database string. Defaults to "".
superlink_ready_timeout (float, optional): Timeout for the superlink to be ready. Defaults to 10.0 seconds.
configure_task_timeout (float, optional): Timeout for configuring the task. Defaults to Constant.CONFIG_TASK_TIMEOUT.
start_task_timeout (float, optional): Timeout for starting the task. Defaults to Constant.START_TASK_TIMEOUT.
max_client_op_interval (float, optional): Maximum interval between client operations. Defaults to Constant.MAX_CLIENT_OP_INTERVAL.
progress_timeout (float, optional): Timeout for workflow progress. Defaults to Constant.WORKFLOW_PROGRESS_TIMEOUT.
per_msg_timeout (float, optional): Timeout for receiving individual messages. Defaults to 10.0 seconds.
tx_timeout (float, optional): Timeout for transmitting data. Defaults to 100.0 seconds.
client_shutdown_timeout (float, optional): Timeout for client shutdown. Defaults to 5.0 seconds.
extra_env (dict, optional): optional extra env variables to be passed to Flower client.
"""
def __init__(
self,
flower_content: str,
name: str = "flower_job",
min_clients: int = 1,
mandatory_clients: Optional[list[str]] = None,
database: str = "",
superlink_ready_timeout: float = 10.0,
configure_task_timeout=Constant.CONFIG_TASK_TIMEOUT,
start_task_timeout=Constant.START_TASK_TIMEOUT,
max_client_op_interval: float = Constant.MAX_CLIENT_OP_INTERVAL,
progress_timeout: float = Constant.WORKFLOW_PROGRESS_TIMEOUT,
per_msg_timeout=10.0,
tx_timeout=100.0,
client_shutdown_timeout=5.0,
extra_env: dict = None,
):
"""Initialize the FlowerRecipe.
Creates a FlowerJob and wraps it in the Recipe interface.
"""
_validate_flwr_version()
# needs to init client api to stream metrics
# only external client api works with the current flower integration
env = {CLIENT_API_TYPE_KEY: ClientAPIType.EX_PROCESS_API.value}
job = _create_flower_job(
name=name,
flower_content=flower_content,
min_clients=min_clients,
mandatory_clients=mandatory_clients,
database=database,
superlink_ready_timeout=superlink_ready_timeout,
configure_task_timeout=configure_task_timeout,
start_task_timeout=start_task_timeout,
max_client_op_interval=max_client_op_interval,
progress_timeout=progress_timeout,
per_msg_timeout=per_msg_timeout,
tx_timeout=tx_timeout,
client_shutdown_timeout=client_shutdown_timeout,
extra_env=env,
)
super().__init__(job)