# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import traceback
from abc import ABC
from typing import List
from nvflare.apis.utils.format_check import name_check
[docs]
class PropKey:
CONN_SECURITY = "connection_security"
CUSTOM_CA_CERT = "custom_ca_cert"
[docs]
class ConnSecurity:
CLEAR = "clear"
INSECURE = "insecure"
TLS = "tls"
MTLS = "mtls"
[docs]
class ConfigEntity:
def __init__(self, props):
if not props:
props = {}
self.props = props
self.conn_security = None
self.custom_ca_cert = None
# validate properties
conn_security = self.get_prop(PropKey.CONN_SECURITY)
if conn_security:
if not isinstance(conn_security, str):
raise ValueError(f"invalid value '{conn_security}' for {PropKey.CONN_SECURITY}")
conn_security = conn_security.lower().strip()
valid_values = [
ConnSecurity.INSECURE,
ConnSecurity.CLEAR,
ConnSecurity.TLS,
ConnSecurity.MTLS,
]
if conn_security not in valid_values:
raise ValueError(
f"invalid value for {PropKey.CONN_SECURITY}: {conn_security}. Must be one of {valid_values}"
)
self.conn_security = conn_security
custom_ca_cert = self.get_prop(PropKey.CUSTOM_CA_CERT)
if custom_ca_cert:
if not os.path.isfile(custom_ca_cert):
raise ValueError(f"specified {PropKey.CUSTOM_CA_CERT} {custom_ca_cert} is not a valid file.")
_, file_extension = os.path.splitext(custom_ca_cert)
if file_extension != ".pem":
raise ValueError(f"specified {PropKey.CUSTOM_CA_CERT} {custom_ca_cert} must have '.pem' extension.")
self.custom_ca_cert = custom_ca_cert
[docs]
def get_prop(self, key: str, default=None):
return self.props.get(key, default)
[docs]
class Participant(ConfigEntity):
def __init__(self, type: str, name: str, org: str, enable_byoc: bool = False, *args, **kwargs):
"""Class to represent a participant.
Each participant communicates to other participant. Therefore, each participant has its
own name, type, organization it belongs to, rules and other information.
Args:
type (str): server, client, admin or other string that builders can handle
name (str): system-wide unique name
org (str): system-wide unique organization
enable_byoc (bool, optional): whether this participant allows byoc codes to be loaded. Defaults to False.
Raises:
ValueError: if name or org is not compliant with characters or format specification.
"""
ConfigEntity.__init__(self, kwargs)
err, reason = name_check(name, type)
if err:
raise ValueError(reason)
err, reason = name_check(org, "org")
if err:
raise ValueError(reason)
self.type = type
self.name = name
self.org = org
self.subject = name
self.enable_byoc = enable_byoc
[docs]
class Project(ConfigEntity):
def __init__(self, name: str, description: str, participants: List[Participant], config: dict = None):
"""A container class to hold information about this FL project.
This class only holds information. It does not drive the workflow.
Args:
name (str): the project name
description (str): brief description on this name
participants (List[Participant]): All the participants that will join this project
config: the whole config dict of the project
Raises:
ValueError: when duplicate name found in participants list
"""
ConfigEntity.__init__(self, config)
self.name = name
all_names = list()
for p in participants:
if p.name in all_names:
raise ValueError(f"Unable to add a duplicate name {p.name} into this project.")
else:
all_names.append(p.name)
self.description = description
self.participants = participants
[docs]
def get_participants_by_type(self, type, first_only=True):
found = list()
for p in self.participants:
if p.type == type:
if first_only:
return p
else:
found.append(p)
return found
[docs]
class Builder(ABC):
[docs]
def initialize(self, ctx: dict):
pass
[docs]
def build(self, project: Project, ctx: dict):
pass
[docs]
def finalize(self, ctx: dict):
pass
[docs]
def get_wip_dir(self, ctx: dict):
return ctx.get("wip_dir")
[docs]
def get_ws_dir(self, participate: Participant, ctx: dict):
return os.path.join(self.get_wip_dir(ctx), participate.name)
[docs]
def get_kit_dir(self, participant: Participant, ctx: dict):
return os.path.join(self.get_ws_dir(participant, ctx), "startup")
[docs]
def get_transfer_dir(self, participant: Participant, ctx: dict):
return os.path.join(self.get_ws_dir(participant, ctx), "transfer")
[docs]
def get_local_dir(self, participant: Participant, ctx: dict):
return os.path.join(self.get_ws_dir(participant, ctx), "local")
[docs]
def get_state_dir(self, ctx: dict):
return ctx.get("state_dir")
[docs]
def get_resources_dir(self, ctx: dict):
return ctx.get("resources_dir")
[docs]
class Provisioner(object):
def __init__(self, root_dir: str, builders: List[Builder]):
"""Workflow class that drive the provision process.
Provisioner's tasks:
- Maintain the provision workspace folder structure;
- Invoke Builders to generate the content of each startup kit
ROOT_WORKSPACE Folder Structure::
root_workspace_dir_name: this is the root of the workspace
project_dir_name: the root dir of the project, could be named after the project
resources: stores resource files (templates, configs, etc.) of the Provisioner and Builders
prod: stores the current set of startup kits (production)
participate_dir: stores content files generated by builders
wip: stores the set of startup kits to be created (WIP)
participate_dir: stores content files generated by builders
state: stores the persistent state of the Builders
Args:
root_dir (str): the directory path to hold all generated or intermediate folders
builders (List[Builder]): all builders that will be called to build the content
"""
self.root_dir = root_dir
self.builders = builders
self.ctx = None
def _make_dir(self, dirs):
for dir in dirs:
if not os.path.exists(dir):
os.makedirs(dir)
def _prepare_workspace(self, ctx):
workspace = ctx.get("workspace")
wip_dir = os.path.join(workspace, "wip")
state_dir = os.path.join(workspace, "state")
resources_dir = os.path.join(workspace, "resources")
ctx.update(dict(wip_dir=wip_dir, state_dir=state_dir, resources_dir=resources_dir))
dirs = [workspace, resources_dir, wip_dir, state_dir]
self._make_dir(dirs)
[docs]
def provision(self, project: Project):
# ctx = {"workspace": os.path.join(self.root_dir, project.name), "project": project}
workspace = os.path.join(self.root_dir, project.name)
# project is more static information while ctx is dynamic
ctx = {
"workspace": workspace,
"project": project,
}
self._prepare_workspace(ctx)
try:
for b in self.builders:
b.initialize(ctx)
# call builders!
for b in self.builders:
b.build(project, ctx)
for b in self.builders[::-1]:
b.finalize(ctx)
except Exception as ex:
prod_dir = ctx.get("current_prod_dir")
if prod_dir:
shutil.rmtree(prod_dir)
print("Exception raised during provision. Incomplete prod_n folder removed.")
traceback.print_exc()
finally:
wip_dir = ctx.get("wip_dir")
if wip_dir:
shutil.rmtree(wip_dir)
return ctx