# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import yaml
from nvflare.lighter import utils
from .constants import CtxKey, PropKey, ProvisionMode
from .entity import Entity, Project
[docs]
class ProvisionContext(dict):
def __init__(self, workspace_root_dir: str, project: Project):
super().__init__()
self[CtxKey.WORKSPACE] = workspace_root_dir
wip_dir = os.path.join(workspace_root_dir, "wip")
state_dir = os.path.join(workspace_root_dir, "state")
resources_dir = os.path.join(workspace_root_dir, "resources")
self.update({CtxKey.WIP: wip_dir, CtxKey.STATE: state_dir, CtxKey.RESOURCES: resources_dir})
dirs = [workspace_root_dir, resources_dir, wip_dir, state_dir]
utils.make_dirs(dirs)
# set commonly used data into ctx
self[CtxKey.PROJECT] = project
server = project.get_server()
admin_port = server.get_prop(PropKey.ADMIN_PORT, 8003)
self[CtxKey.ADMIN_PORT] = admin_port
fed_learn_port = server.get_prop(PropKey.FED_LEARN_PORT, 8002)
self[CtxKey.FED_LEARN_PORT] = fed_learn_port
self[CtxKey.SERVER_NAME] = server.name
[docs]
def get_project(self):
return self.get(CtxKey.PROJECT)
[docs]
def set_template(self, template: dict):
self[CtxKey.TEMPLATE] = template
[docs]
def get_template(self):
return self.get(CtxKey.TEMPLATE)
[docs]
def get_template_section(self, section_key: str):
template = self.get_template()
if not template:
raise RuntimeError("template is not available")
section = template.get(section_key)
if not section:
raise RuntimeError(f"missing section {section} in template")
return section
[docs]
def set_provision_mode(self, mode: str):
valid_modes = [ProvisionMode.POC, ProvisionMode.NORMAL]
if mode not in valid_modes:
raise ValueError(f"invalid provision mode {mode}: must be one of {valid_modes}")
self[CtxKey.PROVISION_MODE] = mode
[docs]
def get_provision_mode(self):
return self.get(CtxKey.PROVISION_MODE)
[docs]
def get_wip_dir(self):
return self.get(CtxKey.WIP)
[docs]
def get_ws_dir(self, entity: Entity):
return os.path.join(self.get_wip_dir(), entity.name)
[docs]
def get_kit_dir(self, entity: Entity):
return os.path.join(self.get_ws_dir(entity), "startup")
[docs]
def get_transfer_dir(self, entity: Entity):
return os.path.join(self.get_ws_dir(entity), "transfer")
[docs]
def get_local_dir(self, entity: Entity):
return os.path.join(self.get_ws_dir(entity), "local")
[docs]
def get_state_dir(self):
return self.get(CtxKey.STATE)
[docs]
def get_resources_dir(self):
return self.get(CtxKey.RESOURCES)
[docs]
def get_workspace(self):
return self.get(CtxKey.WORKSPACE)
[docs]
def yaml_load_template_section(self, section_key: str):
return yaml.safe_load(self.get_template_section(section_key))
[docs]
def json_load_template_section(self, section_key: str):
return json.loads(self.get_template_section(section_key))
[docs]
def build_from_template(
self,
dest_dir: str,
temp_section: str,
file_name,
replacement=None,
mode="t",
exe=False,
content_modify_cb=None,
**cb_kwargs,
):
"""Build a file from a template section and writes it to the specified location.
Args:
dest_dir: destination directory
temp_section: template section key
file_name: file name
replacement: replacement dict
mode: file mode
exe: executable
content_modify_cb: content modification callback, can be included to take the section content as the first argument and return the modified content
cb_kwargs: additional keyword arguments for the callback
"""
section = self.get_template_section(temp_section)
if replacement:
section = utils.sh_replace(section, replacement)
if content_modify_cb:
section = content_modify_cb(section, **cb_kwargs)
utils.write(os.path.join(dest_dir, file_name), section, mode, exe=exe)