Source code for nvflare.fuel.hci.client.api

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import os
import shutil
import threading
import time
import traceback
from datetime import datetime
from pathlib import Path
from typing import List, Optional

import nvflare.fuel.f3.streaming.file_downloader as downloader
from nvflare.apis.fl_constant import ConnectionSecurity, FLContextKey, ProcessType, ReservedKey, ReturnCode
from nvflare.apis.fl_context import FLContext, FLContextManager
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from nvflare.apis.streaming import ConsumerFactory, ObjectProducer, StreamableEngine, StreamContext
from nvflare.apis.utils.decomposers import flare_decomposers
from nvflare.app_common.streamers.file_streamer import FileStreamer
from nvflare.fuel.common.excepts import ConfigError
from nvflare.fuel.f3.cellnet.cell import Cell
from nvflare.fuel.f3.cellnet.defs import CellChannel, MessageHeaderKey
from nvflare.fuel.f3.cellnet.defs import ReturnCode as CellReturnCode
from nvflare.fuel.f3.cellnet.fqcn import FQCN
from nvflare.fuel.f3.cellnet.net_agent import NetAgent
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.f3.message import Message as CellMessage
from nvflare.fuel.hci.client.event import EventContext, EventHandler, EventPropKey, EventType
from nvflare.fuel.hci.cmd_arg_utils import split_to_args
from nvflare.fuel.hci.conn import Connection
from nvflare.fuel.hci.proto import (
    ConfirmMethod,
    InternalCommands,
    MetaKey,
    ProtoKey,
    ReplyKeyword,
    StreamChannel,
    StreamTopic,
    make_error,
    validate_proto,
)
from nvflare.fuel.hci.reg import CommandEntry, CommandModule, CommandRegister
from nvflare.fuel.hci.table import Table
from nvflare.fuel.sec.authn import set_add_auth_headers_filters
from nvflare.fuel.utils.admin_name_utils import new_admin_client_name
from nvflare.fuel.utils.log_utils import get_obj_logger
from nvflare.private.aux_runner import AuxMsgTarget, AuxRunner
from nvflare.private.defs import ClientType
from nvflare.private.fed.authenticator import Authenticator, validate_auth_headers
from nvflare.private.fed.utils.identity_utils import IdentityAsserter, TokenVerifier, get_cn_from_cert, load_cert_file
from nvflare.private.stream_runner import HeaderKey, ObjectStreamer
from nvflare.security.logging import secure_format_exception, secure_log_traceback

from .api_spec import (
    AdminAPISpec,
    AdminConfigKey,
    CommandContext,
    CommandCtxKey,
    CommandInfo,
    ReplyProcessor,
    UidSource,
)
from .api_status import APIStatus

_CMD_TYPE_UNKNOWN = 0
_CMD_TYPE_CLIENT = 1
_CMD_TYPE_SERVER = 2

MAX_AUTO_LOGIN_TRIES = 300
AUTO_LOGIN_INTERVAL = 1.5


[docs] class FileWaiter(threading.Event): def __init__(self, tx_id): super().__init__() self.tx_id = tx_id self.stream_ctx = None self.last_progress_time = time.time()
[docs] def get_stream_ctx(self): return self.stream_ctx
[docs] class ResultKey(object): STATUS = ProtoKey.STATUS DETAILS = ProtoKey.DETAILS META = ProtoKey.META
class _ServerReplyJsonProcessor(object): def __init__(self, ctx: CommandContext): if not isinstance(ctx, CommandContext): raise TypeError(f"ctx is not an instance of CommandContext. but get {type(ctx)}") self.ctx = ctx def process_server_reply(self, resp): """Process the server reply and store the status/details into API's `command_result` NOTE: this func is used for receive_and_process(), which is defined by conn! This method does not tale CommandContext! Args: resp: The raw response that returns by the server. """ api = self.ctx.get_api() api.debug("Server Reply: {}".format(resp)) ctx = self.ctx # this resp is what is usually directly used to return, straight from server ctx.set_command_result(resp) reply_processor = ctx.get_reply_processor() if reply_processor is None: reply_processor = _DefaultReplyProcessor() reply_processor.reply_start(ctx, resp) if resp is not None: data = resp[ProtoKey.DATA] for item in data: it = item[ProtoKey.TYPE] if it == ProtoKey.STRING: reply_processor.process_string(ctx, item[ProtoKey.DATA]) elif it == ProtoKey.SUCCESS: reply_processor.process_success(ctx, item[ProtoKey.DATA]) elif it == ProtoKey.ERROR: reply_processor.process_error(ctx, item[ProtoKey.DATA]) break elif it == ProtoKey.TABLE: table = Table(None) table.set_rows(item[ProtoKey.ROWS]) reply_processor.process_table(ctx, table) elif it == ProtoKey.DICT: reply_processor.process_dict(ctx, item[ProtoKey.DATA]) elif it == ProtoKey.TOKEN: reply_processor.process_token(ctx, item[ProtoKey.DATA]) elif it == ProtoKey.SHUTDOWN: reply_processor.process_shutdown(ctx, item[ProtoKey.DATA]) break else: reply_processor.protocol_error(ctx, "Invalid item type: " + it) break meta = resp.get(ProtoKey.META) if meta: ctx.set_meta(meta) else: reply_processor.protocol_error(ctx, "Protocol Error") reply_processor.reply_done(ctx) class _DefaultReplyProcessor(ReplyProcessor): def process_shutdown(self, ctx: CommandContext, msg: str): api = ctx.get_prop(CommandCtxKey.API) api.shutdown_received = True api.shutdown_msg = msg class _LoginReplyProcessor(ReplyProcessor): """Reply processor for handling login and setting the token for the admin client.""" def process_string(self, ctx: CommandContext, item: str): api = ctx.get_api() api.login_result = item def process_token(self, ctx: CommandContext, token: str): api = ctx.get_api() api.token = token class _CmdListReplyProcessor(ReplyProcessor): """Reply processor to register available commands after getting back a table of commands from the server.""" def process_table(self, ctx: CommandContext, table: Table): api = ctx.get_api() for i in range(len(table.rows)): if i == 0: # this is header continue row = table.rows[i] if len(row) < 5: return scope = row[0] cmd_name = row[1] desc = row[2] usage = row[3] confirm = row[4] client_cmd = None visible = True if len(row) > 5: client_cmd = row[5] if len(row) > 6: visible = row[6].lower() in ["true", "yes"] # if confirm == 'auth' and not client.require_login: # the user is not authenticated - skip this command # continue api.server_cmd_reg.add_command( scope_name=scope, cmd_name=cmd_name, desc=desc, usage=usage, handler=None, authz_func=None, visible=visible, confirm=confirm, client_cmd=client_cmd, map_client_cmd=True, ) api.server_cmd_received = True
[docs] class AdminAPI(AdminAPISpec, StreamableEngine): def __init__( self, user_name: str, admin_config: dict, cmd_modules: Optional[List] = None, debug: bool = False, auto_login_max_tries: int = 15, event_handlers=None, ): """API to keep certs, keys and connection information and to execute admin commands through do_command. Args: cmd_modules: command modules to load and register. Note that FileTransferModule is initialized here with upload_dir and download_dir if cmd_modules is None. user_name: Username to authenticate with FL server debug: Whether to print debug messages, which can help with diagnosing problems. False by default. auto_login_max_tries: maximum number of tries to auto-login. """ super().__init__() if cmd_modules is None: from .file_transfer import FileTransferModule upload_dir = admin_config.get(AdminConfigKey.UPLOAD_DIR, "transfer") download_dir = admin_config.get(AdminConfigKey.DOWNLOAD_DIR, "transfer") cmd_modules = [FileTransferModule(upload_dir=upload_dir, download_dir=download_dir)] elif not isinstance(cmd_modules, list): raise TypeError("cmd_modules must be a list, but got {}".format(type(cmd_modules))) else: for m in cmd_modules: if not isinstance(m, CommandModule): raise TypeError( "cmd_modules must be a list of CommandModule, but got element of type {}".format(type(m)) ) if not event_handlers: event_handlers = [] if event_handlers: if not isinstance(event_handlers, list): raise TypeError(f"event_handlers must be a list but got {type(event_handlers)}") for h in event_handlers: if not isinstance(h, EventHandler): raise TypeError(f"item in event_handlers must be EventHandler but got {type(h)}") for m in cmd_modules: if isinstance(m, EventHandler): event_handlers.append(m) self.logger = get_obj_logger(self) self.conn_sec = admin_config.get(AdminConfigKey.CONNECTION_SECURITY) self.project_name = admin_config.get(AdminConfigKey.PROJECT_NAME) self.server_identity = admin_config.get(AdminConfigKey.SERVER_IDENTITY, "server") self.scheme = admin_config.get(AdminConfigKey.CONNECTION_SCHEME, "grpc") self.ca_cert = admin_config.get(AdminConfigKey.CA_CERT) self.client_cert = admin_config.get(AdminConfigKey.CLIENT_CERT) self.client_key = admin_config.get(AdminConfigKey.CLIENT_KEY) self.uid_source = admin_config.get(AdminConfigKey.UID_SOURCE, UidSource.USER_INPUT) self.host = admin_config.get(AdminConfigKey.HOST, "localhost") self.port = admin_config.get(AdminConfigKey.PORT, 8002) self.default_login_timeout = admin_config.get(AdminConfigKey.LOGIN_TIMEOUT, 10.0) self.file_download_progress_timeout = admin_config.get(AdminConfigKey.FILE_DOWNLOAD_PROGRESS_TIMEOUT, 5.0) self.authenticate_msg_timeout = admin_config.get(AdminConfigKey.AUTHENTICATE_MSG_TIMEOUT, 5.0) self.user_name = user_name self.event_handlers = event_handlers if not self.ca_cert: raise ConfigError("missing CA Cert file name") if not self.client_cert: raise ConfigError("missing Client Cert file name") if not self.client_key: raise ConfigError("missing Client Key file name") if self.uid_source == UidSource.CERT: # We'll find the username from the client cert cert = load_cert_file(self.client_cert) self.user_name = get_cn_from_cert(cert) if not self.user_name: raise Exception("user_name is required.") if debug: self._debug = debug else: self._debug = admin_config.get(AdminConfigKey.WITH_DEBUG, False) self.cmd_timeout = None # for login self.token = None self.login_result = None self.server_cmd_reg = CommandRegister(app_ctx=self) self.client_cmd_reg = CommandRegister(app_ctx=self) self.server_cmd_received = False self.all_cmds = [] self.cmd_modules = cmd_modules # for shutdown self.shutdown_received = False self.shutdown_msg = None self.server_sess_active = False self.shutdown_asked = False self.sess_monitor_thread = None self.sess_monitor_active = False # create the FSM for session monitoring if auto_login_max_tries < 0 or auto_login_max_tries > MAX_AUTO_LOGIN_TRIES: raise ValueError(f"auto_login_max_tries is out of range: [0, {MAX_AUTO_LOGIN_TRIES}]") self.auto_login_max_tries = auto_login_max_tries self.closed = False self.in_logout = False self.cell = None self.aux_runner = None self.object_streamer = None self.fl_ctx_mgr = FLContextManager( engine=self, identity_name=self.user_name, private_stickers={FLContextKey.PROCESS_TYPE: ProcessType.CLIENT_PARENT}, ) self.file_download_waiters = {} # tx_id => Threading.Event
[docs] def new_context(self): return self.fl_ctx_mgr.new_context()
[docs] def connect(self, timeout=None): if timeout is not None: # validate provided timeout value if not isinstance(timeout, (int, float)): raise ValueError(f"timeout must be a number but got {type(timeout)}") if timeout <= 0: raise ValueError(f"timeout must be a number > 0 but got {timeout}") else: # use value configured in admin config timeout = self.default_login_timeout print("Connecting to FLARE ...") if self.cell: return my_fqcn = new_admin_client_name() credentials = { DriverParams.CA_CERT.value: self.ca_cert, DriverParams.CLIENT_CERT.value: self.client_cert, DriverParams.CLIENT_KEY.value: self.client_key, } root_url = f"{self.scheme}://{self.host}:{self.port}" secure_conn = True if self.conn_sec: conn_sec = self.conn_sec.lower() credentials[DriverParams.CONNECTION_SECURITY.value] = conn_sec if conn_sec == ConnectionSecurity.CLEAR: secure_conn = False flare_decomposers.register() self.debug(f"Creating cell: {my_fqcn=} {root_url=} {secure_conn=} {credentials=}") self.cell = Cell( fqcn=my_fqcn, root_url=root_url, secure=secure_conn, credentials=credentials, create_internal_listener=False, parent_url=None, ) self.cell.register_request_cb( channel=CellChannel.HCI, topic="SESSION_EXPIRED", cb=self._handle_session_expired, ) NetAgent(self.cell) self.cell.start() # authenticate authenticator = Authenticator( cell=self.cell, project_name=self.project_name, client_name=self.user_name, client_type=ClientType.ADMIN, expected_sp_identity=self.server_identity, secure_mode=True, # always True to authenticate the cell endpoint! root_cert_file=self.ca_cert, private_key_file=self.client_key, cert_file=self.client_cert, msg_timeout=self.authenticate_msg_timeout, retry_interval=1.0, timeout=timeout, ) abort_signal = Signal() shared_fl_ctx = FLContext() shared_fl_ctx.set_public_props({ReservedKey.IDENTITY_NAME: self.user_name}) token, token_signature, ssid, token_verifier = authenticator.authenticate( shared_fl_ctx=shared_fl_ctx, abort_signal=abort_signal, ) if not isinstance(token_verifier, TokenVerifier): raise RuntimeError(f"expect token_verifier to be TokenVerifier but got {type(token_verifier)}") set_add_auth_headers_filters(self.cell, self.user_name, token, token_signature, ssid) self.cell.core_cell.add_incoming_filter( channel="*", topic="*", cb=validate_auth_headers, token_verifier=token_verifier, logger=self.logger, ) self.debug(f"Successfully authenticated to {self.server_identity}: {token=} {ssid=}") self.aux_runner = AuxRunner(self) self.object_streamer = ObjectStreamer(self.aux_runner) self.cell.register_request_cb( channel=CellChannel.AUX_COMMUNICATION, topic="*", cb=self._handle_aux_message, )
def _handle_aux_message(self, request: CellMessage) -> CellMessage: assert isinstance(request, CellMessage), "request must be CellMessage but got {}".format(type(request)) data = request.payload topic = request.get_header(MessageHeaderKey.TOPIC) with self.new_context() as fl_ctx: reply = self.aux_runner.dispatch(topic=topic, request=data, fl_ctx=fl_ctx) if reply is not None: return_message = CellMessage({}, reply) return_message.set_header(MessageHeaderKey.RETURN_CODE, CellReturnCode.OK) else: return_message = CellMessage({}, None) return return_message
[docs] def download_file(self, source_fqcn: str, ref_id: str, file_name: str): err, file_path = downloader.download_file( cell=self.cell, ref_id=ref_id, from_fqcn=source_fqcn, per_request_timeout=self.file_download_progress_timeout, ) if err: print(f"failed to receive file {file_name}: {err}") return None file_stats = os.stat(file_path) num_bytes_received = file_stats.st_size Path(os.path.dirname(file_name)).mkdir(parents=True, exist_ok=True) shutil.move(file_path, file_name) return num_bytes_received
[docs] def get_cell(self): return self.cell
def _handle_session_expired(self, message: CellMessage): self.debug("received session timeout from server") self.close() self.fire_session_event(EventType.SESSION_TIMEOUT, message.payload)
[docs] def debug(self, msg): if self._debug: print(f"DEBUG: {msg}")
[docs] def fire_event(self, event_type: str, ctx: EventContext): self.debug(f"firing event {event_type}") if self.event_handlers: for h in self.event_handlers: h.handle_event(event_type, ctx)
[docs] def set_command_timeout(self, timeout: float): if not isinstance(timeout, (int, float)): raise TypeError(f"timeout must be a number but got {type(timeout)}") timeout = float(timeout) if timeout <= 0.0: raise ValueError(f"invalid timeout value {timeout} - must be > 0.0") self.cmd_timeout = timeout
[docs] def unset_command_timeout(self): self.cmd_timeout = None
def _new_event_context(self): ctx = EventContext() ctx.set_prop(EventPropKey.USER_NAME, self.user_name) ctx.set_prop(EventPropKey.API, self) return ctx
[docs] def fire_session_event(self, event_type: str, msg: str = ""): ctx = self._new_event_context() if msg: ctx.set_prop(EventPropKey.MSG, msg) self.fire_event(event_type, ctx)
def _try_login(self): resp = None for i in range(self.auto_login_max_tries): try: self.fire_session_event(EventType.TRYING_LOGIN, "Trying to login, please wait ...") except Exception as ex: print(f"exception handling event {EventType.TRYING_LOGIN}: {secure_format_exception(ex)}") return { ResultKey.STATUS: APIStatus.ERROR_RUNTIME, ResultKey.DETAILS: f"exception handling event {EventType.TRYING_LOGIN}", } resp = self._user_login() status = resp.get(ResultKey.STATUS) if status in [APIStatus.SUCCESS, APIStatus.ERROR_AUTHENTICATION, APIStatus.ERROR_CERT]: if status == APIStatus.SUCCESS: self.fire_session_event(EventType.LOGIN_SUCCESS) else: self.fire_session_event(EventType.LOGIN_FAILURE) return resp time.sleep(AUTO_LOGIN_INTERVAL) if resp is None: resp = { ResultKey.STATUS: APIStatus.ERROR_RUNTIME, ResultKey.DETAILS: f"Auto login failed after {self.auto_login_max_tries} tries", } self.fire_session_event(EventType.LOGIN_FAILURE) return resp
[docs] def login(self): try: self.fire_session_event(EventType.BEFORE_LOGIN) result = self._try_login() self.debug(f"login result is {result}") except Exception as e: result = { ResultKey.STATUS: APIStatus.ERROR_RUNTIME, ResultKey.DETAILS: f"Exception occurred ({secure_format_exception(e)}) when trying to login - please try later", } return result
def _load_client_cmds_from_modules(self, cmd_modules): if cmd_modules: for m in cmd_modules: self.client_cmd_reg.register_module(m, include_invisible=True) def _load_client_cmds_from_module_specs(self, cmd_module_specs): if cmd_module_specs: for m in cmd_module_specs: self.client_cmd_reg.register_module_spec(m, include_invisible=True)
[docs] def register_command(self, cmd_entry): self.all_cmds.append(cmd_entry.name)
[docs] def logout(self): """Send logout command to server.""" if self.in_logout: return None self.in_logout = True try: resp = self.server_execute(InternalCommands.LOGOUT) finally: # make sure to close self.close() return resp
[docs] def close(self): # this method can be called multiple times if self.closed: return self.closed = True self.server_sess_active = False self.shutdown_asked = True self.shutdown_streamer() if self.cell: self.cell.stop()
def _get_command_list_from_server(self) -> bool: self.server_cmd_received = False self.server_execute(InternalCommands.GET_CMD_LIST, _CmdListReplyProcessor()) self.server_cmd_reg.finalize(self.register_command) if not self.server_cmd_received: return False return True def _after_login(self) -> dict: result = self._get_command_list_from_server() if not result: return { ResultKey.STATUS: APIStatus.ERROR_RUNTIME, ResultKey.DETAILS: "Can't fetch command list from server.", } # prepare client modules # we may have additional dynamically created cmd modules based on server commands extra_module_specs = [] if self.server_cmd_reg.mapped_cmds: for c in self.server_cmd_reg.mapped_cmds: for m in self.cmd_modules: new_module_spec = m.generate_module_spec(c) if new_module_spec is not None: extra_module_specs.append(new_module_spec) self._load_client_cmds_from_modules(self.cmd_modules) if extra_module_specs: self._load_client_cmds_from_module_specs(extra_module_specs) self.client_cmd_reg.finalize(self.register_command) self.server_sess_active = True return {ResultKey.STATUS: APIStatus.SUCCESS, ResultKey.DETAILS: "Login success"}
[docs] def is_ready(self) -> bool: """Whether the API is ready for executing commands.""" return self.server_sess_active
def _user_login(self): """Login user Returns: A dict of login status and details """ command = f"{InternalCommands.CERT_LOGIN} {self.user_name}" id_asserter = IdentityAsserter(private_key_file=self.client_key, cert_file=self.client_cert) cn_signature = id_asserter.sign_common_name(nonce="") headers = { "user_name": self.user_name, "cert": id_asserter.cert_data, "signature": cn_signature, } self.login_result = None self.server_execute(command, _LoginReplyProcessor(), headers=headers) if self.login_result is None: return { ResultKey.STATUS: APIStatus.ERROR_RUNTIME, ResultKey.DETAILS: "Communication Error - please try later", } elif self.login_result == "REJECT": return { ResultKey.STATUS: APIStatus.ERROR_AUTHENTICATION, ResultKey.DETAILS: "Incorrect user name or password", } return self._after_login() def _send_to_cell(self, ctx: CommandContext): command = ctx.get_command() json_processor = ctx.get_json_processor() process_json_func = json_processor.process_server_reply conn = Connection() conn.append_command(command) if self.token: conn.append_token(self.token) if self.cmd_timeout: conn.update_meta({MetaKey.CMD_TIMEOUT: self.cmd_timeout}) custom_props = ctx.get_custom_props() if custom_props: conn.update_meta({MetaKey.CUSTOM_PROPS: custom_props}) cmd_props = ctx.get_command_props() if cmd_props: conn.update_meta({MetaKey.CMD_PROPS: cmd_props}) timeout = self.cmd_timeout if not timeout: timeout = 5.0 requester = ctx.get_requester() if requester: try: reply = requester.send_request(self, conn, ctx) except: traceback.print_exc() process_json_func(make_error(f"{type(requester)} failed to send request to Admin Server")) return else: request = CellMessage(payload=conn.close(), headers=ctx.get_command_headers()) cell_reply = self.cell.send_request( channel=CellChannel.HCI, topic="command", target=FQCN.ROOT_SERVER, request=request, timeout=timeout, ) reply = cell_reply.payload if reply: try: json_data = validate_proto(reply) process_json_func(json_data) except: traceback.print_exc() process_json_func(make_error(f"{ReplyKeyword.COMM_FAILURE} with Admin Server")) def _try_command(self, cmd_ctx: CommandContext): """Try to execute a command on server side. Args: cmd_ctx: The command to execute. """ self.debug(f"sending command '{cmd_ctx.get_command()}'") json_processor = _ServerReplyJsonProcessor(cmd_ctx) process_json_func = json_processor.process_server_reply cmd_ctx.set_json_processor(json_processor) event_ctx = self._new_event_context() event_ctx.set_prop(EventPropKey.CMD_NAME, cmd_ctx.get_command_name()) event_ctx.set_prop(EventPropKey.CMD_CTX, cmd_ctx) try: self.fire_event(EventType.BEFORE_EXECUTE_CMD, event_ctx) except Exception as ex: secure_log_traceback() process_json_func( make_error(f"exception handling event {EventType.BEFORE_EXECUTE_CMD}: {secure_format_exception(ex)}") ) return # see whether any event handler has set "custom_props" custom_props = event_ctx.get_prop(EventPropKey.CUSTOM_PROPS) if custom_props: cmd_ctx.set_custom_props(custom_props) try: self._send_to_cell(cmd_ctx) except Exception as e: if self._debug: secure_log_traceback() traceback.print_exc() process_json_func( make_error(f"{ReplyKeyword.COMM_FAILURE} with Admin Server: {secure_format_exception(e)}") ) def _get_command_detail(self, command): """Get command details Args: command (str): command Returns: tuple of (cmd_type, cmd_name, args, entries) """ args = split_to_args(command) cmd_name = args[0] # check client side commands entries = self.client_cmd_reg.get_command_entries(cmd_name) if len(entries) > 0: return _CMD_TYPE_CLIENT, cmd_name, args, entries # check server side commands entries = self.server_cmd_reg.get_command_entries(cmd_name) if len(entries) > 0: return _CMD_TYPE_SERVER, cmd_name, args, entries return _CMD_TYPE_UNKNOWN, cmd_name, args, None
[docs] def check_command(self, command: str) -> CommandInfo: """Checks the specified command for processing info Args: command: command to be checked Returns: command processing info """ cmd_type, cmd_name, _, entries = self._get_command_detail(command) if cmd_type == _CMD_TYPE_UNKNOWN: return CommandInfo.UNKNOWN if len(entries) > 1: return CommandInfo.AMBIGUOUS ent = entries[0] assert isinstance(ent, CommandEntry) if ent.confirm == ConfirmMethod.AUTH: return CommandInfo.CONFIRM_AUTH elif ent.confirm == ConfirmMethod.YESNO: return CommandInfo.CONFIRM_YN else: return CommandInfo.OK
def _new_command_context(self, command, args, ent: CommandEntry): ctx = CommandContext() ctx.set_api(self) ctx.set_command(command) ctx.set_command_args(args) ctx.set_command_entry(ent) return ctx def _do_client_command(self, command, args, ent: CommandEntry): ctx = self._new_command_context(command, args, ent) return_result = ent.handler(args, ctx) result = ctx.get_command_result() if return_result: return return_result if result is None: return {ResultKey.STATUS: APIStatus.ERROR_RUNTIME, ResultKey.DETAILS: "Client did not respond"} return result
[docs] def upload_file(self, file_name: str, conn: Connection): stream_ctx = {"conn_data": conn.close()} with self.new_context() as fl_ctx: rc, replies = FileStreamer.stream_file( channel=StreamChannel.UPLOAD, topic=StreamTopic.FOLDER, stream_ctx=stream_ctx, file_name=file_name, fl_ctx=fl_ctx, targets=[FQCN.ROOT_SERVER], # to server ) if rc != ReturnCode.OK: self.logger.error(f"failed to stream file to server: {rc}") return None reply = replies.get(FQCN.ROOT_SERVER) assert isinstance(reply, Shareable) end_result = reply.get_header(HeaderKey.END_RESULT) return end_result
[docs] def do_command(self, command: str, props=None): """A convenient method to call commands using string. Args: command (str): command props: additional props Returns: Object containing status and details (or direct response from server, which originally was just time and data) """ cmd_type, cmd_name, args, entries = self._get_command_detail(command) if cmd_type == _CMD_TYPE_UNKNOWN: return { ResultKey.STATUS: APIStatus.ERROR_SYNTAX, ResultKey.DETAILS: f"Command {cmd_name} not found", } if len(entries) > 1: return { ResultKey.STATUS: APIStatus.ERROR_SYNTAX, ResultKey.DETAILS: f"Ambiguous command {cmd_name} - qualify with scope", } ent = entries[0] if cmd_type == _CMD_TYPE_CLIENT: return self._do_client_command(command=command, args=args, ent=ent) # server command if not self.server_sess_active: return { ResultKey.STATUS: APIStatus.ERROR_INACTIVE_SESSION, ResultKey.DETAILS: "Session is inactive, please try later", } return self.server_execute(command, cmd_entry=ent, props=props)
[docs] def server_execute(self, command, reply_processor=None, cmd_entry=None, cmd_ctx=None, props=None, headers=None): if self.in_logout and command != InternalCommands.LOGOUT: return {ResultKey.STATUS: APIStatus.SUCCESS, ResultKey.DETAILS: "session is logging out"} args = split_to_args(command) if cmd_ctx: ctx = cmd_ctx else: ctx = self._new_command_context(command, args, cmd_entry) ctx.set_command(command) if props: self.debug(f"server_execute: set cmd props to ctx {props}") ctx.set_command_props(props) if headers: self.debug(f"setting cmd headers: {headers}") ctx.set_command_headers(headers) start = time.time() ctx.set_reply_processor(reply_processor) self._try_command(ctx) secs = time.time() - start usecs = int(secs * 1000000) self.debug(f"server_execute Done [{usecs} usecs] {datetime.now()}") result = ctx.get_command_result() meta = ctx.get_meta() if result is None: return {ResultKey.STATUS: APIStatus.ERROR_SERVER_CONNECTION, ResultKey.DETAILS: "Server did not respond"} if meta: result[ResultKey.META] = meta if ResultKey.STATUS not in result: result[ResultKey.STATUS] = self._determine_api_status(result) return result
def _determine_api_status(self, result): status = result.get(ResultKey.STATUS) if status: return status data = result.get(ProtoKey.DATA) if not data: return APIStatus.ERROR_RUNTIME reply_data_list = [] for d in data: if isinstance(d, dict): t = d.get(ProtoKey.TYPE) if t == ProtoKey.SUCCESS: return APIStatus.SUCCESS if t == ProtoKey.STRING or t == ProtoKey.ERROR: reply_data_list.append(d[ProtoKey.DATA]) reply_data_full_response = "\n".join(reply_data_list) if ReplyKeyword.SESSION_INACTIVE in reply_data_full_response: return APIStatus.ERROR_INACTIVE_SESSION if ReplyKeyword.WRONG_SERVER in reply_data_full_response: return APIStatus.ERROR_SERVER_CONNECTION if ReplyKeyword.COMM_FAILURE in reply_data_full_response: return APIStatus.ERROR_SERVER_CONNECTION if ReplyKeyword.INVALID_CLIENT in reply_data_full_response: return APIStatus.ERROR_INVALID_CLIENT if ReplyKeyword.UNKNOWN_SITE in reply_data_full_response: return APIStatus.ERROR_INVALID_CLIENT if ReplyKeyword.NOT_AUTHORIZED in reply_data_full_response: return APIStatus.ERROR_AUTHORIZATION return APIStatus.SUCCESS
[docs] def stream_objects( self, channel: str, topic: str, stream_ctx: StreamContext, targets: List[str], producer: ObjectProducer, fl_ctx: FLContext, optional=False, secure=False, ): """Send a stream of Shareable objects to receivers. Args: channel: the channel for this stream topic: topic of the stream stream_ctx: context of the stream targets: receiving sites producer: the ObjectProducer that can produces the stream of Shareable objects fl_ctx: the FLContext object optional: whether the stream is optional secure: whether to use P2P security Returns: result from the generator's reply processing """ assert isinstance(self.object_streamer, ObjectStreamer) return self.object_streamer.stream( channel=channel, topic=topic, stream_ctx=stream_ctx, producer=producer, fl_ctx=fl_ctx, optional=optional, secure=secure, targets=[AuxMsgTarget.server_target()], # only stream to server! )
[docs] def register_stream_processing( self, channel: str, topic: str, factory: ConsumerFactory, stream_done_cb=None, consumed_cb=None, **cb_kwargs, ): """Register a ConsumerFactory for specified app channel and topic. Once a new streaming request is received for the channel/topic, the registered factory will be used to create an ObjectConsumer object to handle the new stream. Note: the factory should generate a new ObjectConsumer every time get_consumer() is called. This is because multiple streaming sessions could be going on at the same time. Each streaming session should have its own ObjectConsumer. Args: channel: app channel topic: app topic factory: the factory to be registered stream_done_cb: the callback to be called when streaming is done on receiving side consumed_cb: the callback to be called after a chunk is processed Returns: None """ assert isinstance(self.object_streamer, ObjectStreamer) self.object_streamer.register_stream_processing( channel=channel, topic=topic, factory=factory, stream_done_cb=stream_done_cb, consumed_cb=consumed_cb, **cb_kwargs, )
[docs] def shutdown_streamer(self): """Shutdown the engine's streamer. Returns: None """ if self.object_streamer: assert isinstance(self.object_streamer, ObjectStreamer) self.object_streamer.shutdown()