nvflare.app_opt.lightning.api module¶
- class FLCallback(rank: int = 0, load_state_dict_strict: bool = True)[source]¶
Bases:
Callback
FL callback for lightning API.
- Parameters:
rank – global rank of the PyTorch Lightning trainer.
load_state_dict_strict – exposes strict argument of torch.nn.Module.load_state_dict() used to load the received model. Defaults to True. See https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict for details.
- patch(trainer: Trainer, restore_state: bool = True, load_state_dict_strict: bool = True)[source]¶
Patches the PyTorch Lightning Trainer for usage with NVFlare.
- Parameters:
trainer – the PyTorch Lightning trainer.
restore_state – whether to restore optimizer and learning rate scheduler states. Defaults to True.
load_state_dict_strict – exposes strict argument of torch.nn.Module.load_state_dict() used to load the received model. Defaults to True. See https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict for details.
Example
Normal usage:
trainer = Trainer(max_epochs=1) flare.patch(trainer)
Advanced usage:
If users want to pass additional information to FLARE server side via the lightning API, they will need to set the information inside the attributes called
__fl_meta__
in their LightningModule.class LitNet(LightningModule): def __init__(self): super().__init__() self.save_hyperparameters() self.model = Net() self.train_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES) self.valid_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES) self.__fl_meta__ = {"CUSTOM_VAR": "VALUE_OF_THE_VAR"}