Source code for nvflare.private.fed.client.admin

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""The FedAdmin to communicate with the Admin server."""
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import ReservedHeaderKey
from nvflare.fuel.f3.cellnet.cell import Cell
from nvflare.fuel.f3.cellnet.core_cell import Message as CellMessage
from nvflare.fuel.hci.server.constants import ConnProps
from nvflare.fuel.sec.audit import Auditor, AuditService
from nvflare.fuel.sec.authz import AuthorizationService, AuthzContext, Person
from nvflare.private.admin_defs import Message, error_reply, ok_reply
from nvflare.private.defs import CellChannel, RequestHeader, new_cell_message
from nvflare.private.fed.server.site_security import SiteSecurity
from nvflare.security.logging import secure_format_exception, secure_log_traceback


[docs]class RequestProcessor(object): """The RequestProcessor is responsible for processing a request."""
[docs] def get_topics(self) -> [str]: """Get topics that this processor will handle. Returns: list of topics """ pass
[docs] def process(self, req: Message, app_ctx) -> Message: """Called to process the specified request. Args: req: request message app_ctx: application context Returns: reply message """ pass
[docs]class FedAdminAgent(object): """FedAdminAgent communicate with the FedAdminServer.""" def __init__(self, client_name: str, cell: Cell, app_ctx): """Init the FedAdminAgent. Args: client_name: client name app_ctx: application context cell: the Cell for communication """ auditor = AuditService.get_auditor() if not isinstance(auditor, Auditor): raise TypeError("auditor must be an instance of Auditor, but got {}".format(type(auditor))) self.name = client_name self.cell = cell self.auditor = auditor self.app_ctx = app_ctx self.processors = {} self.asked_to_stop = False self.register_cell_cb()
[docs] def register_cell_cb(self): self.cell.register_request_cb( channel=CellChannel.CLIENT_MAIN, topic="*", cb=self._dispatch_request, )
[docs] def register_processor(self, processor: RequestProcessor): """To register the RequestProcessor. Args: processor: RequestProcessor """ if not isinstance(processor, RequestProcessor): raise TypeError("processor must be an instance of RequestProcessor, but got {}".format(type(processor))) topics = processor.get_topics() for topic in topics: assert topic not in self.processors, "duplicate processors for topic {}".format(topic) self.processors[topic] = processor
def _dispatch_request( self, request: CellMessage, # *args, **kwargs ) -> CellMessage: assert isinstance(request, CellMessage), "request must be CellMessage but got {}".format(type(request)) req = request.payload assert isinstance(req, Message), "request payload must be Message but got {}".format(type(req)) topic = req.topic # create audit record if self.auditor: user_name = req.get_header(RequestHeader.USER_NAME, "") ref_event_id = req.get_header(ConnProps.EVENT_ID, "") self.auditor.add_event(user=user_name, action=topic, ref=ref_event_id) processor: RequestProcessor = self.processors.get(topic) if processor: with self.app_ctx.new_context() as fl_ctx: peer_props = req.get_header(ReservedHeaderKey.PEER_PROPS) if peer_props: peer_ctx = FLContext() peer_ctx.set_public_props(peer_props) fl_ctx.set_peer_context(peer_ctx) try: reply = None cmd = req.get_header(RequestHeader.ADMIN_COMMAND, None) if cmd: site_security = SiteSecurity() self._set_security_data(req, fl_ctx) authorized, messages = site_security.authorization_check(self.app_ctx, cmd, fl_ctx) if not authorized: reply = error_reply(messages) if not reply: # see whether pre-authorization is needed authz_flag = req.get_header(RequestHeader.REQUIRE_AUTHZ) require_authz = authz_flag == "true" if require_authz: # authorize this command! if cmd: user = Person( name=req.get_header(RequestHeader.USER_NAME, ""), org=req.get_header(RequestHeader.USER_ORG, ""), role=req.get_header(RequestHeader.USER_ROLE, ""), ) submitter = Person( name=req.get_header(RequestHeader.SUBMITTER_NAME, ""), org=req.get_header(RequestHeader.SUBMITTER_ORG, ""), role=req.get_header(RequestHeader.SUBMITTER_ROLE, ""), ) authz_ctx = AuthzContext(user=user, submitter=submitter, right=cmd) authorized, err = AuthorizationService.authorize(authz_ctx) if err: reply = error_reply(err) elif not authorized: reply = error_reply("not authorized") else: reply = error_reply("requires authz but missing admin command") if not reply: reply = processor.process(req, self.app_ctx) if reply is None: # simply ack reply = ok_reply() else: if not isinstance(reply, Message): raise RuntimeError(f"processor for topic {topic} failed to produce valid reply") except Exception as e: secure_log_traceback() reply = error_reply(f"exception_occurred: {secure_format_exception(e)}") else: reply = error_reply("invalid_request") return new_cell_message({}, reply) def _set_security_data(self, req, fl_ctx: FLContext): security_items = fl_ctx.get_prop(FLContextKey.SECURITY_ITEMS, {}) security_items[FLContextKey.USER_NAME] = req.get_header(RequestHeader.USER_NAME, "") security_items[FLContextKey.USER_ORG] = req.get_header(RequestHeader.USER_ORG, "") security_items[FLContextKey.USER_ROLE] = req.get_header(RequestHeader.USER_ROLE, "") security_items[FLContextKey.SUBMITTER_NAME] = req.get_header(RequestHeader.SUBMITTER_NAME, "") security_items[FLContextKey.SUBMITTER_ORG] = req.get_header(RequestHeader.SUBMITTER_ORG, "") security_items[FLContextKey.SUBMITTER_ROLE] = req.get_header(RequestHeader.SUBMITTER_ROLE, "") security_items[FLContextKey.JOB_META] = req.get_header(RequestHeader.JOB_META, {}) fl_ctx.set_prop(FLContextKey.SECURITY_ITEMS, security_items, private=True, sticky=False)