# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import struct
from .checksum import Checksum
HEADER_STRUCT = struct.Struct(">BII") # marker(1), seq_num(4), size(4)
HEADER_LEN = HEADER_STRUCT.size
MARKER_DATA = 101
MARKER_END = 102
MAX_CHUNK_SIZE = 1024 * 1024
[docs]def get_slice(buf, start: int, length: int):
view = memoryview(buf)
return view[start : start + length]
[docs]class ChunkState:
def __init__(self, expect_seq=1):
self.header_bytes = bytearray()
self.header = None
self.received = 0
self.expect_seq = expect_seq
def __str__(self):
d = {
"header": f"{self.header}",
"header_bytes": f"{self.header_bytes}",
"received": self.received,
"expect_seq": self.expect_seq,
}
return f"{d}"
[docs] def is_last(self):
return self.header and self.header.marker == MARKER_END
[docs]class Receiver:
def __init__(self, receive_data_func):
self.receive_data_func = receive_data_func
self.checksum = Checksum()
self.current_state = ChunkState()
self.done = False
[docs] def receive(self, data) -> bool:
if self.done:
raise RuntimeError("this receiver is already done")
s = chunk_it(self.current_state, data, 0, self._process_chunk)
self.current_state = s
done = s.is_last()
if done:
self.done = True
# compare checksum
expected_checksum = self.checksum.result()
if expected_checksum != s.header.checksum:
raise RuntimeError(f"checksum mismatch: expect {expected_checksum} but received {s.header.checksum}")
return done
def _process_chunk(self, c: ChunkState, data, start: int, length: int):
self.checksum.update(get_slice(data, start, length))
if self.receive_data_func:
self.receive_data_func(data, start, length)
[docs]class Sender:
def __init__(self, send_data_func):
self.send_data_func = send_data_func
self.checksum = Checksum()
self.next_seq = 1
self.closed = False
[docs] def send(self, data):
if self.closed:
raise RuntimeError("this sender is already closed")
if data is None:
data = b""
header = Header(MARKER_DATA, self.next_seq, len(data))
self.next_seq += 1
self.checksum.update(data)
header_bytes = header.to_bytes()
self.send_data_func(header_bytes)
self.send_data_func(data)
[docs] def close(self):
if self.closed:
raise RuntimeError("this sender is already closed")
self.closed = True
cs = self.checksum.result()
header = Header(MARKER_END, 0, cs)
header_bytes = header.to_bytes()
self.send_data_func(header_bytes)
[docs]def chunk_it(c: ChunkState, data, cursor: int, process_chunk_func) -> ChunkState:
if not isinstance(data, (bytearray, bytes)):
raise ValueError(f"can only chunk bytes data but got {type(data)}")
data_len = len(data)
if data_len <= 0:
return c
if cursor < 0 or cursor >= data_len:
raise ValueError(f"cursor {cursor} is out of data range [0, {data_len-1}]")
data_len -= cursor
header_bytes_len = len(c.header_bytes)
if header_bytes_len < HEADER_LEN:
# header not completed yet
num_bytes_needed = HEADER_LEN - header_bytes_len
# need this many bytes for header
if data_len >= num_bytes_needed:
# data has enough bytes
c.header_bytes.extend(get_slice(data, cursor, num_bytes_needed))
cursor += num_bytes_needed
data_len -= num_bytes_needed
c.unpack_header() # header bytes are ready
else:
c.header_bytes.extend(get_slice(data, cursor, data_len))
return c
if data_len == 0 or c.is_last():
return c
lack = c.header.size - c.received
if data_len <= lack:
# remaining data is part of the chunk
c.received += data_len
process_chunk_func(c, data, cursor, data_len)
if c.received == c.header.size:
# this chunk is completed: start a new chunk
return ChunkState(c.header.seq + 1)
else:
# this chunk is not done
return c
else:
# some remaining data is part of the chunk, but after that belongs to next chunk
c.received += lack
process_chunk_func(c, data, cursor, lack)
cursor += lack
next_chunk = ChunkState(c.header.seq + 1)
return chunk_it(next_chunk, data, cursor, process_chunk_func)