Source code for nvflare.edge.simulation.devices.tp

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy

from nvflare.edge.constants import EdgeApiStatus
from nvflare.edge.simulation.config import ConfigParser
from nvflare.edge.simulation.device_task_processor import DeviceTaskProcessor
from nvflare.edge.simulation.simulated_device import DeviceFactory, SimulatedDevice
from nvflare.edge.web.models.device_info import DeviceInfo
from nvflare.edge.web.models.job_response import JobResponse
from nvflare.edge.web.models.task_response import TaskResponse
from nvflare.edge.web.models.user_info import UserInfo
from nvflare.fuel.utils.log_utils import get_obj_logger


[docs] class TaskProcessingDevice(SimulatedDevice): def __init__( self, device_id: str, endpoint_url: str, device_info: DeviceInfo, user_info: UserInfo, processor: DeviceTaskProcessor, ): SimulatedDevice.__init__(self, device_id) self.endpoint_url = endpoint_url self.device_info = device_info self.user_info = user_info self.processor = processor processor.device = self
[docs] def get_device_info(self): return self.device_info
[docs] def get_user_info(self): return self.user_info
[docs] def set_job( self, job_id: str, job_name: str, method: str, job_data: dict, ): super().set_job(job_id, job_name, method, job_data) job = JobResponse( status=EdgeApiStatus.OK, job_id=job_id, job_name=job_name, method=method, job_data=job_data, ) self.processor.setup(job)
[docs] def shutdown(self): self.processor.shutdown()
[docs] def do_task(self, task: TaskResponse) -> dict: return self.processor.process_task(task)
[docs] class TPDeviceFactory(DeviceFactory): def __init__(self, parser: ConfigParser): self.logger = get_obj_logger(self) self.parser = parser self.endpoint_url = parser.get_endpoint()
[docs] def make_device(self, device_id: str) -> SimulatedDevice: device_info = DeviceInfo(f"{device_id}", "flare_mobile", "1.0") user_info = UserInfo("demo_id", "demo_user") variables = {"device_id": device_id, "user_id": user_info.user_id} processor = self.parser.get_processor(variables) return TaskProcessingDevice( device_id=device_id, device_info=device_info, user_info=user_info, processor=processor, endpoint_url=self.endpoint_url, )
[docs] class TPODeviceFactory(DeviceFactory): def __init__(self, tpo: DeviceTaskProcessor): self.logger = get_obj_logger(self) self.tpo = tpo
[docs] def make_device(self, device_id: str) -> SimulatedDevice: device_info = DeviceInfo(f"{device_id}", "flare_mobile", "1.0") user_info = UserInfo("demo_id", "demo_user") processor = copy.deepcopy(self.tpo) return TaskProcessingDevice( device_id=device_id, device_info=device_info, user_info=user_info, processor=processor, endpoint_url="", # not used )