Source code for nvflare.app_opt.tensor_stream.utils

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterator, Optional, Union

import numpy as np
import torch

from nvflare.apis.dxo import DXO, DataKind, from_shareable
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.job_def import SERVER_SITE_NAME
from nvflare.apis.shareable import Shareable

from .types import TensorTopics


[docs] def clean_task_data(fl_ctx: FLContext): """Clean the task data in the FLContext. Args: fl_ctx (FLContext): The FLContext to clean the task data from. """ task_data: Shareable = fl_ctx.get_prop(FLContextKey.TASK_DATA) # keep only the non-tensor in the task data since tensors are sent separately new_task_data = copy_non_tensor_params(task_data["DXO"]["data"]) task_data["DXO"]["data"] = new_task_data fl_ctx.set_prop(FLContextKey.TASK_DATA, value=task_data, private=True, sticky=False)
[docs] def clean_task_result(fl_ctx: FLContext): """Clean the task result in the FLContext. Args: fl_ctx (FLContext): The FLContext to clean the task result from. """ task_result: Shareable = fl_ctx.get_prop(FLContextKey.TASK_RESULT) # keep only the non-tensor in the task result since tensors are sent separately new_task_result = copy_non_tensor_params(task_result["DXO"]["data"]) task_result["DXO"]["data"] = new_task_result fl_ctx.set_prop(FLContextKey.TASK_RESULT, value=task_result, private=True, sticky=False)
[docs] def get_topic_for_ctx_prop_key(ctx_prop_key: str) -> str: """Get the topic based on the context property key. Args: ctx_prop_key (str): The context property key. Returns: str: The topic associated with the context property key. """ if ctx_prop_key == FLContextKey.TASK_DATA: return TensorTopics.TASK_DATA elif ctx_prop_key == FLContextKey.TASK_RESULT: return TensorTopics.TASK_RESULT else: raise ValueError(f"Unsupported context property key: {ctx_prop_key}")
[docs] def get_targets_from_ctx_and_prop_key(fl_ctx: FLContext, ctx_prop_key: str) -> list[str]: """Get the peer identity name from the FLContext. Args: fl_ctx (FLContext): The FLContext for the current operation. Returns: list[str]: The identity name(s) of the peer(s). """ if ctx_prop_key == FLContextKey.TASK_DATA: return [fl_ctx.get_peer_context().get_identity_name()] elif ctx_prop_key == FLContextKey.TASK_RESULT: return [SERVER_SITE_NAME] else: raise ValueError(f"Unsupported context property key: {ctx_prop_key}")
[docs] def to_numpy_recursive(obj: Union[torch.Tensor, dict[str, torch.Tensor]]) -> Union[dict[str, np.ndarray], np.ndarray]: """Recursively convert torch tensors to numpy arrays with minimal memory duplication. Note: For CPU tensors, .numpy() returns a view sharing memory with the original tensor (zero-copy). For GPU tensors, data must be moved to CPU first, which creates a copy. Only the dictionary structure is duplicated, not the underlying tensor data (for CPU tensors). Args: obj: A torch.Tensor or dict containing torch.Tensors (possibly nested) Returns: A numpy array or dict containing numpy arrays. Tensor data is shared where possible (CPU tensors). """ if hasattr(obj, "numpy"): # .numpy() returns a view for CPU tensors (no data copy) # For GPU tensors, must call .cpu() first which creates a copy if obj.is_cuda: return obj.cpu().numpy() return obj.numpy() elif isinstance(obj, dict): # Create new dict structure but reuse converted tensors (which share memory with originals) return {k: to_numpy_recursive(v) for k, v in obj.items()} else: raise ValueError(f"Unsupported object type: {type(obj)}")
[docs] def get_dxo_from_ctx(fl_ctx: FLContext, ctx_prop_key: str, tasks: list[str]) -> DXO: """Extract model parameters from the FLContext based on the provided property key. Args: fl_ctx (FLContext): The FLContext containing the data. ctx_prop_key (str): The context property key to extract data from. tasks (list[str]): The list of tasks to consider. Returns: dict[str, torch.Tensor]: A dictionary of data extracted from the FLContext. """ task_name = fl_ctx.get_prop(FLContextKey.TASK_NAME) if not task_name: raise ValueError("No task name found in FLContext.") if task_name not in tasks: raise ValueError(f"Task name '{task_name}' not part of configured tasks: {tasks}") task: Shareable = fl_ctx.get_prop(ctx_prop_key) if task is None: raise ValueError(f"No task found in FLContext. Looked for for shareable in '{ctx_prop_key}'.") dxo = from_shareable(task) if dxo.data_kind not in (DataKind.WEIGHTS, DataKind.WEIGHT_DIFF): raise ValueError(f"Skipping task, data kind is not WEIGHTS or WEIGHT_DIFF: {dxo.data_kind}") return dxo
[docs] def chunk_tensors_from_params( params: dict[str, Union[torch.Tensor, dict]], parent_keys: Optional[list[str]] = None, chunk_size: Optional[int] = 10, ) -> Iterator[tuple[tuple[str], dict[str, torch.Tensor]]]: """ Generator that yields tensors grouped by their immediate parent dictionary keys. Args: params: Dictionary with string keys and values that are either torch.Tensor or nested dicts. parent_keys: List of keys representing the current path (internal use, defaults to empty). chunk_size: Optional maximum number of tensors to yield at once per parent. Yields: A tuple containing: - List of parent keys (excluding the tensor key itself). - Dictionary mapping tensor key names to torch.Tensor instances. """ if chunk_size is not None and chunk_size <= 0: raise ValueError("chunk_size must be a positive integer or None") if parent_keys: parent_keys = list(parent_keys) else: parent_keys = [] tensors = {} for key, value in params.items(): if isinstance(value, torch.Tensor): tensors[key] = value elif isinstance(value, np.ndarray): tensors[key] = torch.from_numpy(value) elif isinstance(value, dict): yield from chunk_tensors_from_params(value, parent_keys + [key], chunk_size) if tensors: if chunk_size is None or chunk_size >= len(tensors): yield tuple(parent_keys), tensors else: keys = list(tensors.keys()) for i in range(0, len(keys), chunk_size): chunk_keys = keys[i : i + chunk_size] chunk_tensors = {k: tensors[k] for k in chunk_keys} yield tuple(parent_keys), chunk_tensors
[docs] def update_params_with_tensors( params: dict, parents: list[str], tensors: dict[str, torch.Tensor], to_ndarray: bool = False ) -> None: """ Updates the nested dictionary `params` at the location specified by `parents` with the provided tensor values from `tensors`. If `to_ndarray` is True, tensors are converted to numpy ndarrays before insertion. Args: params: The dictionary to update (possibly nested). parents: List of keys that specify the nested path within `params`. tensors: Dictionary mapping keys to torch.Tensor instances. to_ndarray: Whether to convert tensors to numpy arrays before updating. """ cur = params for key in parents: if key not in cur: cur[key] = {} elif not isinstance(cur[key], dict): raise ValueError(f"Expected dict at key '{key}', but found {type(cur[key])}") cur = cur[key] for k, tensor in tensors.items(): if to_ndarray: cur[k] = tensor.cpu().numpy() if tensor.is_cuda else tensor.numpy() else: cur[k] = tensor
[docs] def merge_params_dicts( base_params: dict[str, dict], new_params: dict[str, dict], to_ndarray: bool = False, ) -> dict[str, dict]: """ Merges two nested dictionaries of parameters. Args: base_params: The base dictionary to merge into. new_params: The new dictionary whose values will overwrite those in base_params. to_ndarray: If True, converts torch tensors to numpy arrays during merge. Returns: The merged dictionary with values from new_params overwriting those in base_params. """ for key, value in new_params.items(): if key in base_params and isinstance(base_params[key], dict) and isinstance(value, dict): # Both base and new have this key as dicts - recurse with to_ndarray parameter merge_params_dicts(base_params[key], value, to_ndarray) elif isinstance(value, dict): # New key with dict value - recursively process to convert tensors if needed if to_ndarray: base_params[key] = {} merge_params_dicts(base_params[key], value, to_ndarray) else: base_params[key] = value else: # Leaf value (tensor or other type) if to_ndarray and isinstance(value, torch.Tensor): base_params[key] = value.cpu().numpy() if value.is_cuda else value.numpy() else: base_params[key] = value return base_params
[docs] def copy_non_tensor_params(params: dict[str, dict]) -> dict[str, dict]: """Recursively copy non-tensor parameters in the given dictionary. Args: params: The dictionary of parameters to copy from. Returns: A new dictionary containing only non-tensor parameters. """ non_tensor_params = {} for key, value in params.items(): if isinstance(value, dict): nested_non_tensors = copy_non_tensor_params(value) if nested_non_tensors: non_tensor_params[key] = nested_non_tensors elif not isinstance(value, (torch.Tensor, np.ndarray)): non_tensor_params[key] = value return non_tensor_params