Source code for nvflare.private.fed.server.cmd_utils

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

from nvflare.apis.job_def import JobMetaKey
from nvflare.apis.server_engine_spec import ServerEngineSpec
from nvflare.fuel.hci.conn import Connection
from nvflare.fuel.hci.proto import MetaKey, MetaStatusValue, make_meta
from nvflare.fuel.hci.server.authz import PreAuthzReturnCode
from nvflare.fuel.hci.server.constants import ConnProps
from nvflare.private.fed.server.admin import FedAdminServer


[docs]class CommandUtil(object): TARGET_CLIENTS = "target_clients" TARGET_CLIENT_TOKENS = "target_client_tokens" TARGET_CLIENT_NAMES = "target_client_names" TARGET_TYPE = "target_type" TARGET_TYPE_CLIENT = "client" TARGET_TYPE_SERVER = "server" TARGET_TYPE_ALL = "all" JOB_ID = "job_id" JOB = "job"
[docs] def command_authz_required(self, conn: Connection, args: List[str]) -> PreAuthzReturnCode: return PreAuthzReturnCode.REQUIRE_AUTHZ
[docs] def authorize_client_operation(self, conn: Connection, args: List[str]) -> PreAuthzReturnCode: auth_args = [args[0], self.TARGET_TYPE_CLIENT] auth_args.extend(args[1:]) err = self.validate_command_targets(conn, auth_args[1:]) if err: conn.append_error(err, meta=make_meta(MetaStatusValue.INVALID_TARGET, info=err)) return PreAuthzReturnCode.ERROR return PreAuthzReturnCode.REQUIRE_AUTHZ
[docs] def validate_command_targets(self, conn: Connection, args: List[str]) -> str: """Validate specified args and determine and set target type and target names in the Connection. The args must be like this: target_type client_names ... where target_type is one of 'all', 'client', 'server' Args: conn: A Connection object. args: Specified arguments. Returns: An error message. It is empty "" if no error found. """ # return target type and a list of target names if len(args) < 1: return "missing target type (server or client)" target_type = args[0] conn.set_prop(self.TARGET_TYPE, target_type) if target_type == self.TARGET_TYPE_SERVER: return "" if target_type == self.TARGET_TYPE_CLIENT: client_names = args[1:] elif target_type == self.TARGET_TYPE_ALL: client_names = [] else: return "unknown target type {}".format(target_type) engine = conn.app_ctx if not isinstance(engine, ServerEngineSpec): raise TypeError("engine must be ServerEngineSpec but got {}".format(type(engine))) if len(client_names) == 0: # get all clients clients = engine.get_clients() else: clients, invalid_inputs = engine.validate_targets(client_names) if invalid_inputs: return "invalid client(s): {}".format(" ".join(invalid_inputs)) if target_type == self.TARGET_TYPE_CLIENT and not clients: return "no clients available" valid_tokens = [] client_names = [] all_clients = {} for c in clients: valid_tokens.append(c.token) client_names.append(c.name) all_clients[c.token] = c.name conn.set_prop(self.TARGET_CLIENT_TOKENS, valid_tokens) # if clients: # client_names = [c.name for c in clients] # else: # client_names = [] conn.set_prop(self.TARGET_CLIENT_NAMES, client_names) conn.set_prop(self.TARGET_CLIENTS, all_clients) return ""
[docs] def must_be_project_admin(self, conn: Connection, args: List[str]): role = conn.get_prop(ConnProps.USER_ROLE, "") if role not in ["project_admin"]: conn.append_error(f"Not authorized for {role}", meta=make_meta(MetaStatusValue.NOT_AUTHORIZED)) return PreAuthzReturnCode.ERROR else: return PreAuthzReturnCode.OK
[docs] def authorize_server_operation(self, conn: Connection, args: List[str]): err = self.validate_command_targets(conn, args[1:]) if err: conn.append_error(err, meta=make_meta(MetaStatusValue.INVALID_TARGET, info=err)) return PreAuthzReturnCode.ERROR target_type = conn.get_prop(self.TARGET_TYPE) if target_type == self.TARGET_TYPE_SERVER or target_type == self.TARGET_TYPE_ALL: return PreAuthzReturnCode.REQUIRE_AUTHZ else: return PreAuthzReturnCode.OK
[docs] def send_request_to_clients(self, conn, message): client_tokens = conn.get_prop(self.TARGET_CLIENT_TOKENS) if not client_tokens: return None requests = {} for token in client_tokens: requests.update({token: message}) admin_server: FedAdminServer = conn.server cmd_timeout = conn.get_prop(ConnProps.CMD_TIMEOUT) if not cmd_timeout: cmd_timeout = admin_server.timeout with admin_server.sai.new_context() as fl_ctx: replies = admin_server.send_requests(requests, fl_ctx, timeout_secs=cmd_timeout) return replies
[docs] @staticmethod def get_job_name(meta: dict) -> str: """Gets job name from job meta.""" name = meta.get(JobMetaKey.JOB_NAME) if not name: name = meta.get(JobMetaKey.JOB_FOLDER_NAME, "No name") return name
[docs] def process_replies_to_table(self, conn: Connection, replies): """Process the clients' replies and put in a table format. Args: conn: A Connection object. replies: replies from clients """ if not replies: conn.append_string("no responses from clients") table = conn.append_table(["Client", "Response"]) for r in replies: if r.reply: resp = r.reply.body else: resp = "" client_name = r.client_name if not client_name: clients = conn.get_prop(self.TARGET_CLIENTS) client_name = clients.get(r.client_token, "") table.add_row([client_name, resp])
def _process_replies_to_string(self, conn: Connection, replies) -> str: """Process the clients replies and put in a string format. Args: conn: A Connection object. replies: replies from clients Returns: A string response. """ response = "no responses from clients" client_replies = {} if replies: response = "" for r in replies: client_name = r.client_name response += "client:" + client_name if r.reply: response += " : " + r.reply.body + "\n" client_replies[client_name] = r.reply.body else: response += " : No replies\n" client_replies[client_name] = MetaStatusValue.NO_REPLY conn.update_meta({MetaKey.CLIENT_STATUS: client_replies}) return response