# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path
from abc import ABC
from typing import Dict, List
from nvflare.apis.filter import Filter
from nvflare.apis.fl_component import FLComponent
[docs]class BaseAppConfig(ABC):
"""BaseAppConfig holds the base essential component data for the ServerApp and ClientApp, including the
task_data_filters, task_result_filters, system components and used external scripts.
"""
def __init__(self) -> None:
super().__init__()
self.task_data_filters: [(List[str], Filter)] = []
self.task_result_filters: [(List[str], Filter)] = []
self.components: Dict[str, object] = {}
self.ext_scripts = []
self.handlers: [FLComponent] = []
[docs] def add_component(self, cid: str, component):
if cid in self.components.keys():
raise RuntimeError(f"Component with ID:{cid} already exist.")
self.components[cid] = component
if isinstance(component, FLComponent):
self.handlers.append(component)
[docs] def add_task_data_filter(self, tasks: List[str], filter: Filter):
self._add_task_filter(tasks, filter, self.task_data_filters)
[docs] def add_task_result_filter(self, tasks: List[str], filter: Filter):
self._add_task_filter(tasks, filter, self.task_result_filters)
[docs] def add_ext_script(self, ext_script: str):
if not isinstance(ext_script, str):
raise RuntimeError(f"ext_script must be type of str, but got {ext_script.__class__}")
if not os.path.exists(ext_script):
raise RuntimeError(f"Could not locate external script: {ext_script}")
if not ext_script.endswith(".py"):
raise RuntimeError(f"External script: {ext_script} must be a '.py' file.")
self.ext_scripts.append(ext_script)
def _add_task_filter(self, tasks, filter, filters):
if not isinstance(filter, Filter):
raise RuntimeError(f"filter must be type of Filter, but got {filter.__class__}")
for task in tasks:
for fd in filters:
if task in fd.tasks:
raise RuntimeError(f"Task {task} already defined in the task filters.")
filters.append((tasks, filter))