# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import nvflare.app_opt.xgboost.histogram_based_v2.proto.federated_pb2 as pb2
from nvflare.apis.fl_context import FLContext
from nvflare.app_opt.xgboost.histogram_based_v2.adaptors.xgb_adaptor import XGBServerAdaptor
from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant
from nvflare.app_opt.xgboost.histogram_based_v2.grpc_client import GrpcClient
from nvflare.fuel.f3.drivers.net_utils import get_open_tcp_port
[docs]class GrpcServerAdaptor(XGBServerAdaptor):
def __init__(
self,
int_client_grpc_options=None,
xgb_server_ready_timeout=Constant.XGB_SERVER_READY_TIMEOUT,
in_process=True,
):
XGBServerAdaptor.__init__(self, in_process)
self.int_client_grpc_options = int_client_grpc_options
self.xgb_server_ready_timeout = xgb_server_ready_timeout
self.in_process = in_process
self.internal_xgb_client = None
self._server_stopped = False
self._exit_code = 0
def _start_server(self, addr: str, port: int, world_size: int, fl_ctx: FLContext):
runner_ctx = {
Constant.RUNNER_CTX_SERVER_ADDR: addr,
Constant.RUNNER_CTX_WORLD_SIZE: world_size,
Constant.RUNNER_CTX_PORT: port,
}
self.start_runner(runner_ctx, fl_ctx)
def _stop_server(self):
self._server_stopped = True
self.stop_runner()
def _is_stopped(self) -> (bool, int):
runner_stopped, ec = self.is_runner_stopped()
if runner_stopped:
return runner_stopped, ec
if self._server_stopped:
return True, self._exit_code
return False, 0
[docs] def start(self, fl_ctx: FLContext):
# we dynamically create server address on localhost
port = get_open_tcp_port(resources={})
if not port:
raise RuntimeError("failed to get a port for XGB server")
server_addr = f"127.0.0.1:{port}"
self._start_server(server_addr, port, self.world_size, fl_ctx)
# start XGB client
self.internal_xgb_client = GrpcClient(server_addr, self.int_client_grpc_options)
self.internal_xgb_client.start(ready_timeout=self.xgb_server_ready_timeout)
[docs] def stop(self, fl_ctx: FLContext):
client = self.internal_xgb_client
self.internal_xgb_client = None
if client:
self.log_info(fl_ctx, "Stopping internal XGB client")
client.stop()
self._stop_server()
[docs] def all_gather(self, rank: int, seq: int, send_buf: bytes, fl_ctx: FLContext) -> bytes:
result = self.internal_xgb_client.send_allgather(seq_num=seq, rank=rank, data=send_buf)
if isinstance(result, pb2.AllgatherReply):
return result.receive_buffer
else:
raise RuntimeError(f"bad result from XGB server: expect AllgatherReply but got {type(result)}")
[docs] def all_gather_v(self, rank: int, seq: int, send_buf: bytes, fl_ctx: FLContext) -> bytes:
result = self.internal_xgb_client.send_allgatherv(seq_num=seq, rank=rank, data=send_buf)
if isinstance(result, pb2.AllgatherVReply):
return result.receive_buffer
else:
raise RuntimeError(f"bad result from XGB server: expect AllgatherVReply but got {type(result)}")
[docs] def all_reduce(
self,
rank: int,
seq: int,
data_type: int,
reduce_op: int,
send_buf: bytes,
fl_ctx: FLContext,
) -> bytes:
result = self.internal_xgb_client.send_allreduce(
seq_num=seq,
rank=rank,
data=send_buf,
data_type=data_type,
reduce_op=reduce_op,
)
if isinstance(result, pb2.AllreduceReply):
return result.receive_buffer
else:
raise RuntimeError(f"bad result from XGB server: expect AllreduceReply but got {type(result)}")
[docs] def broadcast(self, rank: int, seq: int, root: int, send_buf: bytes, fl_ctx: FLContext) -> bytes:
result = self.internal_xgb_client.send_broadcast(seq_num=seq, rank=rank, data=send_buf, root=root)
if isinstance(result, pb2.BroadcastReply):
return result.receive_buffer
else:
raise RuntimeError(f"bad result from XGB server: expect BroadcastReply but got {type(result)}")