nvflare.app_opt.tensor_stream.utils module

chunk_tensors_from_params(params: dict[str, Tensor | dict], parent_keys: list[str] | None = None, chunk_size: int | None = 10) Iterator[tuple[tuple[str], dict[str, Tensor]]][source]

Generator that yields tensors grouped by their immediate parent dictionary keys.

Parameters:
  • params – Dictionary with string keys and values that are either torch.Tensor or nested dicts.

  • parent_keys – List of keys representing the current path (internal use, defaults to empty).

  • chunk_size – Optional maximum number of tensors to yield at once per parent.

Yields:

A tuple containing

  • List of parent keys (excluding the tensor key itself).

  • Dictionary mapping tensor key names to torch.Tensor instances.

clean_task_data(fl_ctx: FLContext)[source]

Clean the task data in the FLContext.

Parameters:

fl_ctx (FLContext) – The FLContext to clean the task data from.

clean_task_result(fl_ctx: FLContext)[source]

Clean the task result in the FLContext.

Parameters:

fl_ctx (FLContext) – The FLContext to clean the task result from.

copy_non_tensor_params(params: dict[str, dict]) dict[str, dict][source]

Recursively copy non-tensor parameters in the given dictionary.

Parameters:

params – The dictionary of parameters to copy from.

Returns:

A new dictionary containing only non-tensor parameters.

get_dxo_from_ctx(fl_ctx: FLContext, ctx_prop_key: str, tasks: list[str]) DXO[source]

Extract model parameters from the FLContext based on the provided property key.

Parameters:
  • fl_ctx (FLContext) – The FLContext containing the data.

  • ctx_prop_key (str) – The context property key to extract data from.

  • tasks (list[str]) – The list of tasks to consider.

Returns:

A dictionary of data extracted from the FLContext.

Return type:

dict[str, torch.Tensor]

get_targets_from_ctx_and_prop_key(fl_ctx: FLContext, ctx_prop_key: str) list[str][source]

Get the peer identity name from the FLContext.

Parameters:

fl_ctx (FLContext) – The FLContext for the current operation.

Returns:

The identity name(s) of the peer(s).

Return type:

list[str]

get_topic_for_ctx_prop_key(ctx_prop_key: str) str[source]

Get the topic based on the context property key.

Parameters:

ctx_prop_key (str) – The context property key.

Returns:

The topic associated with the context property key.

Return type:

str

merge_params_dicts(base_params: dict[str, dict], new_params: dict[str, dict], to_ndarray: bool = False) dict[str, dict][source]

Merges two nested dictionaries of parameters.

Parameters:
  • base_params – The base dictionary to merge into.

  • new_params – The new dictionary whose values will overwrite those in base_params.

  • to_ndarray – If True, converts torch tensors to numpy arrays during merge.

Returns:

The merged dictionary with values from new_params overwriting those in base_params.

to_numpy_recursive(obj: Tensor | dict[str, Tensor]) dict[str, ndarray] | ndarray[source]

Recursively convert torch tensors to numpy arrays with minimal memory duplication.

Note: For CPU tensors, .numpy() returns a view sharing memory with the original tensor (zero-copy). For GPU tensors, data must be moved to CPU first, which creates a copy. Only the dictionary structure is duplicated, not the underlying tensor data (for CPU tensors).

Parameters:

obj – A torch.Tensor or dict containing torch.Tensors (possibly nested)

Returns:

A numpy array or dict containing numpy arrays. Tensor data is shared where possible (CPU tensors).

update_params_with_tensors(params: dict, parents: list[str], tensors: dict[str, Tensor], to_ndarray: bool = False) None[source]

Updates the nested dictionary params at the location specified by parents with the provided tensor values from tensors.

If to_ndarray is True, tensors are converted to numpy ndarrays before insertion.

Parameters:
  • params – The dictionary to update (possibly nested).

  • parents – List of keys that specify the nested path within params.

  • tensors – Dictionary mapping keys to torch.Tensor instances.

  • to_ndarray – Whether to convert tensors to numpy arrays before updating.