Source code for nvflare.fuel.hci.server.file_transfer

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import tempfile
from typing import List

import nvflare.fuel.hci.file_transfer_defs as ftd
from nvflare.fuel.hci.base64_utils import (
    b64str_to_binary_file,
    b64str_to_bytes,
    b64str_to_text_file,
    binary_file_to_b64str,
    text_file_to_b64str,
)
from nvflare.fuel.hci.conn import Connection
from nvflare.fuel.hci.reg import CommandModule, CommandModuleSpec, CommandSpec
from nvflare.fuel.hci.server.constants import ConnProps
from nvflare.fuel.utils.zip_utils import unzip_all_from_bytes
from nvflare.private.fed.server.cmd_utils import CommandUtil
from nvflare.security.logging import secure_format_exception, secure_log_traceback


[docs]class FileTransferModule(CommandModule, CommandUtil): def __init__(self, upload_dir: str, download_dir: str): """Command module for file transfers. Args: upload_dir: download_dir: """ if not os.path.isdir(upload_dir): raise ValueError("upload_dir {} is not a valid dir".format(upload_dir)) if not os.path.isdir(download_dir): raise ValueError("download_dir {} is not a valid dir".format(download_dir)) self.upload_dir = upload_dir self.download_dir = download_dir
[docs] def get_spec(self): return CommandModuleSpec( name=ftd.SERVER_MODULE_NAME, cmd_specs=[ CommandSpec( name=ftd.SERVER_CMD_UPLOAD_TEXT, description="upload one or more text files", usage="_upload name1 data1 name2 data2 ...", handler_func=self.upload_text_file, visible=False, ), CommandSpec( name=ftd.SERVER_CMD_DOWNLOAD_TEXT, description="download one or more text files", usage="download file_name ...", handler_func=self.download_text_file, visible=False, ), CommandSpec( name=ftd.SERVER_CMD_UPLOAD_BINARY, description="upload one or more binary files", usage="upload name1 data1 name2 data2 ...", handler_func=self.upload_binary_file, visible=False, ), CommandSpec( name=ftd.SERVER_CMD_DOWNLOAD_BINARY, description="download one or more binary files", usage="download file_name ...", handler_func=self.download_binary_file, visible=False, ), CommandSpec( name=ftd.SERVER_CMD_UPLOAD_FOLDER, description="upload a folder from client", usage="upload_folder folder_name", handler_func=self.upload_folder, visible=False, ), CommandSpec( name=ftd.SERVER_CMD_INFO, description="show info", usage="info", handler_func=self.info, visible=False, ), ], conn_props={ConnProps.DOWNLOAD_DIR: self.download_dir, ConnProps.UPLOAD_DIR: self.upload_dir}, )
[docs] def upload_file(self, conn: Connection, args: List[str], str_to_file_func): if len(args) < 3: conn.append_error("syntax error: missing files") return if len(args) % 2 != 1: conn.append_error("syntax error: file name/data not paired") return table = conn.append_table(["file", "size"]) i = 1 while i < len(args): name = args[i] data = args[i + 1] i += 2 full_path = os.path.join(self.upload_dir, name) num_bytes = str_to_file_func(b64str=data, file_name=full_path) table.add_row([name, str(num_bytes)])
[docs] def upload_text_file(self, conn: Connection, args: List[str]): self.upload_file(conn, args, b64str_to_text_file)
[docs] def upload_binary_file(self, conn: Connection, args: List[str]): self.upload_file(conn, args, b64str_to_binary_file)
[docs] def download_file(self, conn: Connection, args: List[str], file_to_str_func): if len(args) < 2: conn.append_error("syntax error: missing file names") return table = conn.append_table(["name", "data"]) for i in range(1, len(args)): file_name = args[i] full_path = os.path.join(self.download_dir, file_name) if not os.path.exists(full_path): conn.append_error("no such file: {}".format(file_name)) continue if not os.path.isfile(full_path): conn.append_error("not a file: {}".format(file_name)) continue encoded_str = file_to_str_func(full_path) table.add_row([file_name, encoded_str])
[docs] def download_text_file(self, conn: Connection, args: List[str]): self.download_file(conn, args, text_file_to_b64str)
[docs] def download_binary_file(self, conn: Connection, args: List[str]): self.download_file(conn, args, binary_file_to_b64str)
def _authorize_upload_folder(self, conn: Connection, args: List[str]): if len(args) != 3: conn.append_error("syntax error: require data") return False, None folder_name = args[1] zip_b64str = args[2] tmp_dir = tempfile.mkdtemp() try: data_bytes = b64str_to_bytes(zip_b64str) unzip_all_from_bytes(data_bytes, tmp_dir) tmp_folder_path = os.path.join(tmp_dir, folder_name) if not os.path.isdir(tmp_folder_path): conn.append_error("logic error: unzip failed to create folder {}".format(tmp_folder_path)) return False, None return True, None except Exception as e: secure_log_traceback() conn.append_error(f"exception occurred: {secure_format_exception(e)}") return False, None finally: shutil.rmtree(tmp_dir)
[docs] def upload_folder(self, conn: Connection, args: List[str]): folder_name = args[1] zip_b64str = args[2] folder_path = os.path.join(self.upload_dir, folder_name) if os.path.exists(folder_path): shutil.rmtree(folder_path) data_bytes = b64str_to_bytes(zip_b64str) unzip_all_from_bytes(data_bytes, self.upload_dir) conn.set_prop("upload_folder_path", folder_path) conn.append_string("Created folder {}".format(folder_path))
[docs] def info(self, conn: Connection, args: List[str]): conn.append_string("Server Upload Destination: {}".format(self.upload_dir)) conn.append_string("Server Download Source: {}".format(self.download_dir))