Source code for nvflare.fuel.hci.server.login

# Copyright (c) 2021-2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import List

from nvflare.fuel.hci.conn import Connection
from nvflare.fuel.hci.reg import CommandModule, CommandModuleSpec, CommandSpec
from nvflare.fuel.hci.security import verify_password
from nvflare.fuel.hci.server.constants import ConnProps

from .reg import CommandFilter
from .sess import CHECK_SESSION_CMD_NAME, SessionManager

LOGIN_CMD_NAME = "_login"
CERT_LOGIN_CMD_NAME = "_cert_login"


[docs]class Authenticator(ABC): """Base class for authenticating credentials."""
[docs] @abstractmethod def authenticate(self, user_name: str, credential: str, credential_type: str) -> bool: """Authenticate a specified user with the provided credential. Args: user_name: user login name credential: provided credential credential_type: type of credential Returns: True if successful, False otherwise """ pass
[docs]class SimpleAuthenticator(Authenticator): def __init__(self, users): """Authenticator to use in the LoginModule for authenticating admin clients for login. Args: users: user information """ self.users = users
[docs] def authenticate_password(self, user_name: str, pwd: str): pwd_hash = self.users.get(user_name) if pwd_hash is None: return False return verify_password(pwd_hash, pwd)
[docs] def authenticate_cn(self, user_name: str, cn): return user_name == cn
[docs] def authenticate(self, user_name: str, credential, credential_type): if credential_type == "password": return self.authenticate_password(user_name, credential) elif credential_type == "cn": return self.authenticate_cn(user_name, credential) else: return False
[docs]class LoginModule(CommandModule, CommandFilter): def __init__(self, authenticator: Authenticator, sess_mgr: SessionManager): """Login module. CommandModule containing the login commands to handle login and logout of admin clients, as well as the CommandFilter pre_command to check that a client is logged in with a valid session. Args: authenticator: Authenticator sess_mgr: SessionManager """ if authenticator: if not isinstance(authenticator, Authenticator): raise TypeError("authenticator must be Authenticator but got {}.".format(type(authenticator))) if not isinstance(sess_mgr, SessionManager): raise TypeError("sess_mgr must be SessionManager but got {}.".format(type(sess_mgr))) self.authenticator = authenticator self.session_mgr = sess_mgr
[docs] def get_spec(self): return CommandModuleSpec( name="login", cmd_specs=[ CommandSpec( name=LOGIN_CMD_NAME, description="login to server", usage="login userName password", handler_func=self.handle_login, visible=False, ), CommandSpec( name=CERT_LOGIN_CMD_NAME, description="login to server with SSL cert", usage="login userName", handler_func=self.handle_cert_login, visible=False, ), CommandSpec( name="_logout", description="logout from server", usage="logout", handler_func=self.handle_logout, visible=False, ), ], )
[docs] def handle_login(self, conn: Connection, args: List[str]): if not self.authenticator: conn.append_string("OK") return if len(args) != 3: conn.append_string("REJECT") return user_name = args[1] pwd = args[2] ok = self.authenticator.authenticate(user_name, pwd, "password") if not ok: conn.append_string("REJECT") return session = self.session_mgr.create_session(user_name) conn.append_string("OK") conn.append_token(session.token)
[docs] def handle_cert_login(self, conn: Connection, args: List[str]): if not self.authenticator: conn.append_string("OK") return if len(args) != 2: conn.append_string("REJECT") return cn = conn.get_prop("_client_cn", None) if cn is None: conn.append_string("REJECT") return user_name = args[1] ok = self.authenticator.authenticate(user_name, cn, "cn") if not ok: conn.append_string("REJECT") return session = self.session_mgr.create_session(user_name) conn.append_string("OK") conn.append_token(session.token)
[docs] def handle_logout(self, conn: Connection, args: List[str]): if self.authenticator and self.session_mgr: token = conn.get_prop(ConnProps.TOKEN) if token: self.session_mgr.end_session(token) conn.append_string("OK")
[docs] def pre_command(self, conn: Connection, args: List[str]): if args[0] in [LOGIN_CMD_NAME, CERT_LOGIN_CMD_NAME, CHECK_SESSION_CMD_NAME]: # skip login and check session commands return True # validate token req_json = conn.request token = None data = req_json["data"] for item in data: it = item["type"] if it == "token": token = item["data"] break if token is None: conn.append_error("not authenticated - no token") return False sess = self.session_mgr.get_session(token) if sess: sess.mark_active() conn.set_prop(ConnProps.SESSION, sess) conn.set_prop(ConnProps.USER_NAME, sess.user_name) conn.set_prop(ConnProps.TOKEN, token) return True else: conn.append_error("session_inactive") conn.append_string( "user not authenticated or session timed out after {} seconds of inactivity - logged out".format( self.session_mgr.idle_timeout ) ) return False