nvflare.app_opt.lightning.api module

class FLCallback(rank: int = 0, load_state_dict_strict: bool = True, update_fit_loop: bool = True)[source]

Bases: Callback

FL callback for lightning API.

Parameters:
  • rank – global rank of the PyTorch Lightning trainer.

  • load_state_dict_strict – exposes strict argument of torch.nn.Module.load_state_dict() used to load the received model. Defaults to True. See https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict for details.

  • update_fit_loop – whether to increase trainer.fit_loop.max_epochs and trainer.fit_loop.epoch_loop.max_steps each FL round. Defaults to True which is suitable for most PyTorch Lightning applications.

on_train_end(trainer, pl_module)[source]

Called when the train ends.

on_train_start(trainer, pl_module)[source]

Called when the train begins.

on_validation_end(trainer, pl_module)[source]

Called when the validation loop ends.

on_validation_start(trainer, pl_module)[source]

Called when the validation loop begins.

reset_state(trainer)[source]

Resets the state.

If the next round of federated training needs to reuse the same callback instance, the reset_state() needs to be called first Not only resets the states, also sets states for next round

patch(trainer: Trainer, restore_state: bool = True, load_state_dict_strict: bool = True, update_fit_loop: bool = True)[source]

Patches the PyTorch Lightning Trainer for usage with NVFlare.

Parameters:
  • trainer – the PyTorch Lightning trainer.

  • restore_state – whether to restore optimizer and learning rate scheduler states. Defaults to True.

  • load_state_dict_strict – exposes strict argument of torch.nn.Module.load_state_dict() used to load the received model. Defaults to True. See https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict for details.

  • update_fit_loop – whether to increase trainer.fit_loop.max_epochs and trainer.fit_loop.epoch_loop.max_steps each FL round. Defaults to True which is suitable for most PyTorch Lightning applications.

Example

Normal usage:

trainer = Trainer(max_epochs=1)
flare.patch(trainer)

Advanced usage:

If users want to pass additional information to FLARE server side via the lightning API, they will need to set the information inside the attributes called __fl_meta__ in their LightningModule.

class LitNet(LightningModule):
    def __init__(self):
        super().__init__()
        self.save_hyperparameters()
        self.model = Net()
        self.train_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES)
        self.valid_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES)
        self.__fl_meta__ = {"CUSTOM_VAR": "VALUE_OF_THE_VAR"}