Source code for nvflare.private.fed.client.admin_msg_sender

# Copyright (c) 2021-2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This is the FLAdmin Client to send the request message to the admin server."""

import threading
from multiprocessing.dummy import Pool as ThreadPool

import grpc

import nvflare.private.fed.protos.admin_pb2 as admin_msg
import nvflare.private.fed.protos.admin_pb2_grpc as admin_service
from nvflare.private.admin_defs import Message
from nvflare.private.fed.utils.messageproto import message_to_proto, proto_to_message

from .admin import Sender

lock = threading.Lock()


[docs]class AdminMessageSender(Sender): """AdminMessageSender to send the request message to the admin server.""" def __init__( self, client_name, root_cert=None, ssl_cert=None, private_key=None, server_args=None, secure=False, is_multi_gpu=False, rank=0, ): """To init the AdminMessageSender. Args: client_name: client name root_cert: root certificate ssl_cert: SSL certificate private_key: private key server_args: server args secure: True/False is_multi_gpu: True/False rank: local process rank """ self.client_name = client_name self.root_cert = root_cert self.ssl_cert = ssl_cert self.private_key = private_key self.secure = secure self.servers = server_args self.multi_gpu = is_multi_gpu self.rank = rank self.pool = ThreadPool(len(self.servers))
[docs] def send_reply(self, message: Message): """Call to send the request message. Args: message: request message """ if self.rank == 0: # self.send_client_reply(message) for taskname in tuple(self.servers): self._send_client_reply(message, taskname)
def _send_client_reply(self, message, taskname): try: with self._set_up_channel(self.servers[taskname]) as channel: stub = admin_service.AdminCommunicatingStub(channel) reply = admin_msg.Reply() reply.client_name = self.client_name reply.message.CopyFrom(message_to_proto(message)) # reply.message = message_to_proto(message) stub.SendReply(reply) except BaseException: pass
[docs] def retrieve_requests(self) -> [Message]: """Send the message to retrieve pending requests from the Server. Returns: list of messages. """ messages = [] if self.rank == 0: items = self.pool.map(self._retrieve_client_requests, tuple(self.servers)) for item in items: messages.extend(item) return messages
def _retrieve_client_requests(self, taskname): try: message_list = [] with self._set_up_channel(self.servers[taskname]) as channel: stub = admin_service.AdminCommunicatingStub(channel) client = admin_msg.Client() client.client_name = self.client_name messages = stub.Retrieve(client) for i in messages.message: message_list.append(proto_to_message(i)) except Exception as e: messages = None return message_list
[docs] def send_result(self, message: Message): """Send the processor results to server. Args: message: message """ if self.rank == 0: for taskname in tuple(self.servers): try: with self._set_up_channel(self.servers[taskname]) as channel: stub = admin_service.AdminCommunicatingStub(channel) reply = admin_msg.Reply() reply.client_name = self.client_name reply.message.CopyFrom(message_to_proto(message)) stub.SendResult(reply) except BaseException: pass
def _set_up_channel(self, channel_dict): """Connect client to the server. Args: channel_dict: grpc channel parameters Returns: an initialised grpc channel """ if self.secure: with open(self.root_cert, "rb") as f: trusted_certs = f.read() with open(self.private_key, "rb") as f: private_key = f.read() with open(self.ssl_cert, "rb") as f: certificate_chain = f.read() call_credentials = grpc.metadata_call_credentials( lambda context, callback: callback((("x-custom-token", self.client_name),), None) ) credentials = grpc.ssl_channel_credentials( certificate_chain=certificate_chain, private_key=private_key, root_certificates=trusted_certs ) composite_credentials = grpc.composite_channel_credentials(credentials, call_credentials) channel = grpc.secure_channel(**channel_dict, credentials=composite_credentials) else: channel = grpc.insecure_channel(**channel_dict) return channel
[docs] def close(self): self.pool.close()