Source code for nvflare.app_opt.pt.job_config.base_fed_job

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional

from torch import nn as nn

from nvflare.apis.fl_component import FLComponent
from nvflare.app_common.abstract.model_locator import ModelLocator
from nvflare.app_common.abstract.model_persistor import ModelPersistor
from nvflare.app_common.widgets.convert_to_fed_event import ConvertToFedEvent
from nvflare.app_common.widgets.streaming import AnalyticsReceiver
from nvflare.app_common.widgets.validation_json_generator import ValidationJsonGenerator
from nvflare.job_config.base_fed_job import BaseFedJob as UnifiedBaseFedJob


[docs] class BaseFedJob(UnifiedBaseFedJob): """PyTorch BaseFedJob. This is a backward-compatible wrapper around the unified BaseFedJob. Configures ValidationJsonGenerator, model selector, AnalyticsReceiver, ConvertToFedEvent. User must add controllers and executors. Args: initial_model (nn.Module, optional): Initial PyTorch model. Defaults to None. name (str, optional): name of the job. Defaults to "fed_job". min_clients (int, optional): the minimum number of clients for the job. Defaults to 1. mandatory_clients (list[str] | None, optional): mandatory clients to run the job. Default None. key_metric (str, optional): Metric used to determine if the model is globally best. if metrics are a `dict`, `key_metric` can select the metric used for global model selection. Defaults to "accuracy". validation_json_generator (ValidationJsonGenerator | None, optional): A component for generating validation results. if not provided, a ValidationJsonGenerator will be configured. model_selector: (FLComponent | None, optional): A component for selecting the best model during training. This event-driven component evaluates and tracks model performance across training rounds, handling workflow events such as BEFORE_AGGREGATION and BEFORE_CONTRIBUTION_ACCEPT. If not provided, an IntimeModelSelector will be configured based on key_metric. convert_to_fed_event: (ConvertToFedEvent | None, optional): A component to convert certain events to fed events. if not provided, a ConvertToFedEvent object will be created. analytics_receiver (AnalyticsReceiver | None, optional): Component for receiving analytics data. If not provided, no analytics tracking will be enabled. For experiment tracking (e.g., TensorBoard), explicitly pass a TBAnalyticsReceiver instance. model_persistor (ModelPersistor | None, optional): how to persist the model. model_locator (ModelLocator | None, optional): how to locate the model. """ def __init__( self, initial_model: nn.Module = None, name: str = "fed_job", min_clients: int = 1, mandatory_clients: Optional[List[str]] = None, key_metric: str = "accuracy", validation_json_generator: Optional[ValidationJsonGenerator] = None, model_selector: Optional[FLComponent] = None, convert_to_fed_event: Optional[ConvertToFedEvent] = None, analytics_receiver: Optional[AnalyticsReceiver] = None, model_persistor: Optional[ModelPersistor] = None, model_locator: Optional[ModelLocator] = None, ): # Call the unified BaseFedJob super().__init__( name=name, min_clients=min_clients, mandatory_clients=mandatory_clients, key_metric=key_metric, validation_json_generator=validation_json_generator, model_selector=model_selector, convert_to_fed_event=convert_to_fed_event, analytics_receiver=analytics_receiver, ) # PyTorch-specific model setup if initial_model is not None: if not isinstance(initial_model, nn.Module): raise TypeError( f"initial_model must be an instance of nn.Module, but got {type(initial_model).__name__}" ) self._setup_pytorch_model(initial_model, model_persistor, model_locator) def _setup_pytorch_model( self, initial_model: nn.Module, persistor: Optional[ModelPersistor] = None, locator: Optional[ModelLocator] = None, ): """Setup PyTorch model with persistor and locator.""" from nvflare.app_opt.pt.job_config.model import PTModel pt_model = PTModel(model=initial_model, persistor=persistor, locator=locator) self.comp_ids.update(self.to_server(pt_model))