Source code for nvflare.fuel.f3.sfm.sfm_endpoint

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import threading
from typing import List, Optional

from nvflare.fuel.f3.endpoint import Endpoint
from nvflare.fuel.f3.sfm.sfm_conn import SfmConnection

# Hard-coded stream ID to be used by packets before handshake
RESERVED_STREAM_ID = 16
MAX_CONN_PER_ENDPOINT = 1

log = logging.getLogger(__name__)


[docs]class SfmEndpoint: """An endpoint wrapper to keep SFM internal data""" def __init__(self, endpoint: Endpoint): self.endpoint = endpoint self.stream_id: int = RESERVED_STREAM_ID self.lock = threading.Lock() self.connections: List[SfmConnection] = []
[docs] def add_connection(self, sfm_conn: SfmConnection): with self.lock: while len(self.connections) >= MAX_CONN_PER_ENDPOINT: first_conn = self.connections[0] first_conn.conn.close() self.connections.pop(0) log.info( f"Connection {first_conn.get_name()} is evicted for {sfm_conn.get_name()} " f"from endpoint {self.endpoint.name} for exceeding limit {MAX_CONN_PER_ENDPOINT}" ) self.connections.append(sfm_conn)
[docs] def remove_connection(self, sfm_conn: SfmConnection): if not self.connections: log.debug( f"Connection {sfm_conn.get_name()} is already removed. " f"No connections for endpoint {self.endpoint.name}" ) return with self.lock: found_index = next( (index for index, conn in enumerate(self.connections) if conn.get_name() == sfm_conn.get_name()), None ) if found_index is not None: self.connections.pop(found_index) log.debug(f"Connection {sfm_conn.get_name()} is removed from endpoint {self.endpoint.name}") else: log.debug(f"Connection {sfm_conn.get_name()} is already removed from endpoint {self.endpoint.name}")
[docs] def get_connection(self, stream_id: int) -> Optional[SfmConnection]: if not self.connections: return None index = stream_id % len(self.connections) return self.connections[index]
[docs] def next_stream_id(self) -> int: """Get next stream_id for the endpoint stream_id is used to assemble fragmented data """ with self.lock: self.stream_id = (self.stream_id + 1) & 0xFFFF return self.stream_id