# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nvflare.fuel.utils.memory_utils import cleanup_memory
from nvflare.fuel.utils.validation_utils import check_non_negative_int
from .model_controller import ModelController
[docs]
class Cyclic(ModelController):
def __init__(
self,
*args,
num_clients: int = 2,
num_rounds: int = 5,
start_round: int = 0,
memory_gc_rounds: int = 0,
**kwargs,
):
"""The Cyclic ModelController to implement the Cyclic Weight Transfer (CWT) algorithm.
Args:
num_clients (int, optional): The number of clients. Defaults to 2.
num_rounds (int, optional): The total number of training rounds. Defaults to 5.
start_round (int, optional): The starting round number. Defaults to 0.
memory_gc_rounds (int, optional): Run memory cleanup (gc.collect + malloc_trim)
every N rounds. Set to 0 to disable. Defaults to 0 (disabled).
"""
super().__init__(*args, **kwargs)
check_non_negative_int("memory_gc_rounds", memory_gc_rounds)
self.num_clients = num_clients
self.num_rounds = num_rounds
self.start_round = start_round
self.memory_gc_rounds = memory_gc_rounds
self.current_round = None
def _maybe_cleanup_memory(self):
"""Perform memory cleanup if configured (every N rounds based on memory_gc_rounds)."""
if self.current_round is None:
return
if self.memory_gc_rounds > 0 and (self.current_round + 1) % self.memory_gc_rounds == 0:
cleanup_memory()
[docs]
def run(self) -> None:
self.info("Start Cyclic.")
model = self.load_model()
model.start_round = self.start_round
model.total_rounds = self.num_rounds
for self.current_round in range(self.start_round, self.start_round + self.num_rounds):
self.info(f"Round {self.current_round} started.")
model.current_round = self.current_round
clients = self.sample_clients(self.num_clients)
for client in clients:
result = self.send_model_and_wait(targets=[client], data=model)[0]
model.params, model.meta = result.params, result.meta
self.save_model(model)
# Memory cleanup at end of round (if configured)
self._maybe_cleanup_memory()
self.info("Finished Cyclic.")