# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple
from nvflare.apis.shareable import Shareable
from nvflare.widgets.widget import Widget
from .client import Client
from .engine_spec import EngineSpec
from .fl_context import FLContext
from .fl_snapshot import RunSnapshot
from .job_def import Job
from .workspace import Workspace
[docs]
class ServerEngineSpec(EngineSpec, ABC):
[docs]
@abstractmethod
def fire_event(self, event_type: str, fl_ctx: FLContext):
pass
[docs]
@abstractmethod
def get_clients(self) -> List[Client]:
pass
[docs]
@abstractmethod
def sync_clients_from_main_process(self):
"""To fetch the participating clients from the main parent process
Returns: clients
"""
pass
[docs]
@abstractmethod
def update_job_run_status(self):
"""To update the job run status to parent process."""
pass
[docs]
@abstractmethod
def new_context(self) -> FLContext:
# the engine must use FLContextManager to create a new context!
pass
[docs]
@abstractmethod
def get_workspace(self) -> Workspace:
pass
[docs]
@abstractmethod
def add_component(self, component_id: str, component):
"""Add a component into the system.
Args:
component_id: component ID
component: component object
Returns:
"""
pass
[docs]
@abstractmethod
def get_component(self, component_id: str) -> object:
"""Retrieve the system component from the engine.
Args:
component_id: component ID
Returns:
component object
"""
pass
[docs]
@abstractmethod
def register_aux_message_handler(self, topic: str, message_handle_func):
"""Register aux message handling function with specified topics.
Exception is raised when:
a handler is already registered for the topic;
bad topic - must be a non-empty string
bad message_handle_func - must be callable
Implementation Note:
This method should simply call the ServerAuxRunner's register_aux_message_handler method.
Args:
topic: the topic to be handled by the func
message_handle_func: the func to handle the message. Must follow aux_message_handle_func_signature.
"""
pass
[docs]
@abstractmethod
def send_aux_request(
self,
targets: [],
topic: str,
request: Shareable,
timeout: float,
fl_ctx: FLContext,
optional=False,
secure=False,
) -> dict:
"""Send a request to specified clients via the aux channel.
Implementation: simply calls the AuxRunner's send_aux_request method.
Args:
targets: target clients. None or empty list means all clients.
topic: topic of the request.
request: request to be sent
timeout: number of secs to wait for replies. 0 means fire-and-forget.
fl_ctx: FL context
optional: whether this message is optional
secure: send the aux request in a secure way
Returns: a dict of replies (client name => reply Shareable)
"""
pass
[docs]
@abstractmethod
def multicast_aux_requests(
self,
topic: str,
target_requests: Dict[str, Shareable],
timeout: float,
fl_ctx: FLContext,
optional: bool = False,
secure: bool = False,
) -> dict:
"""Send requests to specified clients via the aux channel.
Implementation: simply calls the AuxRunner's multicast_aux_requests method.
Args:
topic: topic of the request
target_requests: requests of the target clients. Different target can have different request.
timeout: amount of time to wait for responses. 0 means fire and forget.
fl_ctx: FL context
optional: whether this request is optional
secure: whether to send the aux request in P2P secure
Returns: a dict of replies (client name => reply Shareable)
"""
pass
[docs]
def fire_and_forget_aux_request(
self, targets: [], topic: str, request: Shareable, fl_ctx: FLContext, optional=False, secure=False
) -> dict:
return self.send_aux_request(targets, topic, request, 0.0, fl_ctx, optional, secure=secure)
[docs]
@abstractmethod
def persist_components(self, fl_ctx: FLContext, completed: bool):
"""To persist the FL running components
Args:
fl_ctx: FLContext
completed: flag to indicate where the run is complete
Returns:
"""
pass
[docs]
@abstractmethod
def restore_components(self, snapshot: RunSnapshot, fl_ctx: FLContext):
"""To restore the FL components from the saved snapshot
Args:
snapshot: RunSnapshot
fl_ctx: FLContext
Returns:
"""
pass
[docs]
@abstractmethod
def start_client_job(self, job_id, client_sites, fl_ctx: FLContext):
"""To send the start client run commands to the clients
Args:
client_sites: client sites
job_id: job_id
fl_ctx: FLContext
Returns:
"""
pass
[docs]
@abstractmethod
def check_client_resources(
self, job: Job, resource_reqs: Dict[str, dict], fl_ctx: FLContext
) -> Dict[str, Tuple[bool, Optional[str]]]:
"""Sends the check_client_resources requests to the clients.
Args:
job: job object
resource_reqs: A dict of {client_name: resource requirements dict}
fl_ctx: FLContext
Returns:
A dict of {client_name: client_check_result}.
client_check_result is a tuple of (is_resource_enough, token);
is_resource_enough is a bool indicates whether there is enough resources;
token is for resource reservation / cancellation for this check request.
"""
pass
[docs]
@abstractmethod
def cancel_client_resources(
self, resource_check_results: Dict[str, Tuple[bool, str]], resource_reqs: Dict[str, dict], fl_ctx: FLContext
):
"""Cancels the request resources for the job.
Args:
resource_check_results: A dict of {client_name: client_check_result}
where client_check_result is a tuple of (is_resource_enough, resource reserve token if any)
resource_reqs: A dict of {client_name: resource requirements dict}
fl_ctx: FLContext
"""
pass
[docs]
@abstractmethod
def get_client_name_from_token(self, token: str) -> str:
"""Gets the client name from client login token.
Args:
token: client login token
Returns:
Client name
"""
pass
[docs]
def register_app_command(self, topic: str, cmd_func, *args, **kwargs):
pass