Source code for nvflare.private.fed.app.deployer.base_client_deployer

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import grpc

from nvflare.apis.fl_context import FLContext
from nvflare.private.fed.client.admin import RequestProcessor
from nvflare.private.fed.client.client_req_processors import ClientRequestProcessors
from nvflare.private.fed.client.fed_client import FederatedClient


[docs]class BaseClientDeployer: def __init__(self): """To init the BaseClientDeployer.""" self.multi_gpu = False self.outbound_filters = None self.inbound_filters = None self.federated_client = None self.model_validator = None self.cross_val_participating = False self.model_registry_path = None self.cross_val_timeout = None self.executors = None self.req_processors = ClientRequestProcessors.request_processors
[docs] def build(self, build_ctx): self.server_config = build_ctx["server_config"] self.client_config = build_ctx["client_config"] self.secure_train = build_ctx["secure_train"] self.client_name = build_ctx["client_name"] self.host = build_ctx["server_host"] self.overseer_agent = build_ctx["overseer_agent"] self.components = build_ctx["client_components"] self.handlers = build_ctx["client_handlers"]
[docs] def set_model_manager(self, model_manager): self.model_manager = model_manager
[docs] def create_fed_client(self, args, sp_target=None): if sp_target: for item in self.server_config: service = item["service"] service["target"] = sp_target servers = [{t["name"]: t["service"]} for t in self.server_config] retry_timeout = 30 if "retry_timeout" in self.client_config: retry_timeout = self.client_config["retry_timeout"] compression = grpc.Compression.NoCompression if "Deflate" == self.client_config.get("compression"): compression = grpc.Compression.Deflate elif "Gzip" == self.client_config.get("compression"): compression = grpc.Compression.Gzip for _, processor in self.components.items(): if isinstance(processor, RequestProcessor): self.req_processors.append(processor) self.federated_client = FederatedClient( client_name=str(self.client_name), # We only deploy the first server right now ..... server_args=sorted(servers)[0], client_args=self.client_config, secure_train=self.secure_train, retry_timeout=retry_timeout, executors=self.executors, compression=compression, overseer_agent=self.overseer_agent, args=args, components=self.components, handlers=self.handlers, ) return self.federated_client
[docs] def finalize(self, fl_ctx: FLContext): self.close()
[docs] def close(self): # if self.federated_client: # self.federated_client.model_manager.close() pass