Source code for nvflare.private.fed.server.info_coll_cmd

# Copyright (c) 2021-2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from typing import List

from nvflare.apis.fl_constant import AdminCommandNames, WorkspaceConstants
from nvflare.fuel.hci.conn import Connection
from nvflare.fuel.hci.reg import CommandModule, CommandModuleSpec, CommandSpec
from nvflare.private.defs import InfoCollectorTopic, RequestHeader
from nvflare.private.fed.server.admin import new_message
from nvflare.private.fed.server.server_engine_internal_spec import ServerEngineInternalSpec
from nvflare.widgets.info_collector import InfoCollector
from nvflare.widgets.widget import WidgetID

from .cmd_utils import CommandUtil


[docs]class InfoCollectorCommandModule(CommandModule, CommandUtil): """This class is for server side info collector commands. NOTE: we only support Server side info collector commands for now, due to the complexity of client-side process/child-process architecture. """ CONN_KEY_COLLECTOR = "collector"
[docs] def get_spec(self): return CommandModuleSpec( name="info", cmd_specs=[ CommandSpec( name=AdminCommandNames.SHOW_STATS, description="show current system stats for an actively running job", usage="show_stats job_id server|client", handler_func=self.show_stats, authz_func=self.authorize_info_collection, visible=True, ), CommandSpec( name=AdminCommandNames.SHOW_ERRORS, description="show latest errors in an actively running job", usage="show_errors job_id server|client", handler_func=self.show_errors, authz_func=self.authorize_info_collection, visible=True, ), CommandSpec( name=AdminCommandNames.RESET_ERRORS, description="reset errors", usage="reset_errors", handler_func=self.reset_errors, authz_func=self.authorize_info_collection, visible=True, ), ], )
[docs] def authorize_info_collection(self, conn: Connection, args: List[str]): if len(args) != 3: conn.append_error("syntax error: missing job_id and target") return False, None run_destination = args[1].lower() if not run_destination.startswith(WorkspaceConstants.WORKSPACE_PREFIX): conn.append_error("syntax error: run_destination must be run_XXX") return False, None job_id = run_destination[len(WorkspaceConstants.WORKSPACE_PREFIX) :] conn.set_prop(self.JOB_ID, job_id) engine = conn.app_ctx if not isinstance(engine, ServerEngineInternalSpec): raise TypeError("engine must be ServerEngineInternalSpec but got {}".format(type(engine))) collector = engine.get_widget(WidgetID.INFO_COLLECTOR) if not collector: conn.append_error("info collector not available") return False, None if not isinstance(collector, InfoCollector): conn.append_error("system error: info collector not right object") return False, None conn.set_prop(self.CONN_KEY_COLLECTOR, collector) run_info = engine.get_app_run_info(job_id) if not run_info: conn.append_string( "Cannot find job: {}. Please make sure the first arg following the command is a valid job_id.".format( job_id ) ) return False, None # return True, FLAuthzContext.new_authz_context( # site_names=['server'], # actions=[Action.VIEW]) auth_args = [args[0]] auth_args.extend(args[2:]) return self.authorize_view(conn, auth_args)
[docs] def show_stats(self, conn: Connection, args: List[str]): engine = conn.app_ctx if not isinstance(engine, ServerEngineInternalSpec): raise TypeError("engine must be ServerEngineInternalSpec but got {}".format(type(engine))) job_id = conn.get_prop(self.JOB_ID) target_type = args[2] if target_type == self.TARGET_TYPE_SERVER: result = engine.show_stats(job_id) conn.append_any(result) elif target_type == self.TARGET_TYPE_CLIENT: message = new_message(conn, topic=InfoCollectorTopic.SHOW_STATS, body="") message.set_header(RequestHeader.JOB_ID, job_id) replies = self.send_request_to_clients(conn, message) self._process_stats_replies(conn, replies)
# collector = conn.get_prop(self.CONN_KEY_COLLECTOR) # result = collector.get_run_stats() # conn.append_any(result)
[docs] def show_errors(self, conn: Connection, args: List[str]): engine = conn.app_ctx if not isinstance(engine, ServerEngineInternalSpec): raise TypeError("engine must be ServerEngineInternalSpec but got {}".format(type(engine))) job_id = conn.get_prop(self.JOB_ID) target_type = args[2] if target_type == self.TARGET_TYPE_SERVER: result = engine.get_errors(job_id) conn.append_any(result) elif target_type == self.TARGET_TYPE_CLIENT: message = new_message(conn, topic=InfoCollectorTopic.SHOW_ERRORS, body="") replies = self.send_request_to_clients(conn, message) self._process_stats_replies(conn, replies)
[docs] def reset_errors(self, conn: Connection, args: List[str]): job_id = conn.get_prop(self.JOB_ID) collector = conn.get_prop(self.CONN_KEY_COLLECTOR) collector.reset_errors() conn.append_string("errors reset")
def _process_stats_replies(self, conn, replies): if not replies: conn.append_error("no responses from clients") return engine = conn.app_ctx for r in replies: client_name = engine.get_client_name_from_token(r.client_token) conn.append_string(f"--- Client ---: {client_name}") try: body = json.loads(r.reply.body) conn.append_any(body) except BaseException: conn.append_string("Bad responses from clients")