nvflare.app_opt.pt.job_config.model module

class PTModel(model: Module | Dict[str, Any], persistor: ModelPersistor | None = None, locator: ModelLocator | None = None, allow_numpy_conversion: bool = True, initial_ckpt: str | None = None, best_model_filename: str | None = None)[source]

Bases: object

PyTorch model wrapper.

Supports two input modes: 1. nn.Module instance - existing behavior 2. Dict config {“path”: “module.Class”, “args”: {…}} - new

Note: PyTorch requires model for architecture because .pt/.pth files only contain state_dict (weights), not model architecture.

Parameters:
  • model – Model input (required), can be: - nn.Module: Model instance (existing behavior) - dict: {“path”: “fully.qualified.Class”, “args”: {…}}

  • persistor (optional, ModelPersistor) – Custom persistor. If None, creates default.

  • locator (optional, ModelLocator) – Custom locator. If None, creates default.

  • allow_numpy_conversion (bool) – If True, enables conversion between PyTorch tensors and NumPy arrays. Defaults to True.

  • initial_ckpt (str, optional) – Absolute path to checkpoint file. May not exist locally (server-side path). Used to load pre-trained weights.

  • best_model_filename (str, optional) – Filename for saving the best global model.

add_to_fed_job(job, ctx)[source]

This method is used by Job API.

Parameters:
  • job – the Job object to add to

  • ctx – Job Context

Returns:

dictionary of ids of component added