nvflare.app_opt.lightning.api module

class FLCallback(rank: int = 0, load_state_dict_strict: bool = True)[source]

Bases: Callback

FL callback for lightning API.

Parameters:
on_train_end(trainer, pl_module)[source]

Called when the train ends.

on_train_start(trainer, pl_module)[source]

Called when the train begins.

on_validation_end(trainer, pl_module)[source]

Called when the validation loop ends.

on_validation_start(trainer, pl_module)[source]

Called when the validation loop begins.

reset_state(trainer)[source]

Resets the state.

If the next round of federated training needs to reuse the same callback instance, the reset_state() needs to be called first Not only resets the states, also sets states for next round

patch(trainer: Trainer, restore_state: bool = True, load_state_dict_strict: bool = True)[source]

Patches the PyTorch Lightning Trainer for usage with NVFlare.

Parameters:

Example

Normal usage:

trainer = Trainer(max_epochs=1)
flare.patch(trainer)

Advanced usage:

If users want to pass additional information to FLARE server side via the lightning API, they will need to set the information inside the attributes called __fl_meta__ in their LightningModule.

class LitNet(LightningModule):
    def __init__(self):
        super().__init__()
        self.save_hyperparameters()
        self.model = Net()
        self.train_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES)
        self.valid_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES)
        self.__fl_meta__ = {"CUSTOM_VAR": "VALUE_OF_THE_VAR"}