nvflare.app_opt.lightning.api module
- class FLCallback(rank: int = 0, load_state_dict_strict: bool = True, update_fit_loop: bool = True)[source]
Bases:
CallbackFL callback for lightning API.
- Parameters:
rank – global rank of the PyTorch Lightning trainer.
load_state_dict_strict – exposes strict argument of torch.nn.Module.load_state_dict() used to load the received model. Defaults to True. See https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict for details. NVFlare still validates incoming keys and shapes before calling
load_state_dict(). WithTrue, unexpected incoming keys are treated as contract drift and fail fast. WithFalse, unexpected keys are logged and ignored, while compatible keys are still loaded.update_fit_loop – whether to increase trainer.fit_loop.max_epochs and trainer.fit_loop.epoch_loop.max_steps each FL round. Defaults to True which is suitable for most PyTorch Lightning applications.
- patch(trainer: Trainer, restore_state: bool = True, load_state_dict_strict: bool = True, update_fit_loop: bool = True)[source]
Patches the PyTorch Lightning Trainer for usage with NVFlare.
- Parameters:
trainer – the PyTorch Lightning trainer.
restore_state – whether to restore optimizer and learning rate scheduler states. Defaults to True.
load_state_dict_strict – exposes strict argument of torch.nn.Module.load_state_dict() used to load the received model. Defaults to True. See https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict for details. NVFlare still validates incoming keys and shapes before calling
load_state_dict(). WithTrue, any incoming key that does not exist in the local Lightning module is rejected before loading. WithFalse, NVFlare warns and filters the payload down to matching keys, which is useful for partial model updates where the client only keeps part of the server keyspace.update_fit_loop – whether to increase trainer.fit_loop.max_epochs and trainer.fit_loop.epoch_loop.max_steps each FL round. Defaults to True which is suitable for most PyTorch Lightning applications.
Example
Normal usage:
trainer = Trainer(max_epochs=1) flare.patch(trainer)
Advanced usage:
If users want to pass additional information to FLARE server side via the lightning API, they will need to set the information inside the attributes called
__fl_meta__in their LightningModule.class LitNet(LightningModule): def __init__(self): super().__init__() self.save_hyperparameters() self.model = Net() self.train_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES) self.valid_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES) self.__fl_meta__ = {"CUSTOM_VAR": "VALUE_OF_THE_VAR"}