# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from nvflare.fuel.f3.cellnet.fqcn import FQCN
CHANNEL = "flare_agent"
TOPIC_GET_TASK = "get_task"
TOPIC_SUBMIT_RESULT = "submit_result"
TOPIC_HEARTBEAT = "heartbeat"
TOPIC_HELLO = "hello"
TOPIC_BYE = "bye"
TOPIC_ABORT = "abort"
JOB_META_KEY_AGENT_ID = "agent_id"
[docs]class RC:
OK = "OK"
BAD_TASK_DATA = "BAD_TASK_DATA"
EXECUTION_EXCEPTION = "EXECUTION_EXCEPTION"
[docs]class PayloadKey:
DATA = "data"
META = "meta"
[docs]class Task:
def __init__(self, task_name: str, task_id: str, meta: dict, data):
self.task_name = task_name
self.task_id = task_id
self.meta = meta
self.data = data
def __str__(self):
return f"'{self.task_name} {self.task_id}'"
[docs]class TaskResult:
def __init__(self, meta: dict, data, return_code=RC.OK):
if not meta:
meta = {}
if not isinstance(meta, dict):
raise TypeError(f"meta must be dict but got {type(meta)}")
if not data:
data = {}
if not isinstance(return_code, str):
raise TypeError(f"return_code must be str but got {type(return_code)}")
self.return_code = return_code
self.meta = meta
self.data = data
[docs]class AgentClosed(Exception):
pass
[docs]class CallStateError(Exception):
pass
[docs]class FlareAgent(ABC):
[docs] @abstractmethod
def start(self):
pass
[docs] @abstractmethod
def stop(self):
pass
[docs] @abstractmethod
def get_task(self, timeout=None):
"""Get a task from FLARE. This is a blocking call.
If timeout is specified, this call is blocked only for the specified amount of time.
If timeout is not specified, this call is blocked forever until a task is received or agent is closed.
Args:
timeout: amount of time to block
Returns: None if no task is available during before timeout; or a Task object if task is available.
Raises:
AgentClosed exception if the agent is closed before timeout.
CallStateError exception if the call is not made properly.
Note: the application must make the call only when it is just started or after a previous task's result
has been submitted.
"""
pass
[docs] def submit_result(self, result: TaskResult) -> bool:
"""Submit the result of the current task.
This is a blocking call. The agent will try to send the result to flare site until it is successfully sent or
the task is aborted or the agent is closed.
Args:
result: result to be submitted
Returns: whether the result is submitted successfully
Raises: the CallStateError exception if the submit_result call is not made properly.
Notes: the application must only make this call after the received task is processed. The call can only be
made a single time regardless whether the submission is successful.
"""
pass
[docs]def agent_site_fqcn(site_name: str, agent_id: str, job_id=None):
if not job_id:
return f"{site_name}--{agent_id}"
else:
return FQCN.join([site_name, job_id, agent_id])