Source code for nvflare.apis.shareable

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
import copy

from ..fuel.utils import fobs
from .fl_constant import ReservedKey, ReturnCode

[docs]class ReservedHeaderKey(object): HEADERS = "__headers__" TOPIC = "__topic__" RC = ReservedKey.RC COOKIE_JAR = ReservedKey.COOKIE_JAR PEER_PROPS = "__peer_props__" REPLY_IS_LATE = "__reply_is_late__" TASK_NAME = ReservedKey.TASK_NAME TASK_ID = ReservedKey.TASK_ID WORKFLOW = ReservedKey.WORKFLOW AUDIT_EVENT_ID = ReservedKey.AUDIT_EVENT_ID CONTENT_TYPE = "__content_type__" TASK_OPERATOR = "__task_operator__" ERROR = "__error__"
[docs]class Shareable(dict): """The information communicated between server and client. Shareable is just a dict that can have any keys and values, defined by developers and users. It is recommended that keys are strings. Values must be serializable. """ def __init__(self): """Init the Shareable.""" super().__init__() self[ReservedHeaderKey.HEADERS] = {}
[docs] def set_header(self, key: str, value): header = self.get(ReservedHeaderKey.HEADERS, None) if not header: header = {} self[ReservedHeaderKey.HEADERS] = header header[key] = value
[docs] def get_header(self, key: str, default=None): header = self.get(ReservedHeaderKey.HEADERS, None) if not header: return default else: if not isinstance(header, dict): raise ValueError("header object must be a dict, but got {}".format(type(header))) return header.get(key, default)
# some convenience methods
[docs] def get_return_code(self, default=ReturnCode.OK): return self.get_header(ReservedHeaderKey.RC, default)
[docs] def set_return_code(self, rc): self.set_header(ReservedHeaderKey.RC, rc)
[docs] def set_peer_props(self, props: dict): self.set_header(ReservedHeaderKey.PEER_PROPS, props)
[docs] def get_peer_props(self): return self.get_header(ReservedHeaderKey.PEER_PROPS, None)
[docs] def get_peer_prop(self, key: str, default): props = self.get_peer_props() if not isinstance(props, dict): return default return props.get(key, default)
[docs] def to_bytes(self) -> bytes: """Serialize the Model object into bytes. Returns: object serialized in bytes. """ return fobs.dumps(self)
[docs] @classmethod def from_bytes(cls, data: bytes): """Convert the data bytes into Model object. Args: data: a bytes object Returns: an object loaded by FOBS from data """ return fobs.loads(data)
# some convenience functions
[docs]def make_reply(rc, headers=None) -> Shareable: reply = Shareable() reply.set_return_code(rc) if headers and isinstance(headers, dict): for k, v in headers.items(): reply.set_header(k, v) return reply
[docs]def make_copy(source: Shareable) -> Shareable: """ Make a copy from the source. The content (non-headers) will be kept intact. Headers will be deep-copied into the new instance. """ assert isinstance(source, Shareable) c = copy.copy(source) headers = source.get(ReservedHeaderKey.HEADERS, None) if headers: new_headers = copy.deepcopy(headers) else: new_headers = {} c[ReservedHeaderKey.HEADERS] = new_headers return c