# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
from typing import Optional
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.abstract.aggregator import Aggregator
from nvflare.app_common.abstract.learnable_persistor import LearnablePersistor
from nvflare.app_common.abstract.model import ModelLearnable, make_model_learnable
from nvflare.app_common.abstract.shareable_generator import ShareableGenerator
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.app_event_type import AppEventType
from nvflare.app_common.shareablegenerators.passthru import PassthroughShareableGenerator
from nvflare.edge.assessor import Assessment, Assessor
from nvflare.fuel.utils.validation_utils import check_str
from nvflare.security.logging import secure_format_exception
[docs]
class SGAPAssessor(Assessor):
def __init__(self, shareable_generator_id: str, aggregator_id: str, persistor_id: str):
"""This assessor implements its required logic by using a Shareable Generator, an Aggregator, and a
Persistor (SGAP).
Args:
shareable_generator_id: component ID of the Shareable Generator. If empty, the PassthroughShareableGenerator
will be used.
aggregator_id: component ID of the Aggregator.
persistor_id: component ID of the Persistor. If not specified, the Persistor will load initial model
and save the final model.
"""
Assessor.__init__(self)
check_str("persistor_id", persistor_id)
check_str("shareable_generator_id", shareable_generator_id)
check_str("aggregator_id", aggregator_id)
self.aggregator_id = aggregator_id
self.shareable_generator_id = shareable_generator_id
self.persistor_id = persistor_id
self._global_weights = make_model_learnable({}, {})
self._aggr_lock = threading.Lock()
self.shareable_gen = None
self.aggregator = None
self.persistor = None
self.register_event_handler(EventType.START_RUN, self._handle_start_run)
def _handle_start_run(self, event_type: str, fl_ctx: FLContext):
engine = fl_ctx.get_engine()
self.aggregator = engine.get_component(self.aggregator_id)
if not isinstance(self.aggregator, Aggregator):
self.system_panic(
f"aggregator {self.aggregator_id} must be an Aggregator type object but got {type(self.aggregator)}",
fl_ctx,
)
return
if self.shareable_generator_id:
self.shareable_gen = engine.get_component(self.shareable_generator_id)
if not isinstance(self.shareable_gen, ShareableGenerator):
self.system_panic(
f"Shareable generator {self.shareable_generator_id} must be a ShareableGenerator type object, "
f"but got {type(self.shareable_gen)}",
fl_ctx,
)
return
else:
self.shareable_gen = PassthroughShareableGenerator()
if self.persistor_id:
self.persistor = engine.get_component(self.persistor_id)
if not isinstance(self.persistor, LearnablePersistor):
self.system_panic(
f"Persistor {self.persistor_id} must be a LearnablePersistor type object, "
f"but got {type(self.persistor)}",
fl_ctx,
)
return
if self.persistor:
model = self.persistor.load(fl_ctx)
if not isinstance(model, ModelLearnable):
self.system_panic(
reason=f"Expected model loaded by persistor to be `ModelLearnable` but received {type(model)}",
fl_ctx=fl_ctx,
)
return
self._global_weights = model
fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, model, private=True, sticky=True)
self.fire_event(AppEventType.INITIAL_MODEL_LOADED, fl_ctx)
[docs]
def start_task(self, fl_ctx: FLContext) -> Shareable:
# Use the Shareable Generator to generate task data
current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND)
self.log_info(fl_ctx, f"starting round {current_round}")
return self.shareable_gen.learnable_to_shareable(self._global_weights, fl_ctx)
[docs]
def process_child_update(self, data: Shareable, fl_ctx: FLContext) -> (bool, Optional[Shareable]):
# Process update from child.
with self._aggr_lock:
accepted = self.aggregator.accept(data, fl_ctx)
return accepted, None
[docs]
def end_task(self, fl_ctx: FLContext):
current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND)
self.log_info(fl_ctx, f"Start aggregation for round {current_round}")
self.fire_event(AppEventType.BEFORE_AGGREGATION, fl_ctx)
with self._aggr_lock:
try:
aggr_result = self.aggregator.aggregate(fl_ctx)
self.aggregator.reset(fl_ctx)
except Exception as ex:
self.log_exception(fl_ctx, f"aggregation error from {type(self.aggregator)}")
self.system_panic(f"aggregation error: {secure_format_exception(ex)}", fl_ctx)
return
self.fire_event_with_data(AppEventType.AFTER_AGGREGATION, fl_ctx, AppConstants.AGGREGATION_RESULT, aggr_result)
self.log_info(fl_ctx, f"End aggregation for round {current_round}.")
self.fire_event(AppEventType.BEFORE_SHAREABLE_TO_LEARNABLE, fl_ctx)
self._global_weights = self.shareable_gen.shareable_to_learnable(aggr_result, fl_ctx)
fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, self._global_weights, private=True, sticky=True)
self.fire_event(AppEventType.AFTER_SHAREABLE_TO_LEARNABLE, fl_ctx)
if self.persistor:
self.log_info(fl_ctx, f"Start persist model on server for round {current_round}.")
self.fire_event(AppEventType.BEFORE_LEARNABLE_PERSIST, fl_ctx)
self.persistor.save(self._global_weights, fl_ctx)
self.fire_event(AppEventType.AFTER_LEARNABLE_PERSIST, fl_ctx)
self.log_info(fl_ctx, f"End persist model on server for round {current_round}")
[docs]
def do_assessment(self, fl_ctx: FLContext):
return Assessment.CONTINUE
[docs]
def assess(self, fl_ctx: FLContext) -> Assessment:
return self.do_assessment(fl_ctx)