Source code for nvflare.private.fed.utils.fed_utils

# Copyright (c) 2021-2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import shutil
from logging.handlers import RotatingFileHandler
from multiprocessing.connection import Listener
from typing import List

from nvflare.apis.fl_constant import WorkspaceConstants
from nvflare.apis.fl_context import FLContext
from nvflare.fuel.hci.zip_utils import unzip_all_from_bytes
from nvflare.fuel.sec.security_content_service import LoadResult, SecurityContentService
from nvflare.fuel.utils import fobs
from nvflare.private.defs import SSLConstants
from nvflare.private.fed.protos.federated_pb2 import ModelData
from nvflare.private.fed.utils.numproto import bytes_to_proto


[docs]def shareable_to_modeldata(shareable, fl_ctx): # make_init_proto message model_data = ModelData() # create an empty message model_data.params["data"].CopyFrom(make_shareable_data(shareable)) context_data = make_context_data(fl_ctx) model_data.params["fl_context"].CopyFrom(context_data) return model_data
[docs]def make_shareable_data(shareable): return bytes_to_proto(shareable.to_bytes())
[docs]def make_context_data(fl_ctx): shared_fl_ctx = FLContext() shared_fl_ctx.set_public_props(fl_ctx.get_all_public_props()) props = fobs.dumps(shared_fl_ctx) context_data = bytes_to_proto(props) return context_data
[docs]def deploy_app(app_name, site_name, workspace, app_data): try: dest = os.path.join(workspace, WorkspaceConstants.APP_PREFIX + site_name) # Remove the previous deployed app. if os.path.exists(dest): shutil.rmtree(dest) if not os.path.exists(dest): os.makedirs(dest) unzip_all_from_bytes(app_data, dest) app_file = os.path.join(workspace, "fl_app.txt") if os.path.exists(app_file): os.remove(app_file) with open(app_file, "wt") as f: f.write(f"{app_name}") return True except: return False
[docs]def add_logfile_handler(log_file): root_logger = logging.getLogger() main_handler = root_logger.handlers[0] file_handler = RotatingFileHandler(log_file, maxBytes=20 * 1024 * 1024, backupCount=10) file_handler.setLevel(main_handler.level) file_handler.setFormatter(main_handler.formatter) root_logger.addHandler(file_handler)
[docs]def listen_command(listen_port, engine, execute_func, logger): conn = None listener = None try: address = ("localhost", listen_port) listener = Listener(address, authkey="client process secret password".encode()) conn = listener.accept() execute_func(conn, engine) except Exception as e: logger.exception(f"Could not create the listener for this process on port: {listen_port}: {e}.", exc_info=True) finally: if conn: conn.close() if listener: listener.close()
[docs]def secure_content_check(config: str, site_type: str) -> List[str]: """To check the security contents. Args: config (str): The fed_XXX config site_type (str): "server" or "client" Returns: A list of insecure content. """ insecure_list = [] data, sig = SecurityContentService.load_json(config) if sig != LoadResult.OK: insecure_list.append(config) sites_to_check = data["servers"] if site_type == "server" else [data["client"]] for site in sites_to_check: for filename in [SSLConstants.CERT, SSLConstants.PRIVATE_KEY, SSLConstants.ROOT_CERT]: content, sig = SecurityContentService.load_content(site.get(filename)) if sig != LoadResult.OK: insecure_list.append(site.get(filename)) if site_type == "server": if "authorization.json" in SecurityContentService.security_content_manager.signature: data, sig = SecurityContentService.load_json("authorization.json") if sig != LoadResult.OK: insecure_list.append("authorization.json") return insecure_list