nvflare.app_opt.lightning.api module

class FLCallback(rank: int = 0, load_state_dict_strict: bool = True, update_fit_loop: bool = True)[source]

Bases: Callback

FL callback for lightning API.

Parameters:
  • rank – global rank of the PyTorch Lightning trainer.

  • load_state_dict_strict – exposes strict argument of torch.nn.Module.load_state_dict() used to load the received model. Defaults to True. See https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict for details. NVFlare still validates incoming keys and shapes before calling load_state_dict(). With True, unexpected incoming keys are treated as contract drift and fail fast. With False, unexpected keys are logged and ignored, while compatible keys are still loaded.

  • update_fit_loop – whether to increase trainer.fit_loop.max_epochs and trainer.fit_loop.epoch_loop.max_steps each FL round. Defaults to True which is suitable for most PyTorch Lightning applications.

on_train_end(trainer, pl_module)[source]

Called when the train ends.

on_train_start(trainer, pl_module)[source]

Called when the train begins.

on_validation_end(trainer, pl_module)[source]

Called when the validation loop ends.

on_validation_start(trainer, pl_module)[source]

Called when the validation loop begins.

reset_state(trainer)[source]

Resets the state.

If the next round of federated training needs to reuse the same callback instance, the reset_state() needs to be called first Not only resets the states, also sets states for next round

patch(trainer: Trainer, restore_state: bool = True, load_state_dict_strict: bool = True, update_fit_loop: bool = True)[source]

Patches the PyTorch Lightning Trainer for usage with NVFlare.

Parameters:
  • trainer – the PyTorch Lightning trainer.

  • restore_state – whether to restore optimizer and learning rate scheduler states. Defaults to True.

  • load_state_dict_strict – exposes strict argument of torch.nn.Module.load_state_dict() used to load the received model. Defaults to True. See https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict for details. NVFlare still validates incoming keys and shapes before calling load_state_dict(). With True, any incoming key that does not exist in the local Lightning module is rejected before loading. With False, NVFlare warns and filters the payload down to matching keys, which is useful for partial model updates where the client only keeps part of the server keyspace.

  • update_fit_loop – whether to increase trainer.fit_loop.max_epochs and trainer.fit_loop.epoch_loop.max_steps each FL round. Defaults to True which is suitable for most PyTorch Lightning applications.

Example

Normal usage:

trainer = Trainer(max_epochs=1)
flare.patch(trainer)

Advanced usage:

If users want to pass additional information to FLARE server side via the lightning API, they will need to set the information inside the attributes called __fl_meta__ in their LightningModule.

class LitNet(LightningModule):
    def __init__(self):
        super().__init__()
        self.save_hyperparameters()
        self.model = Net()
        self.train_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES)
        self.valid_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES)
        self.__fl_meta__ = {"CUSTOM_VAR": "VALUE_OF_THE_VAR"}