# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from typing import List, Optional, Union
import yaml
import nvflare.lighter as prov
from nvflare.lighter import utils
from nvflare.lighter.utils import load_yaml
from .constants import CtxKey, PropKey, ProvisionMode
from .entity import Entity, Project
[docs]
class ProvisionContext(dict):
def __init__(self, workspace_root_dir: str, project: Project):
super().__init__()
self[CtxKey.WORKSPACE] = workspace_root_dir
wip_dir = os.path.join(workspace_root_dir, "wip")
state_dir = os.path.join(workspace_root_dir, "state")
self.update({CtxKey.WIP: wip_dir, CtxKey.STATE: state_dir})
dirs = [workspace_root_dir, wip_dir, state_dir]
utils.make_dirs(dirs)
# set commonly used data into ctx
self[CtxKey.PROJECT] = project
server = project.get_server()
fed_learn_port = server.get_prop(PropKey.FED_LEARN_PORT, 8002)
admin_port = server.get_prop(PropKey.ADMIN_PORT, fed_learn_port)
self[CtxKey.ADMIN_PORT] = admin_port
self[CtxKey.FED_LEARN_PORT] = fed_learn_port
self[CtxKey.SERVER_NAME] = server.name
self[CtxKey.TEMP_FILES_LOADED] = []
self[CtxKey.TEMPLATE] = {}
[docs]
def get_project(self) -> Project:
return self.get(CtxKey.PROJECT)
[docs]
def load_templates(self, temp_files: Union[str, List[str]]):
if not temp_files:
return
if isinstance(temp_files, str):
temp_files = [temp_files]
elif not isinstance(temp_files, list):
raise ValueError(f"temp_files must be str or List[str] but got {type(temp_files)}")
prov_folder = os.path.dirname(prov.__file__)
temp_folder = os.path.join(prov_folder, "templates")
loaded = self[CtxKey.TEMP_FILES_LOADED]
template = self[CtxKey.TEMPLATE]
for f in temp_files:
if f not in loaded:
template.update(load_yaml(os.path.join(temp_folder, f)))
loaded.append(f)
[docs]
def get_template_section(self, section_key: str):
template = self.get(CtxKey.TEMPLATE)
if not template:
raise RuntimeError("template is not available")
section = template.get(section_key)
if not section:
raise RuntimeError(f"missing section {section_key} in template")
return section
[docs]
def set_provision_mode(self, mode: str):
valid_modes = [ProvisionMode.POC, ProvisionMode.NORMAL]
if mode not in valid_modes:
raise ValueError(f"invalid provision mode {mode}: must be one of {valid_modes}")
self[CtxKey.PROVISION_MODE] = mode
[docs]
def get_provision_mode(self):
return self.get(CtxKey.PROVISION_MODE)
[docs]
def set_logger(self, logger):
self[CtxKey.LOGGER] = logger
[docs]
def get_logger(self):
return self.get(CtxKey.LOGGER)
[docs]
def get_wip_dir(self):
return self.get(CtxKey.WIP)
[docs]
def get_ws_dir(self, entity: Entity):
return os.path.join(self.get_wip_dir(), entity.name)
[docs]
def get_kit_dir(self, entity: Entity):
return os.path.join(self.get_ws_dir(entity), "startup")
[docs]
def get_transfer_dir(self, entity: Entity):
return os.path.join(self.get_ws_dir(entity), "transfer")
[docs]
def get_local_dir(self, entity: Entity):
return os.path.join(self.get_ws_dir(entity), "local")
[docs]
def get_state_dir(self):
return self.get(CtxKey.STATE)
[docs]
def get_workspace(self):
return self.get(CtxKey.WORKSPACE)
[docs]
def yaml_load_template_section(self, section_key: str, replacement=None):
section = self.build_section_from_template(section_key, replacement)
return yaml.safe_load(section)
[docs]
def json_load_template_section(self, section_key: str, replacement=None):
section = self.build_section_from_template(section_key, replacement)
return json.loads(section)
[docs]
def build_from_template(
self,
dest_dir: str,
temp_section: Union[str, List[str]],
file_name,
replacement=None,
mode="t",
exe=False,
content_modify_cb=None,
**cb_kwargs,
):
"""Build a file from a template section and writes it to the specified location.
Args:
dest_dir: destination directory
temp_section: template section key
file_name: file name
replacement: replacement dict
mode: file mode
exe: executable
content_modify_cb: content modification callback. If specified, it takes the section content as the
first argument and returns the modified content
cb_kwargs: additional keyword arguments for the callback
"""
section = self.build_section_from_template(temp_section, replacement, content_modify_cb, **cb_kwargs)
utils.write(os.path.join(dest_dir, file_name), section, mode, exe=exe)
[docs]
def build_section_from_template(
self,
temp_section: Union[str, List[str]],
replacement=None,
content_modify_cb=None,
**cb_kwargs,
):
if isinstance(temp_section, str):
temp_section = [temp_section]
elif not isinstance(temp_section, list):
raise ValueError(f"temp_section must be str or List[str] but got {type(temp_section)}")
section = ""
for s in temp_section:
section += self.get_template_section(s)
if replacement:
section = utils.sh_replace(section, replacement)
if content_modify_cb:
section = content_modify_cb(section, **cb_kwargs)
return section
[docs]
def info(self, msg: str):
logger = self.get_logger()
if logger:
logger.info(msg)
else:
print(f"INFO: {msg}")
[docs]
def error(self, msg: str):
logger = self.get_logger()
if logger:
logger.error(msg)
else:
print(f"ERROR: {msg}")
[docs]
def debug(self, msg: str):
logger = self.get_logger()
if logger:
logger.debug(msg)
else:
print(f"DEBUG: {msg}")
[docs]
def warning(self, msg: str):
logger = self.get_logger()
if logger:
logger.warning(msg)
else:
print(f"WARNING: {msg}")
[docs]
def get_result_location(self) -> Optional[str]:
"""Get the directory of the provision result.
This should be called after the provision is done.
Returns: the name of the directory that holds the provisioned result.
"""
return self.get(CtxKey.CURRENT_PROD_DIR)