# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import traceback
import nvflare.fuel.hci.file_transfer_defs as ftd
from nvflare.fuel.hci.base64_utils import (
b64str_to_binary_file,
b64str_to_bytes,
b64str_to_text_file,
binary_file_to_b64str,
bytes_to_b64str,
text_file_to_b64str,
)
from nvflare.fuel.hci.cmd_arg_utils import join_args
from nvflare.fuel.hci.reg import CommandModule, CommandModuleSpec, CommandSpec
from nvflare.fuel.hci.server.constants import ConnProps
from nvflare.fuel.hci.table import Table
from nvflare.fuel.hci.zip_utils import remove_leading_dotdot, split_path, unzip_all_from_bytes, zip_directory_to_bytes
from .api_spec import AdminAPISpec, ReplyProcessor
from .api_status import APIStatus
def _server_cmd_name(name: str):
return ftd.SERVER_MODULE_NAME + "." + name
class _DownloadProcessor(ReplyProcessor):
"""Reply processor to handle downloads."""
def __init__(self, download_dir: str, str_to_file_func):
self.download_dir = download_dir
self.str_to_file_func = str_to_file_func
self.data_received = False
self.table = None
def reply_start(self, api, reply_json):
self.data_received = False
self.table = Table(["file", "size"])
def reply_done(self, api):
if not self.data_received:
api.set_command_result({"status": APIStatus.ERROR_PROTOCOL, "details": "protocol error - no data received"})
else:
command_result = api.get_command_result()
if command_result is None:
command_result = {}
command_result["status"] = APIStatus.SUCCESS
command_result["details"] = self.table
api.set_command_result(command_result)
def process_table(self, api, table: Table):
try:
rows = table.rows
if len(rows) < 1:
# no data
api.set_command_result({"status": APIStatus.ERROR_PROTOCOL, "details": "protocol error - no file data"})
return
for i in range(len(rows)):
if i == 0:
# this is header
continue
row = rows[i]
if len(row) < 1:
api.set_command_result(
{
"status": APIStatus.ERROR_PROTOCOL,
"details": "protocol error - missing file name",
}
)
return
if len(row) < 2:
api.set_command_result(
{
"status": APIStatus.ERROR_PROTOCOL,
"details": "protocol error - missing file data",
}
)
return
file_name = row[0]
encoded_str = row[1]
full_path = os.path.join(self.download_dir, file_name)
num_bytes = self.str_to_file_func(encoded_str, full_path)
self.table.add_row([file_name, str(num_bytes)])
self.data_received = True
except Exception as ex:
traceback.print_exc()
api.set_command_result({"status": APIStatus.ERROR_RUNTIME, "details": f"exception processing file: {ex}"})
class _DownloadFolderProcessor(ReplyProcessor):
"""Reply processor for handling downloading directories."""
def __init__(self, download_dir: str):
self.download_dir = download_dir
self.data_received = False
def reply_start(self, api, reply_json):
self.data_received = False
def reply_done(self, api):
if not self.data_received:
api.set_command_result({"status": APIStatus.ERROR_RUNTIME, "details": "protocol error - no data received"})
def process_error(self, api: AdminAPISpec, err: str):
self.data_received = True
api.set_command_result({"status": APIStatus.ERROR_RUNTIME, "details": err})
def process_string(self, api, item: str):
try:
self.data_received = True
if item.startswith(ConnProps.DOWNLOAD_JOB_URL):
api.set_command_result(
{
"status": APIStatus.SUCCESS,
"details": item,
}
)
else:
data_bytes = b64str_to_bytes(item)
unzip_all_from_bytes(data_bytes, self.download_dir)
api.set_command_result(
{
"status": APIStatus.SUCCESS,
"details": "Download to dir {}".format(self.download_dir),
}
)
except Exception as ex:
traceback.print_exc()
api.set_command_result(
{
"status": APIStatus.ERROR_RUNTIME,
"details": "exception processing reply: {}".format(ex),
}
)
[docs]class FileTransferModule(CommandModule):
"""Command module with commands relevant to file transfer."""
def __init__(self, upload_dir: str, download_dir: str):
if not os.path.isdir(upload_dir):
raise ValueError("upload_dir {} is not a valid dir".format(upload_dir))
if not os.path.isdir(download_dir):
raise ValueError("download_dir {} is not a valid dir".format(download_dir))
self.upload_dir = upload_dir
self.download_dir = download_dir
[docs] def get_spec(self):
return CommandModuleSpec(
name="file_transfer",
cmd_specs=[
CommandSpec(
name="upload_text",
description="upload one or more text files in the upload_dir",
usage="upload_text file_name ...",
handler_func=self.upload_text_file,
visible=False,
),
CommandSpec(
name="download_text",
description="download one or more text files in the download_dir",
usage="download_text file_name ...",
handler_func=self.download_text_file,
visible=False,
),
CommandSpec(
name="upload_binary",
description="upload one or more binary files in the upload_dir",
usage="upload_binary file_name ...",
handler_func=self.upload_binary_file,
visible=False,
),
CommandSpec(
name="download_binary",
description="download one or more binary files in the download_dir",
usage="download_binary file_name ...",
handler_func=self.download_binary_file,
visible=False,
),
CommandSpec(
name="submit_job",
description="Submit application to the server",
usage="submit_job job_folder",
handler_func=self.submit_job,
),
CommandSpec(
name="download_job",
description="download job contents from the server",
usage="download_job job_id",
handler_func=self.download_job,
),
CommandSpec(
name="info",
description="show folder setup info",
usage="info",
handler_func=self.info,
),
],
)
[docs] def upload_file(self, args, api: AdminAPISpec, cmd_name, file_to_str_func):
full_cmd_name = _server_cmd_name(cmd_name)
if len(args) < 2:
return {"status": APIStatus.ERROR_SYNTAX, "details": "syntax error: missing file names"}
parts = [full_cmd_name]
for i in range(1, len(args)):
file_name = args[i]
full_path = os.path.join(self.upload_dir, file_name)
if not os.path.isfile(full_path):
return {"status": APIStatus.ERROR_RUNTIME, "details": f"no such file: {full_path}"}
encoded_string = file_to_str_func(full_path)
parts.append(file_name)
parts.append(encoded_string)
command = join_args(parts)
return api.server_execute(command)
[docs] def upload_text_file(self, args, api: AdminAPISpec):
return self.upload_file(args, api, ftd.SERVER_CMD_UPLOAD_TEXT, text_file_to_b64str)
[docs] def upload_binary_file(self, args, api: AdminAPISpec):
return self.upload_file(args, api, ftd.SERVER_CMD_UPLOAD_BINARY, binary_file_to_b64str)
[docs] def download_file(self, args, api: AdminAPISpec, cmd_name, str_to_file_func):
full_cmd_name = _server_cmd_name(cmd_name)
if len(args) < 2:
return {"status": APIStatus.ERROR_SYNTAX, "details": "syntax error: missing file names"}
parts = [full_cmd_name]
for i in range(1, len(args)):
file_name = args[i]
parts.append(file_name)
command = join_args(parts)
reply_processor = _DownloadProcessor(self.download_dir, str_to_file_func)
return api.server_execute(command, reply_processor)
[docs] def download_text_file(self, args, api: AdminAPISpec):
return self.download_file(args, api, ftd.SERVER_CMD_DOWNLOAD_TEXT, b64str_to_text_file)
[docs] def download_binary_file(self, args, api: AdminAPISpec):
return self.download_file(args, api, ftd.SERVER_CMD_DOWNLOAD_BINARY, b64str_to_binary_file)
[docs] def upload_folder(self, args, api: AdminAPISpec):
if len(args) != 2:
return {"status": APIStatus.ERROR_SYNTAX, "details": "usage: upload_folder folder_name"}
folder_name = args[1]
if folder_name.endswith("/"):
folder_name = folder_name.rstrip("/")
full_path = os.path.join(self.upload_dir, folder_name)
if not os.path.isdir(full_path):
return {"status": APIStatus.ERROR_RUNTIME, "details": f"'{full_path}' is not a valid folder."}
# zip the data
data = zip_directory_to_bytes(self.upload_dir, folder_name)
# prepare for upload
rel_path = os.path.relpath(full_path, self.upload_dir)
folder_name = remove_leading_dotdot(rel_path)
b64str = bytes_to_b64str(data)
parts = [_server_cmd_name(ftd.SERVER_CMD_UPLOAD_FOLDER), folder_name, b64str]
command = join_args(parts)
return api.server_execute(command)
[docs] def submit_job(self, args, api: AdminAPISpec):
if len(args) != 2:
return {"status": APIStatus.ERROR_SYNTAX, "details": "usage: submit_job job_folder"}
folder_name = args[1]
if folder_name.endswith("/"):
folder_name = folder_name.rstrip("/")
full_path = os.path.join(self.upload_dir, folder_name)
if not os.path.isdir(full_path):
return {"status": APIStatus.ERROR_RUNTIME, "details": f"'{full_path}' is not a valid folder."}
# zip the data
data = zip_directory_to_bytes(self.upload_dir, folder_name)
folder_name = split_path(full_path)[1]
b64str = bytes_to_b64str(data)
parts = [_server_cmd_name(ftd.SERVER_CMD_SUBMIT_JOB), folder_name, b64str]
command = join_args(parts)
return api.server_execute(command)
[docs] def download_job(self, args, api: AdminAPISpec):
if len(args) != 2:
return {"status": APIStatus.ERROR_SYNTAX, "details": "usage: download_job job_id"}
job_id = args[1]
parts = [_server_cmd_name(ftd.SERVER_CMD_DOWNLOAD_JOB), job_id]
command = join_args(parts)
reply_processor = _DownloadFolderProcessor(self.download_dir)
return api.server_execute(command, reply_processor)
[docs] def info(self, args, api: AdminAPISpec):
msg = f"Local Upload Source: {self.upload_dir}\n"
msg += f"Local Download Destination: {self.download_dir}\n"
resp = api.server_execute(_server_cmd_name(ftd.SERVER_CMD_INFO))
if "details" not in resp:
resp["details"] = msg
else:
resp["details"] = msg + resp["details"]
api.set_command_result(resp)
return resp