Source code for nvflare.app_opt.flower.recipe

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from importlib.metadata import PackageNotFoundError
from importlib.metadata import version as get_package_version
from typing import Optional

from packaging.version import InvalidVersion, Version

from nvflare.app_common.tie.defs import Constant
from nvflare.client.api import ClientAPIType
from nvflare.client.api_spec import CLIENT_API_TYPE_KEY
from nvflare.fuel.utils.validation_utils import check_object_type
from nvflare.recipe.spec import Recipe

SUPPORTED_FLWR_MIN_VERSION = Version("1.26")
SUPPORTED_FLWR_SPEC = "flwr>=1.26"


def _validate_flwr_version():
    try:
        installed_version = get_package_version("flwr")
    except PackageNotFoundError as ex:
        raise RuntimeError(
            f"Flower package 'flwr' is not installed. " f"FlowerRecipe requires '{SUPPORTED_FLWR_SPEC}'."
        ) from ex

    try:
        parsed_version = Version(installed_version)
    except InvalidVersion as ex:
        raise RuntimeError(
            f"unable to parse installed flwr version '{installed_version}'. "
            f"FlowerRecipe requires '{SUPPORTED_FLWR_SPEC}'."
        ) from ex

    is_supported = parsed_version >= SUPPORTED_FLWR_MIN_VERSION
    if not is_supported:
        raise RuntimeError(
            f"incompatible flwr version '{installed_version}'. " f"FlowerRecipe requires '{SUPPORTED_FLWR_SPEC}'."
        )


def _create_flower_job(**kwargs):
    from nvflare.app_opt.flower.flower_job import FlowerJob

    return FlowerJob(**kwargs)


[docs] class FlowerRecipe(Recipe): """Recipe class for Flower federated learning using NVFlare. This class provides a high-level interface for configuring Flower federated learning jobs. It wraps the FlowerJob and provides a recipe-based interface for easier job configuration and execution. Enables metric streaming and use of client API by default. Flower CLI compatibility: This recipe requires ``flwr>=1.26``. The integration uses Flower Configuration under ``$FLWR_HOME/config.toml`` and the newer SuperLink-based CLI workflow. Example usage: ```python recipe = FlowerRecipe( name="my_flower_job", flower_content="/path/to/flower/content", min_clients=2, stream_metrics=True ) # Pre-deployed mode (no BYOC needed): recipe = FlowerRecipe( name="my_flower_job", flower_app_path="local/custom/my_app", min_clients=2, ) ``` Args: flower_content (str, optional): Local directory path containing Flower app code (BYOC mode). flower_app_path (str, optional): Relative path to pre-deployed Flower app under workspace's local/custom/ directory (pre-deployed mode, no BYOC needed). The server distributes the app to clients via Flower's FAB mechanism. name (str): Name of the job. Defaults to "flower_job". min_clients (int, optional): The minimum number of clients for the job. Defaults to 1. mandatory_clients (List[str], optional): List of mandatory clients for the job. Defaults to None. database (str, optional): Database string. Defaults to "". superlink_ready_timeout (float, optional): Timeout for the superlink to be ready. Defaults to 10.0 seconds. configure_task_timeout (float, optional): Timeout for configuring the task. Defaults to Constant.CONFIG_TASK_TIMEOUT. start_task_timeout (float, optional): Timeout for starting the task. Defaults to Constant.START_TASK_TIMEOUT. max_client_op_interval (float, optional): Maximum interval between client operations. Defaults to Constant.MAX_CLIENT_OP_INTERVAL. progress_timeout (float, optional): Timeout for workflow progress. Defaults to Constant.WORKFLOW_PROGRESS_TIMEOUT. per_msg_timeout (float, optional): Timeout for receiving individual messages. Defaults to 10.0 seconds. tx_timeout (float, optional): Timeout for transmitting data. Defaults to 100.0 seconds. client_shutdown_timeout (float, optional): Timeout for client shutdown. Defaults to 5.0 seconds. extra_env (dict, optional): optional extra env variables to be passed to Flower client. run_config (dict, optional): optional dict for flwr run --run-config arguments. allow_runtime_dependency_installation (bool, optional): whether to allow dynamic dependency installation (only flwr>=1.29). Defaults to False. """ def __init__( self, flower_content: Optional[str] = None, flower_app_path: Optional[str] = None, name: str = "flower_job", min_clients: int = 1, mandatory_clients: Optional[list[str]] = None, database: str = "", superlink_ready_timeout: float = 10.0, configure_task_timeout=Constant.CONFIG_TASK_TIMEOUT, start_task_timeout=Constant.START_TASK_TIMEOUT, max_client_op_interval: float = Constant.MAX_CLIENT_OP_INTERVAL, progress_timeout: float = Constant.WORKFLOW_PROGRESS_TIMEOUT, per_msg_timeout=10.0, tx_timeout=100.0, client_shutdown_timeout=5.0, extra_env: Optional[dict] = None, run_config: Optional[dict] = None, allow_runtime_dependency_installation: bool = False, ): """Initialize the FlowerRecipe. Creates a FlowerJob and wraps it in the Recipe interface. """ _validate_flwr_version() if run_config is not None: check_object_type("run_config", run_config, dict) if extra_env is not None: check_object_type("extra_env", extra_env, dict) # needs to init client api to stream metrics # only external client api works with the current flower integration env = extra_env.copy() if extra_env is not None else {} if CLIENT_API_TYPE_KEY in env and env[CLIENT_API_TYPE_KEY] != ClientAPIType.EX_PROCESS_API.value: raise ValueError( f"'extra_env[{CLIENT_API_TYPE_KEY}]' must be " f"{ClientAPIType.EX_PROCESS_API.value!r} for the Flower integration; " f"got {env[CLIENT_API_TYPE_KEY]!r}." ) env[CLIENT_API_TYPE_KEY] = ClientAPIType.EX_PROCESS_API.value job = _create_flower_job( name=name, flower_content=flower_content, flower_app_path=flower_app_path, min_clients=min_clients, mandatory_clients=mandatory_clients, database=database, superlink_ready_timeout=superlink_ready_timeout, configure_task_timeout=configure_task_timeout, start_task_timeout=start_task_timeout, max_client_op_interval=max_client_op_interval, progress_timeout=progress_timeout, per_msg_timeout=per_msg_timeout, tx_timeout=tx_timeout, client_shutdown_timeout=client_shutdown_timeout, extra_env=env, run_config=run_config, allow_runtime_dependency_installation=allow_runtime_dependency_installation, ) super().__init__(job)