"""
Provides functions for turning typecaster annotated functions into CLI commands.
"""
from argparse import ArgumentParser, Namespace, HelpFormatter, Action, ONE_OR_MORE
from typing import Callable, Any, Tuple, List, Dict, Optional, Type
import inspect
import re
import yaml
from io import StringIO
from diplomat.processing import ConfigSpec
from diplomat.processing.type_casters import (
TypeCaster,
TypeCasterFunction,
get_typecaster_annotations,
get_type_name,
get_typecaster_kwd_arg_name,
to_metavar,
ConvertibleTypeCaster,
)
[docs]
class CLIError(Exception):
"""
A custom exception thrown when an error occurs when attempting to parse user CLI inputs. Used for handling cli
parsing error gracefully internally.
"""
pass
[docs]
class Flag(ConvertibleTypeCaster):
"""
Custom type caster type that represents a boolean flag argument on the command line (true/false doesn't need to be
specified). It's python type is automatically converted to a boolean. The default value of a flag argument
should be False so the python function signature matches the corresponding generated CLI signature.
"""
def __call__(self, arg: Any) -> bool:
return bool(arg)
def to_type_hint(self) -> Type:
return bool
def __repr__(self):
return type(self).__name__
Flag = Flag()
"""
Custom type caster type that represents a boolean flag argument on the command line (true/false doesn't need to be
specified). It's python type is automatically converted to a boolean. The default value of a flag argument
should be False so the python function signature matches the corresponding generated CLI signature.
"""
def _yaml_arg_load(str_list: List[str]) -> Any:
if not isinstance(str_list, list):
return str_list
str_list = " ".join(str_list)
try:
res = yaml.safe_load(StringIO(str_list))
except Exception as e:
raise CLIError(f"Unable to parse argument '{str_list}' as YAML, because: '{e}'")
return res
def _yaml_typecaster(caster: TypeCaster):
def checker(name: str, str_list: List[str]):
res = _yaml_arg_load(str_list)
try:
return caster(res)
except Exception as e:
raise CLIError(f"Failed to parse {name}, because: '{e}'")
return checker
def _func_arg_to_cmd_arg(
annotation: TypeCaster, default: Any, auto_cast: bool = True
) -> Tuple[dict, Optional[Callable]]:
if annotation is Flag:
args = dict(action="store_true")
arg_corrector = None
else:
args = dict(nargs="+", type=str, metavar=to_metavar(annotation))
arg_corrector = (
_yaml_typecaster(annotation)
if (auto_cast)
else _yaml_typecaster(lambda a: a)
)
if default == inspect.Parameter.empty:
args["required"] = True
else:
args["default"] = default
return args, arg_corrector
[docs]
class ComplexParsingWrapper:
"""
Internal: Parses arguments for a single diplomat sub-command.
Diplomat's parses CLI arguments using a yaml parser,
so it supports lists, numbers, floats, etc.
"""
DELETE = object()
[docs]
def __init__(
self,
run_func: Callable,
correctors: Dict[str, Callable],
parser: ArgumentParser,
):
self._func = run_func
self._correctors = correctors
self._parser = parser
@property
def parser(self) -> ArgumentParser:
return self._parser
@property
def accepts_extra_flags(self) -> bool:
return getattr(self._func, "__allow_arbitrary_flags", False)
@property
def correctors(self) -> Dict[str, Callable]:
return self._correctors
def __call__(self, parsed_args: Namespace) -> Any:
result = vars(parsed_args)
for var, value in list(result.items()):
if value is self.DELETE:
del result[var]
del self._correctors[var]
for var, corrector in self._correctors.items():
result[var] = corrector(var, result[var])
return self._func(**result)
[docs]
def get_summary_from_doc_str(doc_str: str) -> str:
"""
Extracts the summary for a command from a function's doc string.
"""
return "".join(re.split(":param |:return|:throw", doc_str)[:1])
[docs]
def func_args_to_config_spec(
func: TypeCasterFunction, caller_func: TypeCasterFunction
) -> ConfigSpec:
"""
Convert extra typecaster function arguments to a ConfigSpec.
:param func: The function to get parameters from.
:param caller_func: The calling function.
:return: A ConfigSpec for arguments not in the caller function.
"""
config_spec = {}
signature = inspect.signature(func)
cmd_args = get_typecaster_annotations(func)
caller_args = get_typecaster_annotations(caller_func)
# Extract params from the doc string...
if hasattr(func, "__clean_doc__"):
doc_str = inspect.cleandoc(func.__clean_doc__)
else:
doc_str = inspect.getdoc(func)
if doc_str is None:
help_messages = {}
else:
help_messages = {
name: info
for name, info in re.findall(":param +([a-zA-Z0-9_]+):([^:]*)", doc_str)
}
for name, caster in cmd_args.items():
if name == "return" or name in caller_args:
continue
config_spec[name] = (
signature.parameters[name].default,
caster,
help_messages.get(name, "").strip(),
)
return config_spec
[docs]
def func_to_command(
func: TypeCasterFunction, parser: ArgumentParser, allow_short_form: bool = True
) -> ArgumentParser:
"""
Convert a typecaster function into an argparse command (CLI command).
:param func: Type caster function to turn into a CLI command.
:param parser: The argument parser to add the function CLI command to.
:param allow_short_form: If true, allow abbreviated versions of arguments to be passed to the CLI.
:return: The argparse parser with the function added as a command.
"""
parser.formatter_class = YAMLArgHelpFormatter
parser.allow_abbrev = False
signature = inspect.signature(func)
cmd_args = get_typecaster_annotations(func)
arg_correctors = {}
# Extract params from the doc string...
if hasattr(func, "__clean_doc__"):
doc_str = inspect.cleandoc(func.__clean_doc__)
else:
doc_str = inspect.getdoc(func)
if doc_str is None:
help_messages = {}
else:
parser.description = get_summary_from_doc_str(doc_str)
help_messages = {
name: info
for name, info in re.findall(":param +([a-zA-Z0-9_]+):([^:]*)", doc_str)
}
abbr_set = set()
if getattr(func, "__allow_arbitrary_flags", False):
name = get_typecaster_kwd_arg_name(func)
if name is not None and name in help_messages:
parser.epilog = help_messages[name]
pos_arg_count = getattr(func, "__pos_cmd_arg_count", 0)
for name, caster in cmd_args.items():
if name == "return":
continue
args, corrector = _func_arg_to_cmd_arg(
caster, signature.parameters[name].default
)
if name in help_messages:
args["help"] = help_messages[name]
abbr_cmd = "-" + "".join(s[:1] for s in name.split("_"))
if pos_arg_count > 0:
if "nargs" in args:
if pos_arg_count > 1:
args["nargs"] = 1
else:
# A default argument for positional arguments only works if the argument is in the last position.
no_default = (
signature.parameters[name].default is inspect.Parameter.empty
)
args["nargs"] = "+" if (no_default) else "*"
parser.add_argument(name, **args)
pos_arg_count -= 1
elif abbr_cmd in abbr_set or not allow_short_form:
parser.add_argument("--" + name, **args)
else:
parser.add_argument("--" + name, abbr_cmd, **args)
abbr_set.add(abbr_cmd)
if corrector is not None:
arg_correctors[name] = corrector
extra_args = getattr(func, "__extra_args", {})
auto_cast = getattr(func, "__auto_cast", True)
for name, (default, caster, desc) in extra_args.items():
args, corrector = _func_arg_to_cmd_arg(
caster, ComplexParsingWrapper.DELETE, auto_cast=auto_cast
)
args["help"] = (
str(desc)
if not callable(getattr(desc, "__typecaster_str__", None))
else str(desc.__typecaster_str__())
)
parser.add_argument("--" + name, **args)
if corrector is not None:
arg_correctors[name] = corrector
parser.set_defaults(_func=ComplexParsingWrapper(func, arg_correctors, parser))
return parser
[docs]
class CLIEngine:
"""
Represents a CLI program. Is a callable that accepts cli arguments and when called and executes the correct sub-command.
"""
[docs]
def __init__(self, parent_parser: ArgumentParser):
"""
Private: Create a new CLIEngine. Internal, use build_full_parser to create an instance of this class instead.
:param parent_parser: An argparse ArgumentParser to wrap.
"""
self._parser = parent_parser
def _reparse(
self, args: List[str], extra: List[str], arg_handler: ComplexParsingWrapper
) -> Namespace:
if not arg_handler.accepts_extra_flags:
return self._parser.parse_args(args)
for op in extra:
if op.startswith("--"):
name = op.split("=")[0]
if len(name) <= 2:
continue
arg_handler.parser.add_argument(
name, type=str, nargs="+", metavar="Unknown"
)
arg_handler.correctors[name[2:]] = _yaml_typecaster(lambda a: a)
return self._parser.parse_args(args)
[docs]
def __call__(self, arg_list: List[str]) -> Any:
"""
Run the command line interface of the constructed CLI program.
:param arg_list: A list of arguments passed by the user from the command line, excluding the program name.
Equivalent to `sys.argv[1:]`.
"""
try:
res, extra = self._parser.parse_known_args(arg_list)
except TypeError as e:
# Python 3.7 argparse doesn't handle subcommand namespaces correctly when no arguments are passed to them
# (throws type error), we insert an empty string argument and reparse to get a more helpful error message
# and force argparse to print the usage string...
if not (str(e) == "sequence item 0: expected str instance, NoneType found"):
raise
res, extra = self._parser.parse_known_args([*arg_list, ""])
func = getattr(res, "_func", None)
if func is not None:
if extra:
# Attempt to reparse after adding the extra arguments in
# (if this is a function that accepts arbitrary flags)...
res = self._reparse(arg_list, extra, func)
del res._func
try:
return func(res)
except CLIError as e:
print(e)
self._parser.print_usage()
return None
else:
return self._parser.print_usage()
[docs]
def build_full_parser(
function_tree: dict, parent_parser: ArgumentParser, name: Optional[str] = None
) -> CLIEngine:
"""
Build an entire CLI interface with subcommands from a tree of typecaster functions.
:param function_tree: A nested dictionary of strings to type caster functions. Strings specify sub command words
that each type caster function should be referenced by.
:param parent_parser: The argument parser to add commands to, or parser for the entire program.
:param name: Name of the program.
:return: A CLIEngine, which represents a command line program.
"""
name = parent_parser.prog if (name is None) else name
parent_parser.allow_abbrev = False
sub_commands = parent_parser.add_subparsers(
title=f"Subcommands and namespaces of '{name}'", required=True
)
for command_name, sub_actions in function_tree.items():
if command_name.startswith("_"):
continue
if isinstance(sub_actions, dict):
sub_cmd_args = {
key[2:]: value
for key, value in sub_actions.items()
if (key.startswith("__"))
}
if "description" in sub_cmd_args:
sub_cmd_args["help"] = sub_cmd_args["description"]
sub_parser = sub_commands.add_parser(command_name, **sub_cmd_args)
build_full_parser(sub_actions, sub_parser, name + " " + command_name)
else:
doc_str = inspect.getdoc(sub_actions)
if doc_str is not None:
desc = get_summary_from_doc_str(doc_str)
sub_parser = sub_commands.add_parser(
command_name, description=desc, help=desc
)
else:
sub_parser = sub_commands.add_parser(command_name)
func_to_command(sub_actions, sub_parser)
return CLIEngine(parent_parser)
[docs]
def clear_extra_cli_args_and_copy(func: Callable):
"""
Create a copy of a typecaster function, with all CLI settings cleared.
"""
import copy
func = copy.copy(func)
if hasattr(func, "__extra_args"):
del func.__extra_args
if hasattr(func, "__auto_cast"):
del func.__auto_cast
if hasattr(func, "__orig_doc__"):
func.__doc__ = func.__orig_doc__
del func.__orig_doc__
if hasattr(func, "__clean_doc__"):
del func.__clean_doc__
return func
[docs]
def allow_arbitrary_flags(func: Callable) -> Callable:
"""
Decorator: Allow arbitrary CLI flags on a typecaster function. Additional CLI flags will be passed to the
wildcard keyword argument.
"""
func.__allow_arbitrary_flags = True
return func
[docs]
def positional_argument_count(amt: int) -> Callable[[Callable], Callable]:
"""
Decorator: Mark the first n arguments to this typecaster function as positional. Those arguments will have no
flag, and instead must be passed by position to the CLI.
:param amt: The number of first arguments to the function to mark as positional.
"""
def attach_pos_args(func: Callable) -> Callable:
func.__pos_cmd_arg_count = amt
return func
return attach_pos_args