Source code for nvflare.app_opt.pt.job_config.base_fed_job

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional

from torch import nn as nn

from nvflare.app_common.abstract.model_locator import ModelLocator
from nvflare.app_common.abstract.model_persistor import ModelPersistor
from nvflare.app_common.tracking.tracker_types import ANALYTIC_EVENT_TYPE
from nvflare.app_common.widgets.convert_to_fed_event import ConvertToFedEvent
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
from nvflare.app_common.widgets.streaming import AnalyticsReceiver
from nvflare.app_common.widgets.validation_json_generator import ValidationJsonGenerator
from nvflare.app_opt.pt.job_config.model import PTModel
from nvflare.app_opt.tracking.tb.tb_receiver import TBAnalyticsReceiver
from nvflare.job_config.api import FedJob, validate_object_for_job


[docs] class BaseFedJob(FedJob): def __init__( self, initial_model: nn.Module = None, name: str = "fed_job", min_clients: int = 1, mandatory_clients: Optional[List[str]] = None, key_metric: str = "accuracy", validation_json_generator: Optional[ValidationJsonGenerator] = None, intime_model_selector: Optional[IntimeModelSelector] = None, convert_to_fed_event: Optional[ConvertToFedEvent] = None, analytics_receiver: Optional[AnalyticsReceiver] = None, model_persistor: Optional[ModelPersistor] = None, model_locator: Optional[ModelLocator] = None, ): """PyTorch BaseFedJob. Configures ValidationJsonGenerator, IntimeModelSelector, AnalyticsReceiver, ConvertToFedEvent. User must add controllers and executors. Args: initial_model (nn.Module): initial PyTorch Model. Defaults to None. name (name, optional): name of the job. Defaults to "fed_job". min_clients (int, optional): the minimum number of clients for the job. Defaults to 1. mandatory_clients (List[str], optional): mandatory clients to run the job. Default None. key_metric (str, optional): Metric used to determine if the model is globally best. if metrics are a `dict`, `key_metric` can select the metric used for global model selection. Defaults to "accuracy". validation_json_generator (ValidationJsonGenerator, optional): A component for generating validation results. if not provided, a ValidationJsonGenerator will be configured. intime_model_selector: (IntimeModelSelector, optional): A component for select the model. if not provided, an IntimeModelSelector will be configured. convert_to_fed_event: (ConvertToFedEvent, optional): A component to covert certain events to fed events. if not provided, a ConvertToFedEvent object will be created. analytics_receiver (AnlyticsReceiver, optional): Receive analytics. If not provided, a TBAnalyticsReceiver will be configured. model_persistor (optional, ModelPersistor): how to persistor the model. model_locator (optional, ModelLocator): how to locate the model. """ super().__init__( name=name, min_clients=min_clients, mandatory_clients=mandatory_clients, ) self.initial_model = initial_model self.comp_ids = {} if validation_json_generator: validate_object_for_job("validation_json_generator", validation_json_generator, ValidationJsonGenerator) else: validation_json_generator = ValidationJsonGenerator() self.to_server(id="json_generator", obj=validation_json_generator) if intime_model_selector: validate_object_for_job("intime_model_selector", intime_model_selector, IntimeModelSelector) self.to_server(id="model_selector", obj=intime_model_selector) elif key_metric: self.to_server(id="model_selector", obj=IntimeModelSelector(key_metric=key_metric)) if convert_to_fed_event: validate_object_for_job("convert_to_fed_event", convert_to_fed_event, ConvertToFedEvent) else: convert_to_fed_event = ConvertToFedEvent(events_to_convert=[ANALYTIC_EVENT_TYPE]) self.convert_to_fed_event = convert_to_fed_event if analytics_receiver: validate_object_for_job("analytics_receiver", analytics_receiver, AnalyticsReceiver) else: analytics_receiver = TBAnalyticsReceiver() self.to_server( id="receiver", obj=analytics_receiver, ) if initial_model: self.comp_ids.update( self.to_server(PTModel(model=initial_model, persistor=model_persistor, locator=model_locator)) )
[docs] def set_up_client(self, target: str): self.to(id="event_to_fed", obj=self.convert_to_fed_event, target=target)