Source code for nvflare.apis.utils.fl_context_utils

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import logging

from nvflare.apis.fl_constant import FLContextKey, NonSerializableKeys
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.fuel.sec.audit import AuditService
from nvflare.fuel.utils import fobs
from nvflare.security.logging import secure_format_exception

logger = logging.getLogger("fl_context_utils")


[docs]def get_serializable_data(fl_ctx: FLContext): new_fl_ctx = FLContext() for k, v in fl_ctx.props.items(): if k not in NonSerializableKeys.KEYS: try: fobs.dumps(v) new_fl_ctx.props[k] = v except Exception as e: msg = f"Object in FLContext with key {k} and type {type(v)} is not serializable (discarded): {secure_format_exception(e)}" logger.warning(generate_log_message(fl_ctx, msg)) return new_fl_ctx
[docs]def gen_new_peer_ctx(fl_ctx: FLContext, need_deep_copy=False): tmp_ctx = FLContext() pub_props = fl_ctx.get_all_public_props() if need_deep_copy: pub_props = copy.deepcopy(pub_props) tmp_ctx.set_public_props(pub_props) new_peer_ctx = get_serializable_data(tmp_ctx) return new_peer_ctx
[docs]def generate_log_message(fl_ctx: FLContext, msg: str): if not fl_ctx: return msg _identity_ = "identity" _my_run = "run" _peer_run = "peer_run" _peer_name = "peer" _task_name = "task_name" _task_id = "task_id" _rc = "peer_rc" _wf = "wf" all_kvs = {_identity_: fl_ctx.get_identity_name()} my_run = fl_ctx.get_job_id() if not my_run: my_run = "?" all_kvs[_my_run] = my_run task_name = fl_ctx.get_prop(FLContextKey.TASK_NAME, None) task_id = fl_ctx.get_prop(FLContextKey.TASK_ID, None) if task_name: all_kvs[_task_name] = task_name if task_id: all_kvs[_task_id] = task_id wf_id = fl_ctx.get_prop(FLContextKey.WORKFLOW, None) if wf_id is not None: all_kvs[_wf] = wf_id peer_ctx = fl_ctx.get_peer_context() if peer_ctx: if not isinstance(peer_ctx, FLContext): raise TypeError("peer_ctx must be an instance of FLContext, but got {}".format(type(peer_ctx))) peer_run = peer_ctx.get_job_id() if not peer_run: peer_run = "?" all_kvs[_peer_run] = peer_run peer_name = peer_ctx.get_identity_name() if not peer_name: peer_name = "?" all_kvs[_peer_name] = peer_name reply = fl_ctx.get_prop(FLContextKey.REPLY, None) if isinstance(reply, Shareable): rc = reply.get_return_code("OK") all_kvs[_rc] = rc item_order = [_identity_, _my_run, _wf, _peer_name, _peer_run, _rc, _task_name, _task_id] ctx_items = [] for item in item_order: if item in all_kvs: ctx_items.append(item + "=" + str(all_kvs[item])) return "[" + ", ".join(ctx_items) + "]: " + msg
[docs]def add_job_audit_event(fl_ctx: FLContext, ref: str = "", msg: str = "") -> str: return AuditService.add_job_event( job_id=fl_ctx.get_job_id(), scope_name=fl_ctx.get_prop(FLContextKey.EFFECTIVE_JOB_SCOPE_NAME, "?"), task_name=fl_ctx.get_prop(FLContextKey.TASK_NAME, "?"), task_id=fl_ctx.get_prop(FLContextKey.TASK_ID, "?"), ref=ref, msg=msg, )