nvflare.app_opt.pt.utils module
- class ModelParamMatchReport(external_key_count: int, local_key_count: int, external_key_sample: tuple[str, ...], local_key_sample: tuple[str, ...], matched_keys: tuple[str, ...], unexpected_keys: tuple[str, ...], shape_mismatches: tuple[ParamShapeMismatch, ...], prefix_hint: str | None = None)[source]
Bases:
objectSummary of how an incoming parameter payload matches a local keyspace.
The report is intentionally descriptive rather than prescriptive: callers can decide whether to warn, fail fast, or filter to matched_keys.
- external_key_count
Number of keys present in the incoming payload.
- Type:
int
- local_key_count
Number of keys present in the local model/checkpoint.
- Type:
int
- external_key_sample
Up to the first five sorted keys from the incoming payload. This is only for diagnostics and does not imply ordering.
- Type:
tuple[str, …]
- local_key_sample
Up to the first five sorted keys from the local model/checkpoint keyspace. This is only for diagnostics.
- Type:
tuple[str, …]
- matched_keys
Incoming keys that exist locally and have compatible shapes. Partial overlap is allowed.
- Type:
tuple[str, …]
- unexpected_keys
Incoming keys that do not exist locally.
- Type:
tuple[str, …]
- shape_mismatches
Keys that exist in both places but whose shapes differ.
- Type:
- prefix_hint
Optional hint for common wrapper drift such as
model.prefixes on all incoming keys.- Type:
str | None
- external_key_count: int
- external_key_sample: tuple[str, ...]
- local_key_count: int
- local_key_sample: tuple[str, ...]
- matched_keys: tuple[str, ...]
- prefix_hint: str | None = None
- shape_mismatches: tuple[ParamShapeMismatch, ...]
- unexpected_keys: tuple[str, ...]
- class ParamShapeMismatch(key: 'str', expected_shape: 'tuple', received_shape: 'tuple')[source]
Bases:
object- expected_shape: tuple
- key: str
- received_shape: tuple
- feed_vars(model: Module, model_params)[source]
Feed variable values from model_params to pytorch state_dict.
- Parameters:
model (nn.Module) – the local pytorch model
model_params – incoming parameter mapping keyed by state-dict name.
- Returns:
a list of params and a dictionary of vars to params
- Raises:
RuntimeError – if a matching key has a shape mismatch, or if a non-empty incoming payload has zero compatible matches with the local model.
Notes
Empty payloads are treated as a no-op. Partial payloads are accepted as long as at least one key matches; unknown keys are ignored with a warning instead of being applied to the local state dict. This is for loading a received model into a local PyTorch module. Server-side validation of learned client updates is handled by
PTModelPersistenceFormatManagerand rejects keys outside the server checkpoint schema.
- inspect_model_params(local_var_dict: Mapping[str, object], model_params: Mapping[str, object] | None) ModelParamMatchReport[source]
Compare incoming model parameters against a local model/checkpoint keyspace.
This helper does not mutate either input. It only classifies the incoming keys into matches, unexpected keys, and shape mismatches, and captures small key samples to make diagnostics readable.
Partial payloads are valid: callers can choose to accept a subset of keys as long as there is at least one compatible match. Empty or missing payloads are also valid and return an all-empty report.
- Parameters:
local_var_dict – Local model or checkpoint mapping keyed by parameter name.
model_params – Incoming parameter mapping to validate.
- Returns:
A
ModelParamMatchReportdescribing how the incoming payload relates to the local keyspace.