# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import time
import uuid
import nvflare.fuel.hci.file_transfer_defs as ftd
from nvflare.fuel.hci.client.event import EventType
from nvflare.fuel.hci.cmd_arg_utils import join_args
from nvflare.fuel.hci.proto import MetaKey, ProtoKey
from nvflare.fuel.hci.reg import CommandEntry, CommandModule, CommandModuleSpec, CommandSpec
from nvflare.fuel.utils.zip_utils import split_path, unzip_all_from_file, zip_directory_to_file
from nvflare.lighter.utils import load_private_key_file, sign_folders
from .api_spec import CommandContext, HCIRequester
from .api_status import APIStatus
class _FileSender(HCIRequester):
def __init__(self, file_name: str):
self.file_name = file_name
def send_request(self, api, conn, cmd_ctx):
result = api.upload_file(self.file_name, conn)
os.remove(self.file_name)
return result
class _FileReceiver(HCIRequester):
def __init__(self, source_fqcn: str, ref_id, file_name: str):
self.source_fqcn = source_fqcn
self.ref_id = ref_id
self.file_name = file_name
self.num_bytes_received = 0
def send_request(self, api, conn, cmd_ctx):
self.num_bytes_received = api.download_file(self.source_fqcn, self.ref_id, self.file_name)
if self.num_bytes_received is not None:
cmd_ctx.set_command_result({ProtoKey.STATUS: APIStatus.SUCCESS, ProtoKey.DETAILS: "OK"})
else:
cmd_ctx.set_command_result(
{ProtoKey.STATUS: APIStatus.ERROR_RUNTIME, ProtoKey.DETAILS: "error receiving file"}
)
return None
[docs]
class FileTransferModule(CommandModule):
"""Command module with commands relevant to file transfer."""
PULL_BINARY_FILE_CMD = "pull_binary_file"
def __init__(self, upload_dir: str, download_dir: str):
if not os.path.isdir(upload_dir):
raise ValueError("upload_dir {} is not a valid dir".format(upload_dir))
if not os.path.isdir(download_dir):
raise ValueError("download_dir {} is not a valid dir".format(download_dir))
self.upload_dir = upload_dir
self.download_dir = download_dir
self.cmd_handlers = {
ftd.PUSH_FOLDER_FQN: self.push_folder,
ftd.PULL_FOLDER_FQN: self.pull_folder,
}
[docs]
def get_spec(self):
return CommandModuleSpec(
name="file_transfer",
cmd_specs=[
CommandSpec(
name=self.PULL_BINARY_FILE_CMD,
description="download one binary files in the download_dir",
usage="pull_binary source_fqcn tx_id ref_id folder_name file_name",
handler_func=self.pull_binary_file,
visible=False,
),
CommandSpec(
name="push_folder",
description="Submit application to the server",
usage="submit_job job_folder",
handler_func=self.push_folder,
visible=False,
),
CommandSpec(
name="info",
description="show folder setup info",
usage="info",
handler_func=self.info,
),
],
)
[docs]
def generate_module_spec(self, server_cmd_spec: CommandSpec):
"""
Generate a new module spec based on a server command
Args:
server_cmd_spec:
Returns:
"""
# print('generating cmd module for {}'.format(server_cmd_spec.client_cmd))
if not server_cmd_spec.client_cmd:
return None
handler = self.cmd_handlers.get(server_cmd_spec.client_cmd)
if handler is None:
print("no cmd handler found for {}".format(server_cmd_spec.client_cmd))
return None
return CommandModuleSpec(
name=server_cmd_spec.scope_name,
cmd_specs=[
CommandSpec(
name=server_cmd_spec.name,
description=server_cmd_spec.description,
usage=server_cmd_spec.usage,
handler_func=handler,
visible=server_cmd_spec.visible,
)
],
)
def _tx_path(self, tx_id: str, folder_name: str):
return os.path.join(self.download_dir, f"{folder_name}__{tx_id}")
[docs]
def pull_binary_file(self, args, ctx: CommandContext):
"""
Args: cmd_name, source_fqcn, tx_id, ref_id, folder_name, file_name, [end]
"""
cmd_entry = ctx.get_command_entry()
if len(args) < 6 or len(args) > 7:
return {ProtoKey.STATUS: APIStatus.ERROR_SYNTAX, ProtoKey.DETAILS: "usage: {}".format(cmd_entry.usage)}
source_fqcn = args[1]
tx_id = args[2]
ref_id = args[3]
folder_name = args[4]
file_name = args[5]
tx_path = self._tx_path(tx_id, folder_name)
file_path = os.path.join(tx_path, file_name)
api = ctx.get_api()
receiver = _FileReceiver(source_fqcn, ref_id, file_path)
api.fire_session_event(EventType.BEFORE_DOWNLOAD_FILE, f"downloading {file_name} ...")
ctx.set_requester(receiver)
download_start = time.time()
result = api.server_execute(ctx.get_command(), cmd_ctx=ctx)
if result.get(ProtoKey.STATUS) != APIStatus.SUCCESS:
return result
download_end = time.time()
api.fire_session_event(
EventType.AFTER_DOWNLOAD_FILE,
f"downloaded {file_name} ({receiver.num_bytes_received} bytes) in {download_end - download_start} seconds",
)
dir_name, ext = os.path.splitext(file_path)
if ext == ".zip":
# unzip the file
api.debug(f"unzipping file {file_path} to {dir_name}")
os.makedirs(dir_name, exist_ok=True)
unzip_all_from_file(file_path, dir_name)
# remove the zip file
os.remove(file_path)
return result
[docs]
def pull_folder(self, args, ctx: CommandContext):
cmd_entry = ctx.get_command_entry()
if len(args) < 2:
return {ProtoKey.STATUS: APIStatus.ERROR_SYNTAX, ProtoKey.DETAILS: "usage: {}".format(cmd_entry.usage)}
folder_name = args[1]
destination_name = folder_name
if len(args) > 2:
destination_name = args[2]
parts = [cmd_entry.full_command_name(), folder_name]
command = join_args(parts)
api = ctx.get_api()
result = api.server_execute(command)
if result.get(ProtoKey.STATUS) != APIStatus.SUCCESS:
return result
meta = result.get(ProtoKey.META)
if not meta:
return result
files = meta.get(MetaKey.FILES)
tx_id = meta.get(MetaKey.TX_ID)
source_fqcn = meta.get(MetaKey.SOURCE_FQCN)
api.debug(f"received tx_id {tx_id}, file names: {files}")
if not files:
return result
cmd_name = self.PULL_BINARY_FILE_CMD
error = None
for i, f in enumerate(files):
file_name = f[0]
ref_id = f[1]
parts = [cmd_name, source_fqcn, tx_id, ref_id, folder_name, file_name]
if i == len(files) - 1:
# this is the last file
parts.append("end")
command = join_args(parts)
reply = api.do_command(command)
if reply.get(ProtoKey.STATUS) != APIStatus.SUCCESS:
error = reply
break
if not error:
tx_path = self._tx_path(tx_id, folder_name)
destination_path = os.path.join(self.download_dir, destination_name)
location = self._rename_folder(tx_path, destination_path)
reply = {
ProtoKey.STATUS: APIStatus.SUCCESS,
ProtoKey.DETAILS: f"content downloaded to {location}",
ProtoKey.META: {MetaKey.LOCATION: location},
}
else:
reply = error
return reply
@staticmethod
def _rename_folder(src: str, destination: str):
max_tries = 1000
for i in range(max_tries):
if i == 0:
d = destination
else:
d = f"{destination}__{i}"
try:
os.rename(src, d)
return d
except:
# try next
pass
# all rename attempts have failed - keep the original destination name
return destination
[docs]
def info(self, args, ctx: CommandContext):
msg = f"Local Upload Source: {self.upload_dir}\n"
msg += f"Local Download Destination: {self.download_dir}\n"
return {"status": "ok", "details": msg}
[docs]
def push_folder(self, args, ctx: CommandContext):
# upload with binary protocol
cmd_entry = ctx.get_command_entry()
assert isinstance(cmd_entry, CommandEntry)
if len(args) != 2:
return {"status": APIStatus.ERROR_SYNTAX, "details": "usage: {}".format(cmd_entry.usage)}
folder_name = args[1]
if folder_name.endswith("/"):
folder_name = folder_name.rstrip("/")
full_path = os.path.join(self.upload_dir, folder_name)
if not os.path.isdir(full_path):
return {"status": APIStatus.ERROR_RUNTIME, "details": f"'{full_path}' is not a valid folder."}
# sign folders and files
api = ctx.get_api()
client_key_file_path = api.client_key
private_key = load_private_key_file(client_key_file_path)
sign_folders(full_path, private_key, api.client_cert)
# zip the data
out_file = os.path.join(tempfile.gettempdir(), str(uuid.uuid4()))
zip_directory_to_file(self.upload_dir, folder_name, out_file)
folder_name = split_path(full_path)[1]
parts = [cmd_entry.full_command_name(), folder_name]
command = join_args(parts)
sender = _FileSender(out_file)
ctx.set_requester(sender)
return api.server_execute(command, cmd_ctx=ctx)