# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from ..fuel.utils import fobs
from .fl_constant import ReservedKey, ReturnCode
[docs]class Shareable(dict):
"""The information communicated between server and client.
Shareable is just a dict that can have any keys and values, defined by developers and users.
It is recommended that keys are strings. Values must be serializable.
"""
def __init__(self):
"""Init the Shareable."""
super().__init__()
self[ReservedHeaderKey.HEADERS] = {}
# some convenience methods
[docs] def get_return_code(self, default=ReturnCode.OK):
return self.get_header(ReservedHeaderKey.RC, default)
[docs] def set_return_code(self, rc):
self.set_header(ReservedHeaderKey.RC, rc)
[docs] def add_cookie(self, name: str, data):
"""Add a cookie that is to be sent to the client and echoed back in response.
This method is intended to be called by the Server side.
Args:
name: the name of the cookie
data: the data of the cookie, which must be serializable
"""
cookie_jar = self.get_cookie_jar()
if not cookie_jar:
cookie_jar = {}
self.set_header(key=ReservedHeaderKey.COOKIE_JAR, value=cookie_jar)
cookie_jar[name] = data
[docs] def get_cookie_jar(self):
return self.get_header(key=ReservedHeaderKey.COOKIE_JAR, default=None)
[docs] def set_cookie_jar(self, jar):
self.set_header(key=ReservedHeaderKey.COOKIE_JAR, value=jar)
[docs] def get_cookie(self, name: str, default=None):
jar = self.get_cookie_jar()
if not jar:
return default
return jar.get(name, default)
[docs] def set_peer_props(self, props: dict):
self.set_header(ReservedHeaderKey.PEER_PROPS, props)
[docs] def get_peer_props(self):
return self.get_header(ReservedHeaderKey.PEER_PROPS, None)
[docs] def get_peer_prop(self, key: str, default):
props = self.get_peer_props()
if not isinstance(props, dict):
return default
return props.get(key, default)
[docs] def to_bytes(self) -> bytes:
"""Serialize the Model object into bytes.
Returns:
object serialized in bytes.
"""
return fobs.dumps(self)
[docs] @classmethod
def from_bytes(cls, data: bytes):
"""Convert the data bytes into Model object.
Args:
data: a bytes object
Returns:
an object loaded by FOBS from data
"""
return fobs.loads(data)
# some convenience functions
[docs]def make_reply(rc, headers=None) -> Shareable:
reply = Shareable()
reply.set_return_code(rc)
if headers and isinstance(headers, dict):
for k, v in headers.items():
reply.set_header(k, v)
return reply
[docs]def make_copy(source: Shareable) -> Shareable:
"""
Make a copy from the source.
The content (non-headers) will be kept intact. Headers will be deep-copied into the new instance.
"""
assert isinstance(source, Shareable)
c = copy.copy(source)
headers = source.get(ReservedHeaderKey.HEADERS, None)
if headers:
new_headers = copy.deepcopy(headers)
else:
new_headers = {}
c[ReservedHeaderKey.HEADERS] = new_headers
return c