Source code for nvflare.app_opt.tracking.wandb.wandb_receiver

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time
from multiprocessing import Process, Queue
from typing import NamedTuple, Optional

import wandb

from nvflare.apis.analytix import AnalyticsData, AnalyticsDataType
from nvflare.apis.dxo import from_shareable
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.tracking.tracker_types import LogWriterName
from nvflare.app_common.widgets.streaming import AnalyticsReceiver


[docs] class WandBTask(NamedTuple): task_owner: str task_type: str task_data: dict step: int
[docs] class WandBReceiver(AnalyticsReceiver): def __init__(self, kwargs: dict, mode: str = "offline", events=None, process_timeout=10): if events is None: events = ["fed.analytix_log_stats"] super().__init__(events=events) self.fl_ctx = None self.mode = mode self.kwargs = kwargs self.queues = {} self.processes = {} self.process_timeout = process_timeout # os.environ["WANDB_API_KEY"] = YOUR_KEY_HERE os.environ["WANDB_MODE"] = self.mode
[docs] def job(self, queue): cnt = 0 run = None try: while True: wandb_task: WandBTask = queue.get() cnt += 1 if wandb_task.task_type == "stop": self.log_info(self.fl_ctx, f"received request to stop at {wandb_task.task_owner} for run {run}") break elif wandb_task.task_type == "init": self.log_info(self.fl_ctx, f"received request to init at {wandb_task.task_owner}") run = wandb.init(**wandb_task.task_data) elif wandb_task.task_type == "log": if cnt % 500 == 0: self.log_info(self.fl_ctx, f"process task : {wandb_task}, cnt = {cnt}") if wandb_task.step: wandb.log(wandb_task.task_data, wandb_task.step) else: wandb.log(wandb_task.task_data) finally: if run: run.finish()
[docs] def initialize(self, fl_ctx: FLContext): self.fl_ctx = fl_ctx sites = fl_ctx.get_engine().get_clients() run_group_id = str(int(time.time())) run_name = self.kwargs["name"] job_id_tag = self.get_job_id_tag(run_group_id) wand_config = self.kwargs.get("config", {}) if self.mode == "online": try: wandb.login(timeout=1, verify=True) except Exception as e: self.log_error(self.fl_ctx, f"Unsuccessful login: {e}. Using wandb offline mode.") self.mode = "offline" for site in sites: self.log_info(self.fl_ctx, f"initialize WandB run for site {site.name}") self.kwargs["name"] = f"{site.name}-{job_id_tag[:6]}-{run_name}" self.kwargs["group"] = f"{run_name}-{job_id_tag}" self.kwargs["mode"] = self.mode wand_config["job_id"] = job_id_tag wand_config["client"] = site.name wand_config["run_name"] = run_name self.check_kwargs(self.kwargs) q = Queue() wandb_task = WandBTask(task_owner=site.name, task_type="init", task_data=self.kwargs, step=0) # q.put_nowait(wandb_task) q.put(wandb_task) self.queues[site.name] = q p = Process(target=self.job, args=(q,)) self.processes[site.name] = p p.start() time.sleep(0.2)
[docs] def get_job_id_tag(self, group_id: str) -> str: job_id = self.fl_ctx.get_job_id() if job_id == "simulate_job": # For simulator, the job ID is the same so we use a string of the time for the job_id_tag job_id = group_id return job_id
[docs] def save(self, fl_ctx: FLContext, shareable: Shareable, record_origin: str): dxo = from_shareable(shareable) data = AnalyticsData.from_dxo(dxo, receiver=LogWriterName.WANDB) if not data: return q: Optional[Queue] = self.get_job_queue(record_origin) if q: if data.data_type == AnalyticsDataType.PARAMETER or data.data_type == AnalyticsDataType.METRIC: log_data = {data.tag: data.value} q.put(WandBTask(task_owner=record_origin, task_type="log", task_data=log_data, step=data.step)) elif data.data_type == AnalyticsDataType.PARAMETERS or data.data_type == AnalyticsDataType.METRICS: q.put(WandBTask(task_owner=record_origin, task_type="log", task_data=data.value, step=data.step))
[docs] def finalize(self, fl_ctx: FLContext): """Called at EventType.END_RUN. Args: fl_ctx (FLContext): the FLContext """ for site in self.processes: self.log_info(self.fl_ctx, f"inform {site} to stop") q: Optional[Queue] = self.get_job_queue(site) q.put(WandBTask(task_owner=site, task_type="stop", task_data={}, step=0)) for site in self.processes: p = self.processes[site] p.join(self.process_timeout) p.terminate()
[docs] def get_job_queue(self, record_origin): return self.queues.get(record_origin, None)
[docs] def check_kwargs(self, kwargs): if "project" not in kwargs: raise ValueError("must provide `project' value") if "group" not in kwargs: raise ValueError("must provide `group' value") if "job_type" not in kwargs: raise ValueError("must provide `job_type' value")