Source code for nvflare.fuel.hci.server.file_transfer

# Copyright (c) 2021-2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import shutil
import tempfile
import traceback
from io import BytesIO
from typing import List
from zipfile import ZipFile

import nvflare.fuel.hci.file_transfer_defs as ftd
from nvflare.apis.job_def import JobDataKey, JobMetaKey
from nvflare.apis.job_def_manager_spec import JobDefManagerSpec
from nvflare.apis.utils.common_utils import get_size
from nvflare.fuel.hci.base64_utils import (
    b64str_to_binary_file,
    b64str_to_bytes,
    b64str_to_text_file,
    binary_file_to_b64str,
    bytes_to_b64str,
    text_file_to_b64str,
)
from nvflare.fuel.hci.conn import Connection
from nvflare.fuel.hci.reg import CommandModule, CommandModuleSpec, CommandSpec
from nvflare.fuel.hci.server.constants import ConnProps
from nvflare.fuel.hci.zip_utils import convert_legacy_zip, unzip_all_from_bytes, zip_directory_to_bytes
from nvflare.private.fed.server.cmd_utils import CommandUtil
from nvflare.private.fed.server.job_meta_validator import JobMetaValidator
from nvflare.security.security import Action

META_FILE = "meta.json"
MAX_DOWNLOAD_JOB__SIZE = 50 * 1024 * 1024 * 1204


[docs]class FileTransferModule(CommandModule, CommandUtil): def __init__(self, upload_dir: str, download_dir: str, upload_folder_authz_func=None, download_job_url=None): """Command module for file transfers. Args: upload_dir: download_dir: upload_folder_authz_func: """ if not os.path.isdir(upload_dir): raise ValueError("upload_dir {} is not a valid dir".format(upload_dir)) if not os.path.isdir(download_dir): raise ValueError("download_dir {} is not a valid dir".format(download_dir)) self.upload_dir = upload_dir self.download_dir = download_dir self.upload_folder_authz_func = upload_folder_authz_func self.download_url = download_job_url
[docs] def get_spec(self): return CommandModuleSpec( name=ftd.SERVER_MODULE_NAME, cmd_specs=[ CommandSpec( name=ftd.SERVER_CMD_UPLOAD_TEXT, description="upload one or more text files", usage="_upload name1 data1 name2 data2 ...", handler_func=self.upload_text_file, visible=False, ), CommandSpec( name=ftd.SERVER_CMD_DOWNLOAD_TEXT, description="download one or more text files", usage="download file_name ...", handler_func=self.download_text_file, visible=False, ), CommandSpec( name=ftd.SERVER_CMD_UPLOAD_BINARY, description="upload one or more binary files", usage="upload name1 data1 name2 data2 ...", handler_func=self.upload_binary_file, visible=False, ), CommandSpec( name=ftd.SERVER_CMD_DOWNLOAD_BINARY, description="download one or more binary files", usage="download file_name ...", handler_func=self.download_binary_file, visible=False, ), CommandSpec( name=ftd.SERVER_CMD_UPLOAD_FOLDER, description="upload a folder from client", usage="upload_folder folder_name", handler_func=self.upload_folder, authz_func=self._authorize_upload_folder, visible=False, ), CommandSpec( name=ftd.SERVER_CMD_SUBMIT_JOB, description="Submit a job", usage="submit_job job_folder", handler_func=self.submit_job, authz_func=self._authorize_submission, visible=False, ), CommandSpec( name=ftd.SERVER_CMD_DOWNLOAD_JOB, description="download a job", usage="download_job job_id", handler_func=self.download_job, visible=False, ), CommandSpec( name=ftd.SERVER_CMD_INFO, description="show info", usage="info", handler_func=self.info, visible=False, ), ], )
[docs] def upload_file(self, conn: Connection, args: List[str], str_to_file_func): if len(args) < 3: conn.append_error("syntax error: missing files") return if len(args) % 2 != 1: conn.append_error("syntax error: file name/data not paired") return table = conn.append_table(["file", "size"]) i = 1 while i < len(args): name = args[i] data = args[i + 1] i += 2 full_path = os.path.join(self.upload_dir, name) num_bytes = str_to_file_func(b64str=data, file_name=full_path) table.add_row([name, str(num_bytes)])
[docs] def upload_text_file(self, conn: Connection, args: List[str]): self.upload_file(conn, args, b64str_to_text_file)
[docs] def upload_binary_file(self, conn: Connection, args: List[str]): self.upload_file(conn, args, b64str_to_binary_file)
[docs] def download_file(self, conn: Connection, args: List[str], file_to_str_func): if len(args) < 2: conn.append_error("syntax error: missing file names") return table = conn.append_table(["name", "data"]) for i in range(1, len(args)): file_name = args[i] full_path = os.path.join(self.download_dir, file_name) if not os.path.exists(full_path): conn.append_error("no such file: {}".format(file_name)) continue if not os.path.isfile(full_path): conn.append_error("not a file: {}".format(file_name)) continue encoded_str = file_to_str_func(full_path) table.add_row([file_name, encoded_str])
[docs] def download_text_file(self, conn: Connection, args: List[str]): self.download_file(conn, args, text_file_to_b64str)
[docs] def download_binary_file(self, conn: Connection, args: List[str]): self.download_file(conn, args, binary_file_to_b64str)
def _authorize_upload_folder(self, conn: Connection, args: List[str]): if len(args) != 3: conn.append_error("syntax error: require data") return False, None folder_name = args[1] zip_b64str = args[2] tmp_dir = tempfile.mkdtemp() try: data_bytes = b64str_to_bytes(zip_b64str) unzip_all_from_bytes(data_bytes, tmp_dir) tmp_folder_path = os.path.join(tmp_dir, folder_name) if not os.path.isdir(tmp_folder_path): conn.append_error("logic error: unzip failed to create folder {}".format(tmp_folder_path)) return False, None if self.upload_folder_authz_func: err, authz_ctx = self.upload_folder_authz_func(tmp_folder_path) if err is None: err = "" elif not isinstance(err, str): # the validator failed to follow signature # assuming things are bad err = "folder validation failed" if len(err) > 0: conn.append_error(err) return False, None else: return True, authz_ctx else: return True, None except BaseException: traceback.print_exc() conn.append_error("exception occurred") return False, None finally: shutil.rmtree(tmp_dir)
[docs] def upload_folder(self, conn: Connection, args: List[str]): folder_name = args[1] zip_b64str = args[2] folder_path = os.path.join(self.upload_dir, folder_name) if os.path.exists(folder_path): shutil.rmtree(folder_path) data_bytes = b64str_to_bytes(zip_b64str) unzip_all_from_bytes(data_bytes, self.upload_dir) conn.set_prop("upload_folder_path", folder_path) conn.append_string("Created folder {}".format(folder_path))
[docs] def submit_job(self, conn: Connection, args: List[str]): folder_name = args[1] data_bytes = conn.get_prop(ConnProps.JOB_DATA) engine = conn.app_ctx try: with engine.new_context() as fl_ctx: job_validator = JobMetaValidator() valid, error, meta = job_validator.validate(folder_name, data_bytes) if not valid: conn.append_error(error) return job_def_manager = engine.job_def_manager if not isinstance(job_def_manager, JobDefManagerSpec): raise TypeError( f"job_def_manager in engine is not of type JobDefManagerSpec, but got {type(job_def_manager)}" ) meta = job_def_manager.create(meta, data_bytes, fl_ctx) conn.append_string("Submitted job: {}".format(meta.get(JobMetaKey.JOB_ID))) except Exception as e: conn.append_error("Exception occurred trying to submit job: " + str(e)) return conn.append_success("")
[docs] def download_job(self, conn: Connection, args: List[str]): if len(args) != 2: conn.append_error("syntax error: job ID required") return job_id = args[1] engine = conn.app_ctx try: job_def_manager = engine.job_def_manager if not isinstance(job_def_manager, JobDefManagerSpec): raise TypeError( f"job_def_manager in engine is not of type JobDefManagerSpec, but got {type(job_def_manager)}" ) with engine.new_context() as fl_ctx: job_data = job_def_manager.get_job_data(job_id, fl_ctx) size = get_size(job_data, seen=None) if size > MAX_DOWNLOAD_JOB__SIZE: conn.append_string(ConnProps.DOWNLOAD_JOB_URL + self.download_url + job_id) return self._unzip_data(job_data, job_id) except Exception as e: conn.append_error("Exception occurred trying to get job from store: " + str(e)) return try: data = zip_directory_to_bytes(self.download_dir, job_id) b64str = bytes_to_b64str(data) conn.append_string(b64str) except BaseException: traceback.print_exc() conn.append_error("Exception occurred during attempt to zip data to send for job: {}".format(job_id))
def _unzip_data(self, job_data, job_id): job_id_dir = os.path.join(self.download_dir, job_id) if os.path.exists(job_id_dir): shutil.rmtree(job_id_dir) os.mkdir(job_id_dir) data_bytes = job_data[JobDataKey.JOB_DATA.value] job_dir = os.path.join(job_id_dir, "job") os.mkdir(job_dir) unzip_all_from_bytes(data_bytes, job_dir) workspace_bytes = job_data[JobDataKey.WORKSPACE_DATA.value] workspace_dir = os.path.join(job_id_dir, "workspace") os.mkdir(workspace_dir) if workspace_bytes is not None: unzip_all_from_bytes(workspace_bytes, workspace_dir)
[docs] def info(self, conn: Connection, args: List[str]): conn.append_string("Server Upload Destination: {}".format(self.upload_dir)) conn.append_string("Server Download Source: {}".format(self.download_dir))
def _authorize_submission(self, conn: Connection, args: List[str]): folder_name = args[1] zip_b64str = args[2] data_bytes = convert_legacy_zip(b64str_to_bytes(zip_b64str)) conn.set_prop(ConnProps.JOB_DATA, data_bytes) meta_file = f"{folder_name}/{META_FILE}" with ZipFile(BytesIO(data_bytes), "r") as zf: meta_data = zf.read(meta_file) meta = json.loads(meta_data) conn.set_prop(ConnProps.JOB_META, meta) return self.authorize_job_meta(conn, meta, [Action.TRAIN])