Source code for nvflare.edge.device.pt.trainer

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.nn as nn

from nvflare.apis.dxo import DXO
from nvflare.edge.device.defs import Context, ContextKey, DataSource, EventType, Executor, Signal, Transform


[docs] class PTTrainer(Executor): def __init__(self, epoch: int, lr, loss_fn, optimizer, transforms): self.epoch = epoch self.lr = lr self.loss_fn = loss_fn self.optimizer = optimizer self.transforms = transforms
[docs] def execute(self, task_data: DXO, ctx: Context, abort_signal: Signal) -> DXO: # the model must have been converted to nn.Module by some filters model = task_data.data.get("model") assert isinstance(model, nn.Module) params = task_data.data.get("params") data_source = ctx.get(ContextKey.DATA_SOURCE) assert isinstance(data_source, DataSource) # load the dataset train_dataset = data_source.get_dataset(dataset_type="train", ctx=ctx) # loss function and optimizer lr = params.get("learning_rate") if not lr: lr = self.lr optimizer = self.optimizer.get(model.parameters(), lr=lr) batch_size = params.get("batch_size") if not batch_size: batch_size = 10 n_epochs = params.get("num_epochs") if not n_epochs: n_epochs = self.epoch batches_per_epoch = (train_dataset.size() + batch_size - 1) / batch_size for epoch in range(n_epochs): for i in range(batches_per_epoch): batch = train_dataset.get_next_batch(batch_size) if self.transforms: for t in self.transforms: assert isinstance(t, Transform) batch = t.transform(batch, ctx, abort_signal) x_batch = batch.get_input() y_batch = batch.get_label() # forward pass y_pred = model(x_batch) loss = self.loss_fn(y_pred, y_batch) ctx.fire_event( EventType.LOSS_GENERATED, data={ "loss": float(loss), "epoch": epoch, "iter": i, }, abort_signal=abort_signal, ) # backward pass optimizer.zero_grad() loss.backward() # update weights optimizer.step() # reset dataset for next epoch train_dataset.reset() weights = {} for name, param in model.named_parameters(): if param.requires_grad: weights[name] = param.data return DXO( data_kind="model", data={ "weights": weights, }, meta={ "dataset_size": train_dataset.size(), "batch_size": batch_size, "num_epochs": n_epochs, "lr": lr, }, )