Source code for nvflare.private.fed.utils.numproto

# Copyright (c) 2021-2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""NumPy ndarray to protobuf serialization and deserialization."""

from io import BytesIO

import numpy as np

from nvflare.private.fed.protos.federated_pb2 import NDArray


[docs]def ndarray_to_proto(nda: np.ndarray) -> NDArray: """Serializes a numpy array into an NDArray protobuf message. Args: nda (np.ndarray): numpy array to serialize. Returns: Returns an NDArray protobuf message. """ nda_bytes = BytesIO() np.save(nda_bytes, nda, allow_pickle=False) return NDArray(ndarray=nda_bytes.getvalue())
[docs]def proto_to_ndarray(nda_proto: NDArray) -> np.ndarray: """Deserializes an NDArray protobuf message into a numpy array. Args: nda_proto (NDArray): NDArray protobuf message to deserialize. Returns: Returns a numpy.ndarray. """ nda_bytes = BytesIO(nda_proto.ndarray) return np.load(nda_bytes, allow_pickle=False)
[docs]def bytes_to_proto(data: bytes) -> NDArray: """Serializes a bytes into an NDArray protobuf message. Args: data : bytes data Returns: Returns an NDArray protobuf message. """ if not isinstance(data, bytes): raise TypeError("data must be bytes but got {}".format(type(data))) return NDArray(ndarray=data)
[docs]def proto_to_bytes(nda_proto: NDArray) -> bytes: """Deserializes an NDArray protobuf message into bytes. Args: nda_proto (NDArray): bytes. Returns: Returns bytes. """ nda_bytes = BytesIO(nda_proto.ndarray) return nda_bytes.read()