Source code for nvflare.fuel.hci.client.file_transfer

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import tempfile
import time
import uuid
from pathlib import Path

import nvflare.fuel.hci.file_transfer_defs as ftd
from nvflare.fuel.hci.base64_utils import (
    b64str_to_binary_file,
    b64str_to_bytes,
    b64str_to_text_file,
    binary_file_to_b64str,
    text_file_to_b64str,
)
from nvflare.fuel.hci.binary_proto import CT_BINARY, receive_all, send_binary_file
from nvflare.fuel.hci.client.event import EventType
from nvflare.fuel.hci.cmd_arg_utils import join_args
from nvflare.fuel.hci.proto import MetaKey, ProtoKey
from nvflare.fuel.hci.reg import CommandEntry, CommandModule, CommandModuleSpec, CommandSpec
from nvflare.fuel.hci.table import Table
from nvflare.fuel.utils.zip_utils import split_path, unzip_all_from_bytes, unzip_all_from_file, zip_directory_to_file
from nvflare.lighter.utils import load_private_key_file, sign_folders
from nvflare.security.logging import secure_format_exception, secure_log_traceback

from .api_spec import CommandContext, ReceiveBytesFromServer, ReplyProcessor, SendBytesToServer
from .api_status import APIStatus


def _server_cmd_name(name: str):
    return ftd.SERVER_MODULE_NAME + "." + name


class _SendFileToServer(SendBytesToServer):
    def __init__(self, file_name: str):
        self.file_name = file_name

    def send(self, sock, meta: str):
        send_binary_file(sock, self.file_name, meta)
        os.remove(self.file_name)


class _ReceiveFileFromServer(ReceiveBytesFromServer):
    def __init__(self, file_name: str):
        self.file_name = file_name
        self.num_bytes_received = 0

    def receive(self, sock):
        ct, _, tmp_file_name = receive_all(sock)
        if ct != CT_BINARY:
            raise RuntimeError(f"expecting BINARY type {CT_BINARY} but got {ct}")
        if not tmp_file_name:
            raise RuntimeError("nothing received from the server")
        file_stats = os.stat(tmp_file_name)
        self.num_bytes_received = file_stats.st_size
        Path(os.path.dirname(self.file_name)).mkdir(parents=True, exist_ok=True)
        shutil.move(tmp_file_name, self.file_name)


class _DownloadProcessor(ReplyProcessor):
    """Reply processor to handle downloads."""

    def __init__(self, download_dir: str, str_to_file_func):
        self.download_dir = download_dir
        self.str_to_file_func = str_to_file_func
        self.data_received = False
        self.table = None

    def reply_start(self, ctx: CommandContext, reply_json):
        self.data_received = False
        self.table = Table(["file", "size"])

    def reply_done(self, ctx: CommandContext):
        if not self.data_received:
            ctx.set_command_result({"status": APIStatus.ERROR_PROTOCOL, "details": "protocol error - no data received"})
        else:
            command_result = ctx.get_command_result()
            if command_result is None:
                command_result = {}
            command_result["status"] = APIStatus.SUCCESS
            command_result["details"] = self.table
            ctx.set_command_result(command_result)

    def process_table(self, ctx: CommandContext, table: Table):
        try:
            rows = table.rows
            if len(rows) < 1:
                # no data
                ctx.set_command_result({"status": APIStatus.ERROR_PROTOCOL, "details": "protocol error - no file data"})
                return

            for i in range(len(rows)):
                if i == 0:
                    # this is header
                    continue

                row = rows[i]
                if len(row) < 1:
                    ctx.set_command_result(
                        {
                            "status": APIStatus.ERROR_PROTOCOL,
                            "details": "protocol error - missing file name",
                        }
                    )
                    return

                if len(row) < 2:
                    ctx.set_command_result(
                        {
                            "status": APIStatus.ERROR_PROTOCOL,
                            "details": "protocol error - missing file data",
                        }
                    )
                    return

                file_name = row[0]
                encoded_str = row[1]
                full_path = os.path.join(self.download_dir, file_name)
                num_bytes = self.str_to_file_func(encoded_str, full_path)
                self.table.add_row([file_name, str(num_bytes)])
                self.data_received = True
        except Exception as e:
            secure_log_traceback()
            ctx.set_command_result(
                {
                    "status": APIStatus.ERROR_RUNTIME,
                    "details": f"exception processing file: {secure_format_exception(e)}",
                }
            )


class _DownloadFolderProcessor(ReplyProcessor):
    """Reply processor for handling downloading directories."""

    def __init__(self, download_dir: str):
        self.download_dir = download_dir
        self.data_received = False

    def reply_start(self, ctx: CommandContext, reply_json):
        self.data_received = False

    def reply_done(self, ctx: CommandContext):
        if not self.data_received:
            ctx.set_command_result({"status": APIStatus.ERROR_RUNTIME, "details": "protocol error - no data received"})

    def process_error(self, ctx: CommandContext, err: str):
        self.data_received = True
        ctx.set_command_result({"status": APIStatus.ERROR_RUNTIME, "details": err})

    def process_string(self, ctx: CommandContext, item: str):
        try:
            self.data_received = True
            if item.startswith(ftd.DOWNLOAD_URL_MARKER):
                ctx.set_command_result(
                    {
                        "status": APIStatus.SUCCESS,
                        "details": item,
                    }
                )
            else:
                data_bytes = b64str_to_bytes(item)
                unzip_all_from_bytes(data_bytes, self.download_dir)
                ctx.set_command_result(
                    {
                        "status": APIStatus.SUCCESS,
                        "details": "Downloaded to dir {}".format(self.download_dir),
                    }
                )
        except Exception as e:
            secure_log_traceback()
            ctx.set_command_result(
                {
                    "status": APIStatus.ERROR_RUNTIME,
                    "details": f"exception processing reply: {secure_format_exception(e)}",
                }
            )


[docs]class FileTransferModule(CommandModule): """Command module with commands relevant to file transfer.""" def __init__(self, upload_dir: str, download_dir: str): if not os.path.isdir(upload_dir): raise ValueError("upload_dir {} is not a valid dir".format(upload_dir)) if not os.path.isdir(download_dir): raise ValueError("download_dir {} is not a valid dir".format(download_dir)) self.upload_dir = upload_dir self.download_dir = download_dir self.cmd_handlers = { ftd.PUSH_FOLDER_FQN: self.push_folder, ftd.DOWNLOAD_FOLDER_FQN: self.download_folder, ftd.PULL_BINARY_FQN: self.pull_binary_file, ftd.PULL_FOLDER_FQN: self.pull_folder, }
[docs] def get_spec(self): return CommandModuleSpec( name="file_transfer", cmd_specs=[ CommandSpec( name="upload_text", description="upload one or more text files in the upload_dir", usage="upload_text file_name ...", handler_func=self.upload_text_file, visible=False, ), CommandSpec( name="download_text", description="download one or more text files in the download_dir", usage="download_text file_name ...", handler_func=self.download_text_file, visible=False, ), CommandSpec( name="upload_binary", description="upload one or more binary files in the upload_dir", usage="upload_binary file_name ...", handler_func=self.upload_binary_file, visible=False, ), CommandSpec( name="download_binary", description="download one or more binary files in the download_dir", usage="download_binary file_name ...", handler_func=self.download_binary_file, visible=False, ), CommandSpec( name="pull_binary", description="download one binary files in the download_dir", usage="pull_binary control_id file_name", handler_func=self.pull_binary_file, visible=False, ), CommandSpec( name="push_folder", description="Submit application to the server", usage="submit_job job_folder", handler_func=self.push_folder, visible=False, ), CommandSpec( name="download_folder", description="download job contents from the server", usage="download_job job_id", handler_func=self.download_folder, visible=False, ), CommandSpec( name="info", description="show folder setup info", usage="info", handler_func=self.info, ), ], )
[docs] def generate_module_spec(self, server_cmd_spec: CommandSpec): """ Generate a new module spec based on a server command Args: server_cmd_spec: Returns: """ # print('generating cmd module for {}'.format(server_cmd_spec.client_cmd)) if not server_cmd_spec.client_cmd: return None handler = self.cmd_handlers.get(server_cmd_spec.client_cmd) if handler is None: print("no cmd handler found for {}".format(server_cmd_spec.client_cmd)) return None return CommandModuleSpec( name=server_cmd_spec.scope_name, cmd_specs=[ CommandSpec( name=server_cmd_spec.name, description=server_cmd_spec.description, usage=server_cmd_spec.usage, handler_func=handler, visible=server_cmd_spec.visible, ) ], )
[docs] def upload_file(self, args, ctx: CommandContext, cmd_name, file_to_str_func): full_cmd_name = _server_cmd_name(cmd_name) if len(args) < 2: return {"status": APIStatus.ERROR_SYNTAX, "details": "syntax error: missing file names"} parts = [full_cmd_name] for i in range(1, len(args)): file_name = args[i] full_path = os.path.join(self.upload_dir, file_name) if not os.path.isfile(full_path): return {"status": APIStatus.ERROR_RUNTIME, "details": f"no such file: {full_path}"} encoded_string = file_to_str_func(full_path) parts.append(file_name) parts.append(encoded_string) command = join_args(parts) api = ctx.get_api() return api.server_execute(command)
[docs] def upload_text_file(self, args, ctx: CommandContext): return self.upload_file(args, ctx, ftd.SERVER_CMD_UPLOAD_TEXT, text_file_to_b64str)
[docs] def upload_binary_file(self, args, ctx: CommandContext): return self.upload_file(args, ctx, ftd.SERVER_CMD_UPLOAD_BINARY, binary_file_to_b64str)
[docs] def download_file(self, args, ctx: CommandContext, cmd_name, str_to_file_func): full_cmd_name = _server_cmd_name(cmd_name) if len(args) < 2: return {"status": APIStatus.ERROR_SYNTAX, "details": "syntax error: missing file names"} parts = [full_cmd_name] for i in range(1, len(args)): file_name = args[i] parts.append(file_name) command = join_args(parts) reply_processor = _DownloadProcessor(self.download_dir, str_to_file_func) api = ctx.get_api() return api.server_execute(command, reply_processor)
[docs] def download_text_file(self, args, ctx: CommandContext): return self.download_file(args, ctx, ftd.SERVER_CMD_DOWNLOAD_TEXT, b64str_to_text_file)
[docs] def download_binary_file(self, args, ctx: CommandContext): return self.download_file(args, ctx, ftd.SERVER_CMD_DOWNLOAD_BINARY, b64str_to_binary_file)
def _tx_path(self, tx_id: str, folder_name: str): return os.path.join(self.download_dir, f"{folder_name}__{tx_id}")
[docs] def pull_binary_file(self, args, ctx: CommandContext): """ Args: cmd_name, ctl_id, folder_name, file_name, [end] """ cmd_entry = ctx.get_command_entry() if len(args) < 4 or len(args) > 5: return {ProtoKey.STATUS: APIStatus.ERROR_SYNTAX, ProtoKey.DETAILS: "usage: {}".format(cmd_entry.usage)} tx_id = args[1] folder_name = args[2] file_name = args[3] # is_end = len(args) > 4 tx_path = self._tx_path(tx_id, folder_name) file_path = os.path.join(tx_path, file_name) receiver = _ReceiveFileFromServer(file_path) api = ctx.get_api() api.fire_session_event(EventType.BEFORE_DOWNLOAD_FILE, f"downloading {file_name} ...") api = ctx.get_api() ctx.set_bytes_receiver(receiver) download_start = time.time() result = api.server_execute(ctx.get_command(), cmd_ctx=ctx) if result.get(ProtoKey.STATUS) != APIStatus.SUCCESS: return result download_end = time.time() api.fire_session_event( EventType.AFTER_DOWNLOAD_FILE, f"downloaded {file_name} ({receiver.num_bytes_received} bytes) in {download_end-download_start} seconds", ) dir_name, ext = os.path.splitext(file_path) if ext == ".zip": # unzip the file api.debug(f"unzipping file {file_path} to {dir_name}") os.makedirs(dir_name, exist_ok=True) unzip_all_from_file(file_path, dir_name) # remove the zip file os.remove(file_path) return result
[docs] def pull_folder(self, args, ctx: CommandContext): cmd_entry = ctx.get_command_entry() if len(args) < 2: return {ProtoKey.STATUS: APIStatus.ERROR_SYNTAX, ProtoKey.DETAILS: "usage: {}".format(cmd_entry.usage)} folder_name = args[1] destination_name = folder_name if len(args) > 2: destination_name = args[2] parts = [cmd_entry.full_command_name(), folder_name] command = join_args(parts) api = ctx.get_api() result = api.server_execute(command) if result.get(ProtoKey.STATUS) != APIStatus.SUCCESS: return result meta = result.get(ProtoKey.META) if not meta: return result file_names = meta.get(MetaKey.FILES) tx_id = meta.get(MetaKey.TX_ID) api.debug(f"received tx_id {tx_id}, file names: {file_names}") if not file_names: return result cmd_name = meta.get(MetaKey.CMD_NAME) error = None for i, file_name in enumerate(file_names): parts = [cmd_name, tx_id, folder_name, file_name] if i == len(file_names) - 1: # this is the last file parts.append("end") command = join_args(parts) reply = api.do_command(command) if reply.get(ProtoKey.STATUS) != APIStatus.SUCCESS: error = reply break if not error: tx_path = self._tx_path(tx_id, folder_name) destination_path = os.path.join(self.download_dir, destination_name) location = self._rename_folder(tx_path, destination_path) reply = { ProtoKey.STATUS: APIStatus.SUCCESS, ProtoKey.DETAILS: f"content downloaded to {location}", ProtoKey.META: {MetaKey.LOCATION: location}, } else: reply = error return reply
@staticmethod def _rename_folder(src: str, destination: str): max_tries = 1000 for i in range(max_tries): if i == 0: d = destination else: d = f"{destination}__{i}" try: os.rename(src, d) return d except: # try next pass # all rename attempts have failed - keep the original destination name return destination
[docs] def download_folder(self, args, ctx: CommandContext): cmd_entry = ctx.get_command_entry() assert isinstance(cmd_entry, CommandEntry) if len(args) != 2: return {"status": APIStatus.ERROR_SYNTAX, "details": "usage: {}".format(cmd_entry.usage)} job_id = args[1] parts = [cmd_entry.full_command_name(), job_id] command = join_args(parts) reply_processor = _DownloadFolderProcessor(self.download_dir) api = ctx.get_api() return api.server_execute(command, reply_processor)
[docs] def info(self, args, ctx: CommandContext): msg = f"Local Upload Source: {self.upload_dir}\n" msg += f"Local Download Destination: {self.download_dir}\n" return {"status": "ok", "details": msg}
[docs] def push_folder(self, args, ctx: CommandContext): # upload with binary protocol cmd_entry = ctx.get_command_entry() assert isinstance(cmd_entry, CommandEntry) if len(args) != 2: return {"status": APIStatus.ERROR_SYNTAX, "details": "usage: {}".format(cmd_entry.usage)} folder_name = args[1] if folder_name.endswith("/"): folder_name = folder_name.rstrip("/") full_path = os.path.join(self.upload_dir, folder_name) if not os.path.isdir(full_path): return {"status": APIStatus.ERROR_RUNTIME, "details": f"'{full_path}' is not a valid folder."} # sign folders and files api = ctx.get_api() if not api.insecure: # we are not in POC mode client_key_file_path = api.client_key private_key = load_private_key_file(client_key_file_path) sign_folders(full_path, private_key, api.client_cert) # zip the data out_file = os.path.join(tempfile.gettempdir(), str(uuid.uuid4())) zip_directory_to_file(self.upload_dir, folder_name, out_file) folder_name = split_path(full_path)[1] parts = [cmd_entry.full_command_name(), folder_name] command = join_args(parts) sender = _SendFileToServer(out_file) ctx.set_bytes_sender(sender) return api.server_execute(command, cmd_ctx=ctx)