# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from werkzeug.security import check_password_hash, generate_password_hash
from .blob import gen_client, gen_overseer, gen_server, gen_user
from .cert import Entity, make_root_cert
from .models import Capacity, Client, Organization, Project, Role, User, db
log = logging.getLogger(__name__)
[docs]def check_role(id, claims, requester):
is_creator = requester == Store._get_email_by_id(id)
is_project_admin = claims.get("role") == "project_admin"
return is_creator, is_project_admin
def _dict_or_empty(item):
return item.asdict() if item else {}
[docs]def get_or_create(session, model, **kwargs):
instance = session.query(model).filter_by(**kwargs).first()
if instance:
return instance
else:
instance = model(**kwargs)
session.add(instance)
session.commit()
return instance
[docs]def add_ok(obj):
obj.update({"status": "ok"})
return obj
[docs]def inc_dl(model, id):
instance = model.query.get(id)
instance.download_count = instance.download_count + 1
db.session.add(instance)
db.session.commit()
[docs]class Store(object):
[docs] @classmethod
def ready(cls):
user = User.query.get(1)
return user.approval_state >= 100 if user else False
[docs] @classmethod
def seed_user(cls, email, pwd):
seed_user = {
"name": "super_name",
"email": email,
"password": pwd,
"organization": "",
"role": "project_admin",
"approval_state": 200,
}
cls.create_user(seed_user)
cls.create_project()
return email, pwd
[docs] @classmethod
def init_db(cls):
db.drop_all()
db.create_all()
return add_ok({})
[docs] @classmethod
def create_project(cls):
project = Project()
db.session.add(project)
db.session.commit()
return add_ok({"project": _dict_or_empty(project)})
[docs] @classmethod
def build_project(cls, project):
entity = Entity(project.short_name)
cert_pair = make_root_cert(entity)
project.root_cert = cert_pair.ser_cert
project.root_key = cert_pair.ser_pri_key
db.session.add(project)
db.session.commit()
return add_ok({"project": _dict_or_empty(project)})
@classmethod
def _add_registered_info(cls, project_dict):
project_dict["num_clients"] = Client.query.count()
project_dict["num_orgs"] = Organization.query.count()
project_dict["num_users"] = User.query.count()
return project_dict
[docs] @classmethod
def set_project(cls, req):
project = Project.query.first()
if project.frozen:
return {"status": "Project is frozen"}
req.pop("id", None)
short_name = req.pop("short_name", "")
if short_name:
if len(short_name) > 16:
short_name = short_name[:16]
project.short_name = short_name
for k, v in req.items():
setattr(project, k, v)
db.session.add(project)
db.session.commit()
if project.frozen:
cls.build_project(project)
project_dict = _dict_or_empty(project)
project_dict = cls._add_registered_info(project_dict)
return add_ok({"project": project_dict})
[docs] @classmethod
def get_project(cls):
project_dict = _dict_or_empty(Project.query.first())
project_dict = cls._add_registered_info(project_dict)
return add_ok({"project": project_dict})
[docs] @classmethod
def get_overseer_blob(cls, key):
fileobj, filename = gen_overseer(key)
return fileobj, filename
[docs] @classmethod
def get_server_blob(cls, key, first_server=True):
fileobj, filename = gen_server(key, first_server)
return fileobj, filename
[docs] @classmethod
def get_orgs(cls):
all_orgs = Organization.query.all()
return add_ok({"client_list": [_dict_or_empty(org) for org in all_orgs]})
@classmethod
def _is_approved_by_client_id(cls, id):
client = Client.query.get(id)
return client.approval_state >= 100
@classmethod
def _is_approved_by_user_id(cls, id):
user = User.query.get(id)
return user.approval_state >= 100
[docs] @classmethod
def create_client(cls, req, creator):
creator_id = User.query.filter_by(email=creator).first().id
name = req.get("name")
organization = req.get("organization", "")
capacity = req.get("capacity")
description = req.get("description", "")
org = get_or_create(db.session, Organization, name=organization)
if capacity is not None:
cap = get_or_create(db.session, Capacity, capacity=json.dumps(capacity))
client = Client(name=name, description=description, creator_id=creator_id)
client.organization_id = org.id
client.capacity_id = cap.id
try:
db.session.add(client)
db.session.commit()
except Exception as e:
log.error(f"Error while creating client: {e}")
return None
return add_ok({"client": _dict_or_empty(client)})
[docs] @classmethod
def get_clients(cls, org=None):
if org is None:
all_clients = Client.query.all()
else:
all_clients = Organization.query.filter_by(name=org).first().clients
return add_ok({"client_list": [_dict_or_empty(client) for client in all_clients]})
[docs] @classmethod
def get_creator_id_by_client_id(cls, id):
client = Client.query.get(id)
if client:
creator_id = client.creator_id
return creator_id
else:
return None
[docs] @classmethod
def get_client(cls, id):
client = Client.query.get(id)
return add_ok({"client": _dict_or_empty(client)})
[docs] @classmethod
def patch_client_by_project_admin(cls, id, req):
client = Client.query.get(id)
organization = req.pop("organization", None)
if organization is not None:
org = get_or_create(db.session, Organization, name=organization)
client.organization_id = org.id
capacity = req.pop("capacity", None)
if capacity is not None:
capacity = json.dumps(capacity)
cap = get_or_create(db.session, Capacity, capacity=capacity)
client.capacity_id = cap.id
for k, v in req.items():
setattr(client, k, v)
try:
db.session.add(client)
db.session.commit()
except Exception as e:
log.error(f"Error while patching client: {e}")
return None
return add_ok({"client": _dict_or_empty(client)})
[docs] @classmethod
def patch_client_by_creator(cls, id, req):
client = Client.query.get(id)
_ = req.pop("approval_state", None)
organization = req.pop("organization", None)
if organization is not None:
org = get_or_create(db.session, Organization, name=organization)
client.organization_id = org.id
capacity = req.pop("capacity", None)
if capacity is not None:
capacity = json.dumps(capacity)
cap = get_or_create(db.session, Capacity, capacity=capacity)
client.capacity_id = cap.id
for k, v in req.items():
setattr(client, k, v)
try:
db.session.add(client)
db.session.commit()
except Exception as e:
log.error(f"Error while patching client: {e}")
return None
return add_ok({"client": _dict_or_empty(client)})
[docs] @classmethod
def delete_client(cls, id):
client = Client.query.get(id)
db.session.delete(client)
db.session.commit()
return add_ok({})
[docs] @classmethod
def get_client_blob(cls, key, id):
fileobj, filename = gen_client(key, id)
inc_dl(Client, id)
return fileobj, filename
[docs] @classmethod
def create_user(cls, req):
name = req.get("name", "")
email = req.get("email")
password = req.get("password", "")
password_hash = generate_password_hash(password)
organization = req.get("organization", "")
role_name = req.get("role", "")
description = req.get("description", "")
approval_state = req.get("approval_state", 0)
org = get_or_create(db.session, Organization, name=organization)
role = get_or_create(db.session, Role, name=role_name)
try:
user = User(
email=email,
name=name,
password_hash=password_hash,
description=description,
approval_state=approval_state,
)
user.organization_id = org.id
user.role_id = role.id
db.session.add(user)
db.session.commit()
except Exception as e:
log.error(f"Error while creating user: {e}")
return None
return add_ok({"user": _dict_or_empty(user)})
[docs] @classmethod
def verify_user(cls, email, password):
user = User.query.filter_by(email=email).first()
if user is not None and check_password_hash(user.password_hash, password):
return user
else:
return None
[docs] @classmethod
def get_users(cls, org_name=None):
if org_name is None:
all_users = User.query.all()
else:
org = Organization.query.filter_by(name=org_name).first()
if org:
all_users = org.users
else:
all_users = {}
return add_ok({"user_list": [_dict_or_empty(user) for user in all_users]})
@classmethod
def _get_email_by_id(cls, id):
user = User.query.get(id)
return user.email if user else None
[docs] @classmethod
def get_user(cls, id):
user = User.query.get(id)
return add_ok({"user": _dict_or_empty(user)})
[docs] @classmethod
def patch_user_by_project_admin(cls, id, req):
user = User.query.get(id)
org_name = req.pop("organization", None)
if org_name is not None:
org = get_or_create(db.session, Organization, name=org_name)
user.organization_id = org.id
role_name = req.pop("role", None)
if role_name is not None:
role = get_or_create(db.session, Role, name=role_name)
user.role_id = role.id
password = req.pop("password", None)
if password is not None:
password_hash = generate_password_hash(password)
user.password_hash = password_hash
for k, v in req.items():
setattr(user, k, v)
db.session.add(user)
db.session.commit()
return add_ok({"user": _dict_or_empty(user)})
[docs] @classmethod
def patch_user_by_creator(cls, id, req):
user = User.query.get(id)
_ = req.pop("approval_state", None)
role = req.pop("role", None)
if role is not None and user.role.name == "":
role = get_or_create(db.session, Role, name=role)
user.role_id = role.id
organization = req.pop("organization", None)
if organization is not None and user.organization.name == "":
org = get_or_create(db.session, Organization, name=organization)
user.organization_id = org.id
password = req.pop("password", None)
if password is not None:
password_hash = generate_password_hash(password)
user.password_hash = password_hash
for k, v in req.items():
setattr(user, k, v)
db.session.add(user)
db.session.commit()
return add_ok({"user": _dict_or_empty(user)})
[docs] @classmethod
def delete_user(cls, id):
clients = Client.query.filter_by(creator_id=id).all()
for client in clients:
db.session.delete(client)
user = User.query.get(id)
db.session.delete(user)
db.session.commit()
return add_ok({})
[docs] @classmethod
def get_user_blob(cls, key, id):
fileobj, filename = gen_user(key, id)
inc_dl(User, id)
return fileobj, filename