nvflare.app_opt.pt.utils module

class ModelParamMatchReport(external_key_count: int, local_key_count: int, external_key_sample: tuple[str, ...], local_key_sample: tuple[str, ...], matched_keys: tuple[str, ...], unexpected_keys: tuple[str, ...], shape_mismatches: tuple[ParamShapeMismatch, ...], prefix_hint: str | None = None)[source]

Bases: object

Summary of how an incoming parameter payload matches a local keyspace.

The report is intentionally descriptive rather than prescriptive: callers can decide whether to warn, fail fast, or filter to matched_keys.

external_key_count

Number of keys present in the incoming payload.

Type:

int

local_key_count

Number of keys present in the local model/checkpoint.

Type:

int

external_key_sample

Up to the first five sorted keys from the incoming payload. This is only for diagnostics and does not imply ordering.

Type:

tuple[str, …]

local_key_sample

Up to the first five sorted keys from the local model/checkpoint keyspace. This is only for diagnostics.

Type:

tuple[str, …]

matched_keys

Incoming keys that exist locally and have compatible shapes. Partial overlap is allowed.

Type:

tuple[str, …]

unexpected_keys

Incoming keys that do not exist locally.

Type:

tuple[str, …]

shape_mismatches

Keys that exist in both places but whose shapes differ.

Type:

tuple[nvflare.app_opt.pt.utils.ParamShapeMismatch, …]

prefix_hint

Optional hint for common wrapper drift such as model. prefixes on all incoming keys.

Type:

str | None

external_key_count: int
external_key_sample: tuple[str, ...]
format_context() str[source]
format_shape_mismatch_error() str[source]
format_unexpected_keys_error() str[source]
format_unexpected_keys_warning() str[source]
format_zero_match_error() str[source]
local_key_count: int
local_key_sample: tuple[str, ...]
matched_keys: tuple[str, ...]
prefix_hint: str | None = None
shape_mismatches: tuple[ParamShapeMismatch, ...]
unexpected_keys: tuple[str, ...]
class ParamShapeMismatch(key: 'str', expected_shape: 'tuple', received_shape: 'tuple')[source]

Bases: object

expected_shape: tuple
key: str
received_shape: tuple
feed_vars(model: Module, model_params)[source]

Feed variable values from model_params to pytorch state_dict.

Parameters:
  • model (nn.Module) – the local pytorch model

  • model_params – incoming parameter mapping keyed by state-dict name.

Returns:

a list of params and a dictionary of vars to params

Raises:

RuntimeError – if a matching key has a shape mismatch, or if a non-empty incoming payload has zero compatible matches with the local model.

Notes

Empty payloads are treated as a no-op. Partial payloads are accepted as long as at least one key matches; unknown keys are ignored with a warning instead of being applied to the local state dict. This is for loading a received model into a local PyTorch module. Server-side validation of learned client updates is handled by PTModelPersistenceFormatManager and rejects keys outside the server checkpoint schema.

inspect_model_params(local_var_dict: Mapping[str, object], model_params: Mapping[str, object] | None) ModelParamMatchReport[source]

Compare incoming model parameters against a local model/checkpoint keyspace.

This helper does not mutate either input. It only classifies the incoming keys into matches, unexpected keys, and shape mismatches, and captures small key samples to make diagnostics readable.

Partial payloads are valid: callers can choose to accept a subset of keys as long as there is at least one compatible match. Empty or missing payloads are also valid and return an all-empty report.

Parameters:
  • local_var_dict – Local model or checkpoint mapping keyed by parameter name.

  • model_params – Incoming parameter mapping to validate.

Returns:

A ModelParamMatchReport describing how the incoming payload relates to the local keyspace.