nvflare.app_opt.pt.job_config.model module
- class PTModel(model: Module | Dict[str, Any], persistor: ModelPersistor | None = None, locator: ModelLocator | None = None, allow_numpy_conversion: bool = True, initial_ckpt: str | None = None, best_model_filename: str | None = None)[source]
Bases:
objectPyTorch model wrapper.
Supports two input modes: 1. nn.Module instance - existing behavior 2. Dict config {“path”: “module.Class”, “args”: {…}} - new
Note: PyTorch requires model for architecture because .pt/.pth files only contain state_dict (weights), not model architecture.
- Parameters:
model – Model input (required), can be: - nn.Module: Model instance (existing behavior) - dict: {“path”: “fully.qualified.Class”, “args”: {…}}
persistor (optional, ModelPersistor) – Custom persistor. If None, creates default.
locator (optional, ModelLocator) – Custom locator. If None, creates default.
allow_numpy_conversion (bool) – If True, enables conversion between PyTorch tensors and NumPy arrays. Defaults to True.
initial_ckpt (str, optional) – Absolute path to checkpoint file. May not exist locally (server-side path). Used to load pre-trained weights.
best_model_filename (str, optional) – Filename for saving the best global model.