Source code for nvflare.metrics.job_metrics_collector

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from nvflare.apis.event_type import EventType
from nvflare.apis.fl_context import FLContext
from nvflare.app_common.app_event_type import AppEventType
from nvflare.metrics.metrics_collector import MetricsCollector
from nvflare.metrics.metrics_publisher import collect_metrics


[docs] class JobMetricsCollector(MetricsCollector): def __init__(self, tags: dict, streaming_to_server: bool = False): """ Args: tags: comma separated static tags. used to specify server, client, production, test etc. streaming_to_server: boolean to specify if metrics should be streamed to server """ super().__init__(tags=tags, streaming_to_server=streaming_to_server) # Job events self.single_events = [ EventType.SUBMIT_JOB, EventType.DEPLOY_JOB_TO_SERVER, EventType.DEPLOY_JOB_TO_CLIENT, EventType.BEFORE_CHECK_RESOURCE_MANAGER, EventType.BEFORE_SEND_ADMIN_COMMAND, # application AppEventType.INITIAL_MODEL_LOADED, AppEventType.BEFORE_TRAIN_TASK, AppEventType.AFTER_AGGREGATION, ] self.pair_events = { EventType.START_WORKFLOW: "_workflow", EventType.END_WORKFLOW: "_workflow", EventType.START_RUN: "_run", EventType.END_RUN: "_run", EventType.JOB_STARTED: "_job", EventType.JOB_COMPLETED: "_job", EventType.JOB_ABORTED: "_job", EventType.JOB_CANCELLED: "_job", EventType.BEFORE_PULL_TASK: "_pull_task", EventType.AFTER_PULL_TASK: "_pull_task", EventType.BEFORE_PROCESS_TASK_REQUEST: "_process_task", EventType.AFTER_PROCESS_TASK_REQUEST: "_process_task", EventType.BEFORE_PROCESS_SUBMISSION: "_process_submission", EventType.AFTER_PROCESS_SUBMISSION: "_process_submission", EventType.BEFORE_TASK_DATA_FILTER: "_data_filter", EventType.AFTER_TASK_DATA_FILTER: "_data_filter", EventType.BEFORE_TASK_RESULT_FILTER: "_result_filter", EventType.AFTER_TASK_RESULT_FILTER: "_result_filter", EventType.BEFORE_TASK_EXECUTION: "_task_execution", EventType.AFTER_TASK_EXECUTION: "_task_execution", EventType.ABORT_TASK: "_task_execution", EventType.BEFORE_SEND_TASK_RESULT: "_send_task_result", EventType.AFTER_SEND_TASK_RESULT: "_send_task_result", EventType.BEFORE_PROCESS_RESULT_OF_UNKNOWN_TASK: "_process_result_of_unknown_task", EventType.AFTER_PROCESS_RESULT_OF_UNKNOWN_TASK: "_process_result_of_unknown_task", # Application AppEventType.BEFORE_AGGREGATION: "_aggregation", AppEventType.END_AGGREGATION: "_aggregation", AppEventType.RECEIVE_BEST_MODEL: "_receive_best_model", AppEventType.BEFORE_TRAIN: "_train", AppEventType.AFTER_TRAIN: "_train", AppEventType.TRAIN_DONE: "_train_done", AppEventType.TRAINING_STARTED: "_training", AppEventType.TRAINING_FINISHED: "_training", AppEventType.ROUND_STARTED: "_round", AppEventType.ROUND_DONE: "_round", }
[docs] def handle_event(self, event: str, fl_ctx: FLContext): job_id = fl_ctx.get_job_id() tags: dict = self.tags tags["job_id"] = job_id super().collect_event_metrics(event=event, tags=tags, fl_ctx=fl_ctx)
[docs] def publish_metrics(self, metrics: dict, metric_name: str, tags: dict, fl_ctx: FLContext): collect_metrics(self, self.streaming_to_server, metrics, metric_name, tags, self.data_bus, fl_ctx)
[docs] def get_single_events(self): return self.single_events
[docs] def get_pair_events(self): return self.pair_events