# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import io
import os
import re
import shlex
from typing import List
from nvflare.apis.utils.format_check import type_pattern_mapping
[docs]def split_to_args(line: str) -> List[str]:
if '"' in line:
return shlex.split(line)
else:
line = re.sub(" +", " ", line)
return line.split(" ")
[docs]def join_args(segs: List[str]) -> str:
result = ""
sep = ""
for a in segs:
parts = a.split()
if len(parts) < 2:
p = parts[0]
else:
p = '"' + a + '"'
result = result + sep + p
sep = " "
return result
[docs]class ArgValidator(argparse.ArgumentParser):
def __init__(self, name):
"""Validator for admin shell commands that uses argparse to check arguments and get usage through print_help.
Args:
name: name of the program to pass to ArgumentParser
"""
argparse.ArgumentParser.__init__(self, prog=name, add_help=False)
self.err = ""
[docs] def error(self, message):
self.err = message
[docs] def validate(self, args):
try:
result = self.parse_args(args)
return self.err, result
except Exception:
return 'argument error; try "? cmdName to show supported usage for a command"', None
[docs] def get_usage(self) -> str:
buffer = io.StringIO()
self.print_help(buffer)
usage_output = buffer.getvalue().split("\n", 1)[1]
buffer.close()
return usage_output
[docs]def process_targets_into_str(targets: List[str]) -> str:
if not isinstance(targets, list):
raise SyntaxError("targets is not a list.")
if not all(isinstance(t, str) for t in targets):
raise SyntaxError("all targets in the list of targets must be strings.")
for t in targets:
try:
validate_required_target_string(t)
except SyntaxError:
raise SyntaxError(f"invalid target {t}")
return " ".join(targets)
[docs]def validate_required_target_string(target: str) -> str:
"""Returns the target string if it exists and is valid."""
if not target:
raise SyntaxError("target is required but not specified.")
if not isinstance(target, str):
raise SyntaxError("target is not str.")
if not re.match("^[A-Za-z0-9._-]*$", target):
raise SyntaxError("target must be a string of only valid characters and no spaces.")
return target
[docs]def validate_options_string(options: str) -> str:
"""Returns the options string if it is valid."""
if not isinstance(options, str):
raise SyntaxError("options is not str.")
if not re.match("^[A-Za-z0-9- ]*$", options):
raise SyntaxError("options must be a string of only valid characters.")
return options
[docs]def validate_path_string(path: str) -> str:
"""Returns the path string if it is valid."""
if not isinstance(path, str):
raise SyntaxError("path is not str.")
if not re.match("^[A-Za-z0-9-._/]*$", path):
raise SyntaxError("unsupported characters in path {}".format(path))
if os.path.isabs(path):
raise SyntaxError("absolute path is not allowed")
paths = path.split(os.path.sep)
for p in paths:
if p == "..":
raise SyntaxError(".. in path name is not allowed")
return path
[docs]def validate_file_string(file: str) -> str:
"""Returns the file string if it is valid."""
validate_path_string(file)
basename, file_extension = os.path.splitext(file)
if file_extension not in [".txt", ".log", ".json", ".csv", ".sh", ".config", ".py"]:
raise SyntaxError(
"this command cannot be applied to file {}. Only files with the following extensions are "
"permitted: .txt, .log, .json, .csv, .sh, .config, .py".format(file)
)
return file
[docs]def validate_sp_string(sp_string) -> str:
if re.match(
type_pattern_mapping.get("sp_end_point"),
sp_string,
):
return sp_string
else:
raise SyntaxError("sp_string must be of the format example.com:8002:8003")