# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import nvflare.fuel.utils.fobs as fobs
from nvflare.fuel.f3.cellnet.defs import Encoding, MessageHeaderKey
from nvflare.fuel.f3.message import Message
from nvflare.fuel.f3.streaming.stream_const import StreamHeaderKey
from nvflare.fuel.utils.buffer_list import BufferList
cell_mapping = {
"O": MessageHeaderKey.ORIGIN,
"D": MessageHeaderKey.DESTINATION,
"F": MessageHeaderKey.FROM_CELL,
"T": MessageHeaderKey.TO_CELL,
}
msg_mapping = {
"CH": MessageHeaderKey.CHANNEL,
"TP": MessageHeaderKey.TOPIC,
"SCH": StreamHeaderKey.CHANNEL,
"STP": StreamHeaderKey.TOPIC,
"SEQ": StreamHeaderKey.SEQUENCE,
}
[docs]def make_reply(rc: str, error: str = "", body=None) -> Message:
headers = {MessageHeaderKey.RETURN_CODE: rc}
if error:
headers[MessageHeaderKey.ERROR] = error
return Message(headers, payload=body)
[docs]def shorten_string(string):
if len(string) > 8:
ss = ":" + string[-7:]
else:
ss = string
return ss
[docs]def buffer_len(buffer: Any):
if not buffer:
buf_len = 0
elif isinstance(buffer, list):
buf_len = BufferList(buffer).get_size()
else:
buf_len = len(buffer)
return buf_len
[docs]def shorten_fqcn(fqcn):
parts = fqcn.split(".")
s_fqcn = ".".join([shorten_string(p) for p in parts])
return s_fqcn
[docs]def encode_payload(message: Message, encoding_key=MessageHeaderKey.PAYLOAD_ENCODING):
encoding = message.get_header(encoding_key)
if not encoding:
if message.payload is None:
encoding = Encoding.NONE
elif isinstance(message.payload, (bytes, bytearray, memoryview)):
encoding = Encoding.BYTES
else:
encoding = Encoding.FOBS
message.payload = fobs.dumps(message.payload, buffer_list=True)
message.set_header(encoding_key, encoding)
size = buffer_len(message.payload)
message.set_header(MessageHeaderKey.PAYLOAD_LEN, size)
[docs]def decode_payload(message: Message, encoding_key=MessageHeaderKey.PAYLOAD_ENCODING):
size = buffer_len(message.payload)
message.set_header(MessageHeaderKey.PAYLOAD_LEN, size)
encoding = message.get_header(encoding_key)
if not encoding:
return
if encoding == Encoding.FOBS:
message.payload = fobs.loads(message.payload)
elif encoding == Encoding.NONE:
message.payload = None
else:
# assume to be bytes
pass
message.remove_header(encoding_key)