# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import List
from nvflare.fuel.hci.conn import Connection
from nvflare.fuel.hci.proto import CredentialType, InternalCommands
from nvflare.fuel.hci.reg import CommandModule, CommandModuleSpec, CommandSpec
from nvflare.fuel.hci.security import IdentityKey, verify_password
from nvflare.fuel.hci.server.constants import ConnProps
from .reg import CommandFilter
from .sess import Session, SessionManager
[docs]class Authenticator(ABC):
"""Base class for authenticating credentials."""
[docs] @abstractmethod
def authenticate(self, user_name: str, credential: str, credential_type: CredentialType) -> bool:
"""Authenticate a specified user with the provided credential.
Args:
user_name: user login name
credential: provided credential
credential_type: type of credential
Returns: True if successful, False otherwise
"""
pass
[docs]class SimpleAuthenticator(Authenticator):
def __init__(self, users):
"""Authenticator to use in the LoginModule for authenticating admin clients for login.
Args:
users: user information
"""
self.users = users
[docs] def authenticate_password(self, user_name: str, pwd: str):
pwd_hash = self.users.get(user_name)
if pwd_hash is None:
return False
return verify_password(pwd_hash, pwd)
[docs] def authenticate_cn(self, user_name: str, cn):
return user_name == cn
[docs] def authenticate(self, user_name: str, credential, credential_type):
if credential_type == CredentialType.PASSWORD:
return self.authenticate_password(user_name, credential)
elif credential_type == CredentialType.CERT:
return self.authenticate_cn(user_name, credential)
else:
return False
[docs]class LoginModule(CommandModule, CommandFilter):
def __init__(self, authenticator: Authenticator, sess_mgr: SessionManager):
"""Login module.
CommandModule containing the login commands to handle login and logout of admin clients, as well as the
CommandFilter pre_command to check that a client is logged in with a valid session.
Args:
authenticator: Authenticator
sess_mgr: SessionManager
"""
if authenticator:
if not isinstance(authenticator, Authenticator):
raise TypeError("authenticator must be Authenticator but got {}.".format(type(authenticator)))
if not isinstance(sess_mgr, SessionManager):
raise TypeError("sess_mgr must be SessionManager but got {}.".format(type(sess_mgr)))
self.authenticator = authenticator
self.session_mgr = sess_mgr
[docs] def get_spec(self):
return CommandModuleSpec(
name="login",
cmd_specs=[
CommandSpec(
name=InternalCommands.PWD_LOGIN,
description="login to server",
usage="login userName password",
handler_func=self.handle_login,
visible=False,
),
CommandSpec(
name=InternalCommands.CERT_LOGIN,
description="login to server with SSL cert",
usage="login userName",
handler_func=self.handle_cert_login,
visible=False,
),
CommandSpec(
name="_logout",
description="logout from server",
usage="logout",
handler_func=self.handle_logout,
visible=False,
),
],
)
[docs] def handle_login(self, conn: Connection, args: List[str]):
if not self.authenticator:
conn.append_string("OK")
return
if len(args) != 3:
conn.append_string("REJECT")
return
user_name = args[1]
pwd = args[2]
ok = self.authenticator.authenticate(user_name, pwd, CredentialType.PASSWORD)
if not ok:
conn.append_string("REJECT")
return
session = self.session_mgr.create_session(user_name=user_name, user_org="global", user_role="project_admin")
conn.append_string("OK")
conn.append_token(session.token)
[docs] def handle_cert_login(self, conn: Connection, args: List[str]):
if not self.authenticator:
conn.append_string("OK")
return
if len(args) != 2:
conn.append_string("REJECT")
return
identity = conn.get_prop(ConnProps.CLIENT_IDENTITY, None)
if identity is None:
conn.append_string("REJECT")
return
user_name = args[1]
ok = self.authenticator.authenticate(user_name, identity[IdentityKey.NAME], CredentialType.CERT)
if not ok:
conn.append_string("REJECT")
return
session = self.session_mgr.create_session(
user_name=identity[IdentityKey.NAME],
user_org=identity.get(IdentityKey.ORG, ""),
user_role=identity.get(IdentityKey.ROLE, ""),
)
conn.append_string("OK")
conn.append_token(session.token)
[docs] def handle_logout(self, conn: Connection, args: List[str]):
if self.authenticator and self.session_mgr:
token = conn.get_prop(ConnProps.TOKEN)
if token:
self.session_mgr.end_session(token)
conn.append_string("OK")
[docs] def pre_command(self, conn: Connection, args: List[str]):
if args[0] in [InternalCommands.PWD_LOGIN, InternalCommands.CERT_LOGIN, InternalCommands.CHECK_SESSION]:
# skip login and check session commands
return True
# validate token
req_json = conn.request
token = None
data = req_json["data"]
for item in data:
it = item["type"]
if it == "token":
token = item["data"]
break
if token is None:
conn.append_error("not authenticated - no token")
return False
sess = self.session_mgr.get_session(token)
if sess:
assert isinstance(sess, Session)
sess.mark_active()
conn.set_prop(ConnProps.SESSION, sess)
conn.set_prop(ConnProps.USER_NAME, sess.user_name)
conn.set_prop(ConnProps.USER_ORG, sess.user_org)
conn.set_prop(ConnProps.USER_ROLE, sess.user_role)
conn.set_prop(ConnProps.TOKEN, token)
return True
else:
conn.append_error("session_inactive")
conn.append_string(
"user not authenticated or session timed out after {} seconds of inactivity - logged out".format(
self.session_mgr.idle_timeout
)
)
return False
[docs] def close(self):
self.session_mgr.shutdown()