Uploaded Test files
This commit is contained in:
parent
f584ad9d97
commit
2e81cb7d99
16627 changed files with 2065359 additions and 102444 deletions
Binary file not shown.
|
@ -0,0 +1,3 @@
|
|||
from .system import SystemCompleter
|
||||
|
||||
__all__ = ["SystemCompleter"]
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,62 @@
|
|||
from prompt_toolkit.completion.filesystem import ExecutableCompleter, PathCompleter
|
||||
from prompt_toolkit.contrib.regular_languages.compiler import compile
|
||||
from prompt_toolkit.contrib.regular_languages.completion import GrammarCompleter
|
||||
|
||||
__all__ = [
|
||||
"SystemCompleter",
|
||||
]
|
||||
|
||||
|
||||
class SystemCompleter(GrammarCompleter):
|
||||
"""
|
||||
Completer for system commands.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Compile grammar.
|
||||
g = compile(
|
||||
r"""
|
||||
# First we have an executable.
|
||||
(?P<executable>[^\s]+)
|
||||
|
||||
# Ignore literals in between.
|
||||
(
|
||||
\s+
|
||||
("[^"]*" | '[^']*' | [^'"]+ )
|
||||
)*
|
||||
|
||||
\s+
|
||||
|
||||
# Filename as parameters.
|
||||
(
|
||||
(?P<filename>[^\s]+) |
|
||||
"(?P<double_quoted_filename>[^\s]+)" |
|
||||
'(?P<single_quoted_filename>[^\s]+)'
|
||||
)
|
||||
""",
|
||||
escape_funcs={
|
||||
"double_quoted_filename": (lambda string: string.replace('"', '\\"')),
|
||||
"single_quoted_filename": (lambda string: string.replace("'", "\\'")),
|
||||
},
|
||||
unescape_funcs={
|
||||
"double_quoted_filename": (
|
||||
lambda string: string.replace('\\"', '"')
|
||||
), # XXX: not entirely correct.
|
||||
"single_quoted_filename": (lambda string: string.replace("\\'", "'")),
|
||||
},
|
||||
)
|
||||
|
||||
# Create GrammarCompleter
|
||||
super().__init__(
|
||||
g,
|
||||
{
|
||||
"executable": ExecutableCompleter(),
|
||||
"filename": PathCompleter(only_directories=False, expanduser=True),
|
||||
"double_quoted_filename": PathCompleter(
|
||||
only_directories=False, expanduser=True
|
||||
),
|
||||
"single_quoted_filename": PathCompleter(
|
||||
only_directories=False, expanduser=True
|
||||
),
|
||||
},
|
||||
)
|
|
@ -0,0 +1,77 @@
|
|||
r"""
|
||||
Tool for expressing the grammar of an input as a regular language.
|
||||
==================================================================
|
||||
|
||||
The grammar for the input of many simple command line interfaces can be
|
||||
expressed by a regular language. Examples are PDB (the Python debugger); a
|
||||
simple (bash-like) shell with "pwd", "cd", "cat" and "ls" commands; arguments
|
||||
that you can pass to an executable; etc. It is possible to use regular
|
||||
expressions for validation and parsing of such a grammar. (More about regular
|
||||
languages: http://en.wikipedia.org/wiki/Regular_language)
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
Let's take the pwd/cd/cat/ls example. We want to have a shell that accepts
|
||||
these three commands. "cd" is followed by a quoted directory name and "cat" is
|
||||
followed by a quoted file name. (We allow quotes inside the filename when
|
||||
they're escaped with a backslash.) We could define the grammar using the
|
||||
following regular expression::
|
||||
|
||||
grammar = \s* (
|
||||
pwd |
|
||||
ls |
|
||||
(cd \s+ " ([^"]|\.)+ ") |
|
||||
(cat \s+ " ([^"]|\.)+ ")
|
||||
) \s*
|
||||
|
||||
|
||||
What can we do with this grammar?
|
||||
---------------------------------
|
||||
|
||||
- Syntax highlighting: We could use this for instance to give file names
|
||||
different colour.
|
||||
- Parse the result: .. We can extract the file names and commands by using a
|
||||
regular expression with named groups.
|
||||
- Input validation: .. Don't accept anything that does not match this grammar.
|
||||
When combined with a parser, we can also recursively do
|
||||
filename validation (and accept only existing files.)
|
||||
- Autocompletion: .... Each part of the grammar can have its own autocompleter.
|
||||
"cat" has to be completed using file names, while "cd"
|
||||
has to be completed using directory names.
|
||||
|
||||
How does it work?
|
||||
-----------------
|
||||
|
||||
As a user of this library, you have to define the grammar of the input as a
|
||||
regular expression. The parts of this grammar where autocompletion, validation
|
||||
or any other processing is required need to be marked using a regex named
|
||||
group. Like ``(?P<varname>...)`` for instance.
|
||||
|
||||
When the input is processed for validation (for instance), the regex will
|
||||
execute, the named group is captured, and the validator associated with this
|
||||
named group will test the captured string.
|
||||
|
||||
There is one tricky bit:
|
||||
|
||||
Often we operate on incomplete input (this is by definition the case for
|
||||
autocompletion) and we have to decide for the cursor position in which
|
||||
possible state the grammar it could be and in which way variables could be
|
||||
matched up to that point.
|
||||
|
||||
To solve this problem, the compiler takes the original regular expression and
|
||||
translates it into a set of other regular expressions which each match certain
|
||||
prefixes of the original regular expression. We generate one prefix regular
|
||||
expression for every named variable (with this variable being the end of that
|
||||
expression).
|
||||
|
||||
|
||||
TODO: some examples of:
|
||||
- How to create a highlighter from this grammar.
|
||||
- How to create a validator from this grammar.
|
||||
- How to create an autocompleter from this grammar.
|
||||
- How to create a parser from this grammar.
|
||||
"""
|
||||
from .compiler import compile
|
||||
|
||||
__all__ = ["compile"]
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,573 @@
|
|||
r"""
|
||||
Compiler for a regular grammar.
|
||||
|
||||
Example usage::
|
||||
|
||||
# Create and compile grammar.
|
||||
p = compile('add \s+ (?P<var1>[^\s]+) \s+ (?P<var2>[^\s]+)')
|
||||
|
||||
# Match input string.
|
||||
m = p.match('add 23 432')
|
||||
|
||||
# Get variables.
|
||||
m.variables().get('var1') # Returns "23"
|
||||
m.variables().get('var2') # Returns "432"
|
||||
|
||||
|
||||
Partial matches are possible::
|
||||
|
||||
# Create and compile grammar.
|
||||
p = compile('''
|
||||
# Operators with two arguments.
|
||||
((?P<operator1>[^\s]+) \s+ (?P<var1>[^\s]+) \s+ (?P<var2>[^\s]+)) |
|
||||
|
||||
# Operators with only one arguments.
|
||||
((?P<operator2>[^\s]+) \s+ (?P<var1>[^\s]+))
|
||||
''')
|
||||
|
||||
# Match partial input string.
|
||||
m = p.match_prefix('add 23')
|
||||
|
||||
# Get variables. (Notice that both operator1 and operator2 contain the
|
||||
# value "add".) This is because our input is incomplete, and we don't know
|
||||
# yet in which rule of the regex we we'll end up. It could also be that
|
||||
# `operator1` and `operator2` have a different autocompleter and we want to
|
||||
# call all possible autocompleters that would result in valid input.)
|
||||
m.variables().get('var1') # Returns "23"
|
||||
m.variables().get('operator1') # Returns "add"
|
||||
m.variables().get('operator2') # Returns "add"
|
||||
|
||||
"""
|
||||
import re
|
||||
from typing import Callable, Dict, Iterable, Iterator, List
|
||||
from typing import Match as RegexMatch
|
||||
from typing import Optional, Pattern, Tuple, cast
|
||||
|
||||
from .regex_parser import (
|
||||
AnyNode,
|
||||
Lookahead,
|
||||
Node,
|
||||
NodeSequence,
|
||||
Regex,
|
||||
Repeat,
|
||||
Variable,
|
||||
parse_regex,
|
||||
tokenize_regex,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"compile",
|
||||
]
|
||||
|
||||
|
||||
# Name of the named group in the regex, matching trailing input.
|
||||
# (Trailing input is when the input contains characters after the end of the
|
||||
# expression has been matched.)
|
||||
_INVALID_TRAILING_INPUT = "invalid_trailing"
|
||||
|
||||
EscapeFuncDict = Dict[str, Callable[[str], str]]
|
||||
|
||||
|
||||
class _CompiledGrammar:
|
||||
"""
|
||||
Compiles a grammar. This will take the parse tree of a regular expression
|
||||
and compile the grammar.
|
||||
|
||||
:param root_node: :class~`.regex_parser.Node` instance.
|
||||
:param escape_funcs: `dict` mapping variable names to escape callables.
|
||||
:param unescape_funcs: `dict` mapping variable names to unescape callables.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_node: Node,
|
||||
escape_funcs: Optional[EscapeFuncDict] = None,
|
||||
unescape_funcs: Optional[EscapeFuncDict] = None,
|
||||
) -> None:
|
||||
|
||||
self.root_node = root_node
|
||||
self.escape_funcs = escape_funcs or {}
|
||||
self.unescape_funcs = unescape_funcs or {}
|
||||
|
||||
#: Dictionary that will map the regex names to Node instances.
|
||||
self._group_names_to_nodes: Dict[
|
||||
str, str
|
||||
] = {} # Maps regex group names to varnames.
|
||||
counter = [0]
|
||||
|
||||
def create_group_func(node: Variable) -> str:
|
||||
name = "n%s" % counter[0]
|
||||
self._group_names_to_nodes[name] = node.varname
|
||||
counter[0] += 1
|
||||
return name
|
||||
|
||||
# Compile regex strings.
|
||||
self._re_pattern = "^%s$" % self._transform(root_node, create_group_func)
|
||||
self._re_prefix_patterns = list(
|
||||
self._transform_prefix(root_node, create_group_func)
|
||||
)
|
||||
|
||||
# Compile the regex itself.
|
||||
flags = re.DOTALL # Note that we don't need re.MULTILINE! (^ and $
|
||||
# still represent the start and end of input text.)
|
||||
self._re = re.compile(self._re_pattern, flags)
|
||||
self._re_prefix = [re.compile(t, flags) for t in self._re_prefix_patterns]
|
||||
|
||||
# We compile one more set of regexes, similar to `_re_prefix`, but accept any trailing
|
||||
# input. This will ensure that we can still highlight the input correctly, even when the
|
||||
# input contains some additional characters at the end that don't match the grammar.)
|
||||
self._re_prefix_with_trailing_input = [
|
||||
re.compile(
|
||||
r"(?:%s)(?P<%s>.*?)$" % (t.rstrip("$"), _INVALID_TRAILING_INPUT), flags
|
||||
)
|
||||
for t in self._re_prefix_patterns
|
||||
]
|
||||
|
||||
def escape(self, varname: str, value: str) -> str:
|
||||
"""
|
||||
Escape `value` to fit in the place of this variable into the grammar.
|
||||
"""
|
||||
f = self.escape_funcs.get(varname)
|
||||
return f(value) if f else value
|
||||
|
||||
def unescape(self, varname: str, value: str) -> str:
|
||||
"""
|
||||
Unescape `value`.
|
||||
"""
|
||||
f = self.unescape_funcs.get(varname)
|
||||
return f(value) if f else value
|
||||
|
||||
@classmethod
|
||||
def _transform(
|
||||
cls, root_node: Node, create_group_func: Callable[[Variable], str]
|
||||
) -> str:
|
||||
"""
|
||||
Turn a :class:`Node` object into a regular expression.
|
||||
|
||||
:param root_node: The :class:`Node` instance for which we generate the grammar.
|
||||
:param create_group_func: A callable which takes a `Node` and returns the next
|
||||
free name for this node.
|
||||
"""
|
||||
|
||||
def transform(node: Node) -> str:
|
||||
# Turn `AnyNode` into an OR.
|
||||
if isinstance(node, AnyNode):
|
||||
return "(?:%s)" % "|".join(transform(c) for c in node.children)
|
||||
|
||||
# Concatenate a `NodeSequence`
|
||||
elif isinstance(node, NodeSequence):
|
||||
return "".join(transform(c) for c in node.children)
|
||||
|
||||
# For Regex and Lookahead nodes, just insert them literally.
|
||||
elif isinstance(node, Regex):
|
||||
return node.regex
|
||||
|
||||
elif isinstance(node, Lookahead):
|
||||
before = "(?!" if node.negative else "(="
|
||||
return before + transform(node.childnode) + ")"
|
||||
|
||||
# A `Variable` wraps the children into a named group.
|
||||
elif isinstance(node, Variable):
|
||||
return "(?P<%s>%s)" % (
|
||||
create_group_func(node),
|
||||
transform(node.childnode),
|
||||
)
|
||||
|
||||
# `Repeat`.
|
||||
elif isinstance(node, Repeat):
|
||||
if node.max_repeat is None:
|
||||
if node.min_repeat == 0:
|
||||
repeat_sign = "*"
|
||||
elif node.min_repeat == 1:
|
||||
repeat_sign = "+"
|
||||
else:
|
||||
repeat_sign = "{%i,%s}" % (
|
||||
node.min_repeat,
|
||||
("" if node.max_repeat is None else str(node.max_repeat)),
|
||||
)
|
||||
|
||||
return "(?:%s)%s%s" % (
|
||||
transform(node.childnode),
|
||||
repeat_sign,
|
||||
("" if node.greedy else "?"),
|
||||
)
|
||||
else:
|
||||
raise TypeError("Got %r" % (node,))
|
||||
|
||||
return transform(root_node)
|
||||
|
||||
@classmethod
|
||||
def _transform_prefix(
|
||||
cls, root_node: Node, create_group_func: Callable[[Variable], str]
|
||||
) -> Iterable[str]:
|
||||
"""
|
||||
Yield all the regular expressions matching a prefix of the grammar
|
||||
defined by the `Node` instance.
|
||||
|
||||
For each `Variable`, one regex pattern will be generated, with this
|
||||
named group at the end. This is required because a regex engine will
|
||||
terminate once a match is found. For autocompletion however, we need
|
||||
the matches for all possible paths, so that we can provide completions
|
||||
for each `Variable`.
|
||||
|
||||
- So, in the case of an `Any` (`A|B|C)', we generate a pattern for each
|
||||
clause. This is one for `A`, one for `B` and one for `C`. Unless some
|
||||
groups don't contain a `Variable`, then these can be merged together.
|
||||
- In the case of a `NodeSequence` (`ABC`), we generate a pattern for
|
||||
each prefix that ends with a variable, and one pattern for the whole
|
||||
sequence. So, that's one for `A`, one for `AB` and one for `ABC`.
|
||||
|
||||
:param root_node: The :class:`Node` instance for which we generate the grammar.
|
||||
:param create_group_func: A callable which takes a `Node` and returns the next
|
||||
free name for this node.
|
||||
"""
|
||||
|
||||
def contains_variable(node: Node) -> bool:
|
||||
if isinstance(node, Regex):
|
||||
return False
|
||||
elif isinstance(node, Variable):
|
||||
return True
|
||||
elif isinstance(node, (Lookahead, Repeat)):
|
||||
return contains_variable(node.childnode)
|
||||
elif isinstance(node, (NodeSequence, AnyNode)):
|
||||
return any(contains_variable(child) for child in node.children)
|
||||
|
||||
return False
|
||||
|
||||
def transform(node: Node) -> Iterable[str]:
|
||||
# Generate separate pattern for all terms that contain variables
|
||||
# within this OR. Terms that don't contain a variable can be merged
|
||||
# together in one pattern.
|
||||
if isinstance(node, AnyNode):
|
||||
# If we have a definition like:
|
||||
# (?P<name> .*) | (?P<city> .*)
|
||||
# Then we want to be able to generate completions for both the
|
||||
# name as well as the city. We do this by yielding two
|
||||
# different regular expressions, because the engine won't
|
||||
# follow multiple paths, if multiple are possible.
|
||||
children_with_variable = []
|
||||
children_without_variable = []
|
||||
for c in node.children:
|
||||
if contains_variable(c):
|
||||
children_with_variable.append(c)
|
||||
else:
|
||||
children_without_variable.append(c)
|
||||
|
||||
for c in children_with_variable:
|
||||
yield from transform(c)
|
||||
|
||||
# Merge options without variable together.
|
||||
if children_without_variable:
|
||||
yield "|".join(
|
||||
r for c in children_without_variable for r in transform(c)
|
||||
)
|
||||
|
||||
# For a sequence, generate a pattern for each prefix that ends with
|
||||
# a variable + one pattern of the complete sequence.
|
||||
# (This is because, for autocompletion, we match the text before
|
||||
# the cursor, and completions are given for the variable that we
|
||||
# match right before the cursor.)
|
||||
elif isinstance(node, NodeSequence):
|
||||
# For all components in the sequence, compute prefix patterns,
|
||||
# as well as full patterns.
|
||||
complete = [cls._transform(c, create_group_func) for c in node.children]
|
||||
prefixes = [list(transform(c)) for c in node.children]
|
||||
variable_nodes = [contains_variable(c) for c in node.children]
|
||||
|
||||
# If any child is contains a variable, we should yield a
|
||||
# pattern up to that point, so that we are sure this will be
|
||||
# matched.
|
||||
for i in range(len(node.children)):
|
||||
if variable_nodes[i]:
|
||||
for c_str in prefixes[i]:
|
||||
yield "".join(complete[:i]) + c_str
|
||||
|
||||
# If there are non-variable nodes, merge all the prefixes into
|
||||
# one pattern. If the input is: "[part1] [part2] [part3]", then
|
||||
# this gets compiled into:
|
||||
# (complete1 + (complete2 + (complete3 | partial3) | partial2) | partial1 )
|
||||
# For nodes that contain a variable, we skip the "|partial"
|
||||
# part here, because thees are matched with the previous
|
||||
# patterns.
|
||||
if not all(variable_nodes):
|
||||
result = []
|
||||
|
||||
# Start with complete patterns.
|
||||
for i in range(len(node.children)):
|
||||
result.append("(?:")
|
||||
result.append(complete[i])
|
||||
|
||||
# Add prefix patterns.
|
||||
for i in range(len(node.children) - 1, -1, -1):
|
||||
if variable_nodes[i]:
|
||||
# No need to yield a prefix for this one, we did
|
||||
# the variable prefixes earlier.
|
||||
result.append(")")
|
||||
else:
|
||||
result.append("|(?:")
|
||||
# If this yields multiple, we should yield all combinations.
|
||||
assert len(prefixes[i]) == 1
|
||||
result.append(prefixes[i][0])
|
||||
result.append("))")
|
||||
|
||||
yield "".join(result)
|
||||
|
||||
elif isinstance(node, Regex):
|
||||
yield "(?:%s)?" % node.regex
|
||||
|
||||
elif isinstance(node, Lookahead):
|
||||
if node.negative:
|
||||
yield "(?!%s)" % cls._transform(node.childnode, create_group_func)
|
||||
else:
|
||||
# Not sure what the correct semantics are in this case.
|
||||
# (Probably it's not worth implementing this.)
|
||||
raise Exception("Positive lookahead not yet supported.")
|
||||
|
||||
elif isinstance(node, Variable):
|
||||
# (Note that we should not append a '?' here. the 'transform'
|
||||
# method will already recursively do that.)
|
||||
for c_str in transform(node.childnode):
|
||||
yield "(?P<%s>%s)" % (create_group_func(node), c_str)
|
||||
|
||||
elif isinstance(node, Repeat):
|
||||
# If we have a repetition of 8 times. That would mean that the
|
||||
# current input could have for instance 7 times a complete
|
||||
# match, followed by a partial match.
|
||||
prefix = cls._transform(node.childnode, create_group_func)
|
||||
|
||||
if node.max_repeat == 1:
|
||||
yield from transform(node.childnode)
|
||||
else:
|
||||
for c_str in transform(node.childnode):
|
||||
if node.max_repeat:
|
||||
repeat_sign = "{,%i}" % (node.max_repeat - 1)
|
||||
else:
|
||||
repeat_sign = "*"
|
||||
yield "(?:%s)%s%s%s" % (
|
||||
prefix,
|
||||
repeat_sign,
|
||||
("" if node.greedy else "?"),
|
||||
c_str,
|
||||
)
|
||||
|
||||
else:
|
||||
raise TypeError("Got %r" % node)
|
||||
|
||||
for r in transform(root_node):
|
||||
yield "^(?:%s)$" % r
|
||||
|
||||
def match(self, string: str) -> Optional["Match"]:
|
||||
"""
|
||||
Match the string with the grammar.
|
||||
Returns a :class:`Match` instance or `None` when the input doesn't match the grammar.
|
||||
|
||||
:param string: The input string.
|
||||
"""
|
||||
m = self._re.match(string)
|
||||
|
||||
if m:
|
||||
return Match(
|
||||
string, [(self._re, m)], self._group_names_to_nodes, self.unescape_funcs
|
||||
)
|
||||
return None
|
||||
|
||||
def match_prefix(self, string: str) -> Optional["Match"]:
|
||||
"""
|
||||
Do a partial match of the string with the grammar. The returned
|
||||
:class:`Match` instance can contain multiple representations of the
|
||||
match. This will never return `None`. If it doesn't match at all, the "trailing input"
|
||||
part will capture all of the input.
|
||||
|
||||
:param string: The input string.
|
||||
"""
|
||||
# First try to match using `_re_prefix`. If nothing is found, use the patterns that
|
||||
# also accept trailing characters.
|
||||
for patterns in [self._re_prefix, self._re_prefix_with_trailing_input]:
|
||||
matches = [(r, r.match(string)) for r in patterns]
|
||||
matches2 = [(r, m) for r, m in matches if m]
|
||||
|
||||
if matches2 != []:
|
||||
return Match(
|
||||
string, matches2, self._group_names_to_nodes, self.unescape_funcs
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class Match:
|
||||
"""
|
||||
:param string: The input string.
|
||||
:param re_matches: List of (compiled_re_pattern, re_match) tuples.
|
||||
:param group_names_to_nodes: Dictionary mapping all the re group names to the matching Node instances.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
string: str,
|
||||
re_matches: List[Tuple[Pattern[str], RegexMatch[str]]],
|
||||
group_names_to_nodes: Dict[str, str],
|
||||
unescape_funcs: Dict[str, Callable[[str], str]],
|
||||
):
|
||||
self.string = string
|
||||
self._re_matches = re_matches
|
||||
self._group_names_to_nodes = group_names_to_nodes
|
||||
self._unescape_funcs = unescape_funcs
|
||||
|
||||
def _nodes_to_regs(self) -> List[Tuple[str, Tuple[int, int]]]:
|
||||
"""
|
||||
Return a list of (varname, reg) tuples.
|
||||
"""
|
||||
|
||||
def get_tuples() -> Iterable[Tuple[str, Tuple[int, int]]]:
|
||||
for r, re_match in self._re_matches:
|
||||
for group_name, group_index in r.groupindex.items():
|
||||
if group_name != _INVALID_TRAILING_INPUT:
|
||||
regs = cast(Tuple[Tuple[int, int], ...], re_match.regs)
|
||||
reg = regs[group_index]
|
||||
node = self._group_names_to_nodes[group_name]
|
||||
yield (node, reg)
|
||||
|
||||
return list(get_tuples())
|
||||
|
||||
def _nodes_to_values(self) -> List[Tuple[str, str, Tuple[int, int]]]:
|
||||
"""
|
||||
Returns list of (Node, string_value) tuples.
|
||||
"""
|
||||
|
||||
def is_none(sl: Tuple[int, int]) -> bool:
|
||||
return sl[0] == -1 and sl[1] == -1
|
||||
|
||||
def get(sl: Tuple[int, int]) -> str:
|
||||
return self.string[sl[0] : sl[1]]
|
||||
|
||||
return [
|
||||
(varname, get(slice), slice)
|
||||
for varname, slice in self._nodes_to_regs()
|
||||
if not is_none(slice)
|
||||
]
|
||||
|
||||
def _unescape(self, varname: str, value: str) -> str:
|
||||
unwrapper = self._unescape_funcs.get(varname)
|
||||
return unwrapper(value) if unwrapper else value
|
||||
|
||||
def variables(self) -> "Variables":
|
||||
"""
|
||||
Returns :class:`Variables` instance.
|
||||
"""
|
||||
return Variables(
|
||||
[(k, self._unescape(k, v), sl) for k, v, sl in self._nodes_to_values()]
|
||||
)
|
||||
|
||||
def trailing_input(self) -> Optional["MatchVariable"]:
|
||||
"""
|
||||
Get the `MatchVariable` instance, representing trailing input, if there is any.
|
||||
"Trailing input" is input at the end that does not match the grammar anymore, but
|
||||
when this is removed from the end of the input, the input would be a valid string.
|
||||
"""
|
||||
slices: List[Tuple[int, int]] = []
|
||||
|
||||
# Find all regex group for the name _INVALID_TRAILING_INPUT.
|
||||
for r, re_match in self._re_matches:
|
||||
for group_name, group_index in r.groupindex.items():
|
||||
if group_name == _INVALID_TRAILING_INPUT:
|
||||
slices.append(re_match.regs[group_index])
|
||||
|
||||
# Take the smallest part. (Smaller trailing text means that a larger input has
|
||||
# been matched, so that is better.)
|
||||
if slices:
|
||||
slice = (max(i[0] for i in slices), max(i[1] for i in slices))
|
||||
value = self.string[slice[0] : slice[1]]
|
||||
return MatchVariable("<trailing_input>", value, slice)
|
||||
return None
|
||||
|
||||
def end_nodes(self) -> Iterable["MatchVariable"]:
|
||||
"""
|
||||
Yields `MatchVariable` instances for all the nodes having their end
|
||||
position at the end of the input string.
|
||||
"""
|
||||
for varname, reg in self._nodes_to_regs():
|
||||
# If this part goes until the end of the input string.
|
||||
if reg[1] == len(self.string):
|
||||
value = self._unescape(varname, self.string[reg[0] : reg[1]])
|
||||
yield MatchVariable(varname, value, (reg[0], reg[1]))
|
||||
|
||||
|
||||
class Variables:
|
||||
def __init__(self, tuples: List[Tuple[str, str, Tuple[int, int]]]) -> None:
|
||||
#: List of (varname, value, slice) tuples.
|
||||
self._tuples = tuples
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "%s(%s)" % (
|
||||
self.__class__.__name__,
|
||||
", ".join("%s=%r" % (k, v) for k, v, _ in self._tuples),
|
||||
)
|
||||
|
||||
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
items = self.getall(key)
|
||||
return items[0] if items else default
|
||||
|
||||
def getall(self, key: str) -> List[str]:
|
||||
return [v for k, v, _ in self._tuples if k == key]
|
||||
|
||||
def __getitem__(self, key: str) -> Optional[str]:
|
||||
return self.get(key)
|
||||
|
||||
def __iter__(self) -> Iterator["MatchVariable"]:
|
||||
"""
|
||||
Yield `MatchVariable` instances.
|
||||
"""
|
||||
for varname, value, slice in self._tuples:
|
||||
yield MatchVariable(varname, value, slice)
|
||||
|
||||
|
||||
class MatchVariable:
|
||||
"""
|
||||
Represents a match of a variable in the grammar.
|
||||
|
||||
:param varname: (string) Name of the variable.
|
||||
:param value: (string) Value of this variable.
|
||||
:param slice: (start, stop) tuple, indicating the position of this variable
|
||||
in the input string.
|
||||
"""
|
||||
|
||||
def __init__(self, varname: str, value: str, slice: Tuple[int, int]) -> None:
|
||||
self.varname = varname
|
||||
self.value = value
|
||||
self.slice = slice
|
||||
|
||||
self.start = self.slice[0]
|
||||
self.stop = self.slice[1]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "%s(%r, %r)" % (self.__class__.__name__, self.varname, self.value)
|
||||
|
||||
|
||||
def compile(
|
||||
expression: str,
|
||||
escape_funcs: Optional[EscapeFuncDict] = None,
|
||||
unescape_funcs: Optional[EscapeFuncDict] = None,
|
||||
) -> _CompiledGrammar:
|
||||
"""
|
||||
Compile grammar (given as regex string), returning a `CompiledGrammar`
|
||||
instance.
|
||||
"""
|
||||
return _compile_from_parse_tree(
|
||||
parse_regex(tokenize_regex(expression)),
|
||||
escape_funcs=escape_funcs,
|
||||
unescape_funcs=unescape_funcs,
|
||||
)
|
||||
|
||||
|
||||
def _compile_from_parse_tree(
|
||||
root_node: Node,
|
||||
escape_funcs: Optional[EscapeFuncDict] = None,
|
||||
unescape_funcs: Optional[EscapeFuncDict] = None,
|
||||
) -> _CompiledGrammar:
|
||||
"""
|
||||
Compile grammar (given as parse tree), returning a `CompiledGrammar`
|
||||
instance.
|
||||
"""
|
||||
return _CompiledGrammar(
|
||||
root_node, escape_funcs=escape_funcs, unescape_funcs=unescape_funcs
|
||||
)
|
|
@ -0,0 +1,94 @@
|
|||
"""
|
||||
Completer for a regular grammar.
|
||||
"""
|
||||
from typing import Dict, Iterable, List
|
||||
|
||||
from prompt_toolkit.completion import CompleteEvent, Completer, Completion
|
||||
from prompt_toolkit.document import Document
|
||||
|
||||
from .compiler import Match, _CompiledGrammar
|
||||
|
||||
__all__ = [
|
||||
"GrammarCompleter",
|
||||
]
|
||||
|
||||
|
||||
class GrammarCompleter(Completer):
|
||||
"""
|
||||
Completer which can be used for autocompletion according to variables in
|
||||
the grammar. Each variable can have a different autocompleter.
|
||||
|
||||
:param compiled_grammar: `GrammarCompleter` instance.
|
||||
:param completers: `dict` mapping variable names of the grammar to the
|
||||
`Completer` instances to be used for each variable.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, compiled_grammar: _CompiledGrammar, completers: Dict[str, Completer]
|
||||
) -> None:
|
||||
|
||||
self.compiled_grammar = compiled_grammar
|
||||
self.completers = completers
|
||||
|
||||
def get_completions(
|
||||
self, document: Document, complete_event: CompleteEvent
|
||||
) -> Iterable[Completion]:
|
||||
m = self.compiled_grammar.match_prefix(document.text_before_cursor)
|
||||
|
||||
if m:
|
||||
completions = self._remove_duplicates(
|
||||
self._get_completions_for_match(m, complete_event)
|
||||
)
|
||||
|
||||
for c in completions:
|
||||
yield c
|
||||
|
||||
def _get_completions_for_match(
|
||||
self, match: Match, complete_event: CompleteEvent
|
||||
) -> Iterable[Completion]:
|
||||
"""
|
||||
Yield all the possible completions for this input string.
|
||||
(The completer assumes that the cursor position was at the end of the
|
||||
input string.)
|
||||
"""
|
||||
for match_variable in match.end_nodes():
|
||||
varname = match_variable.varname
|
||||
start = match_variable.start
|
||||
|
||||
completer = self.completers.get(varname)
|
||||
|
||||
if completer:
|
||||
text = match_variable.value
|
||||
|
||||
# Unwrap text.
|
||||
unwrapped_text = self.compiled_grammar.unescape(varname, text)
|
||||
|
||||
# Create a document, for the completions API (text/cursor_position)
|
||||
document = Document(unwrapped_text, len(unwrapped_text))
|
||||
|
||||
# Call completer
|
||||
for completion in completer.get_completions(document, complete_event):
|
||||
new_text = (
|
||||
unwrapped_text[: len(text) + completion.start_position]
|
||||
+ completion.text
|
||||
)
|
||||
|
||||
# Wrap again.
|
||||
yield Completion(
|
||||
text=self.compiled_grammar.escape(varname, new_text),
|
||||
start_position=start - len(match.string),
|
||||
display=completion.display,
|
||||
display_meta=completion.display_meta,
|
||||
)
|
||||
|
||||
def _remove_duplicates(self, items: Iterable[Completion]) -> List[Completion]:
|
||||
"""
|
||||
Remove duplicates, while keeping the order.
|
||||
(Sometimes we have duplicates, because the there several matches of the
|
||||
same grammar, each yielding similar completions.)
|
||||
"""
|
||||
result: List[Completion] = []
|
||||
for i in items:
|
||||
if i not in result:
|
||||
result.append(i)
|
||||
return result
|
|
@ -0,0 +1,92 @@
|
|||
"""
|
||||
`GrammarLexer` is compatible with other lexers and can be used to highlight
|
||||
the input using a regular grammar with annotations.
|
||||
"""
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
from prompt_toolkit.document import Document
|
||||
from prompt_toolkit.formatted_text.base import StyleAndTextTuples
|
||||
from prompt_toolkit.formatted_text.utils import split_lines
|
||||
from prompt_toolkit.lexers import Lexer
|
||||
|
||||
from .compiler import _CompiledGrammar
|
||||
|
||||
__all__ = [
|
||||
"GrammarLexer",
|
||||
]
|
||||
|
||||
|
||||
class GrammarLexer(Lexer):
|
||||
"""
|
||||
Lexer which can be used for highlighting of fragments according to variables in the grammar.
|
||||
|
||||
(It does not actual lexing of the string, but it exposes an API, compatible
|
||||
with the Pygments lexer class.)
|
||||
|
||||
:param compiled_grammar: Grammar as returned by the `compile()` function.
|
||||
:param lexers: Dictionary mapping variable names of the regular grammar to
|
||||
the lexers that should be used for this part. (This can
|
||||
call other lexers recursively.) If you wish a part of the
|
||||
grammar to just get one fragment, use a
|
||||
`prompt_toolkit.lexers.SimpleLexer`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
compiled_grammar: _CompiledGrammar,
|
||||
default_style: str = "",
|
||||
lexers: Optional[Dict[str, Lexer]] = None,
|
||||
) -> None:
|
||||
|
||||
self.compiled_grammar = compiled_grammar
|
||||
self.default_style = default_style
|
||||
self.lexers = lexers or {}
|
||||
|
||||
def _get_text_fragments(self, text: str) -> StyleAndTextTuples:
|
||||
m = self.compiled_grammar.match_prefix(text)
|
||||
|
||||
if m:
|
||||
characters: StyleAndTextTuples = [(self.default_style, c) for c in text]
|
||||
|
||||
for v in m.variables():
|
||||
# If we have a `Lexer` instance for this part of the input.
|
||||
# Tokenize recursively and apply tokens.
|
||||
lexer = self.lexers.get(v.varname)
|
||||
|
||||
if lexer:
|
||||
document = Document(text[v.start : v.stop])
|
||||
lexer_tokens_for_line = lexer.lex_document(document)
|
||||
text_fragments: StyleAndTextTuples = []
|
||||
for i in range(len(document.lines)):
|
||||
text_fragments.extend(lexer_tokens_for_line(i))
|
||||
text_fragments.append(("", "\n"))
|
||||
if text_fragments:
|
||||
text_fragments.pop()
|
||||
|
||||
i = v.start
|
||||
for t, s, *_ in text_fragments:
|
||||
for c in s:
|
||||
if characters[i][0] == self.default_style:
|
||||
characters[i] = (t, characters[i][1])
|
||||
i += 1
|
||||
|
||||
# Highlight trailing input.
|
||||
trailing_input = m.trailing_input()
|
||||
if trailing_input:
|
||||
for i in range(trailing_input.start, trailing_input.stop):
|
||||
characters[i] = ("class:trailing-input", characters[i][1])
|
||||
|
||||
return characters
|
||||
else:
|
||||
return [("", text)]
|
||||
|
||||
def lex_document(self, document: Document) -> Callable[[int], StyleAndTextTuples]:
|
||||
lines = list(split_lines(self._get_text_fragments(document.text)))
|
||||
|
||||
def get_line(lineno: int) -> StyleAndTextTuples:
|
||||
try:
|
||||
return lines[lineno]
|
||||
except IndexError:
|
||||
return []
|
||||
|
||||
return get_line
|
|
@ -0,0 +1,281 @@
|
|||
"""
|
||||
Parser for parsing a regular expression.
|
||||
Take a string representing a regular expression and return the root node of its
|
||||
parse tree.
|
||||
|
||||
usage::
|
||||
|
||||
root_node = parse_regex('(hello|world)')
|
||||
|
||||
Remarks:
|
||||
- The regex parser processes multiline, it ignores all whitespace and supports
|
||||
multiple named groups with the same name and #-style comments.
|
||||
|
||||
Limitations:
|
||||
- Lookahead is not supported.
|
||||
"""
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
__all__ = [
|
||||
"Repeat",
|
||||
"Variable",
|
||||
"Regex",
|
||||
"Lookahead",
|
||||
"tokenize_regex",
|
||||
"parse_regex",
|
||||
]
|
||||
|
||||
|
||||
class Node:
|
||||
"""
|
||||
Base class for all the grammar nodes.
|
||||
(You don't initialize this one.)
|
||||
"""
|
||||
|
||||
def __add__(self, other_node: "Node") -> "NodeSequence":
|
||||
return NodeSequence([self, other_node])
|
||||
|
||||
def __or__(self, other_node: "Node") -> "AnyNode":
|
||||
return AnyNode([self, other_node])
|
||||
|
||||
|
||||
class AnyNode(Node):
|
||||
"""
|
||||
Union operation (OR operation) between several grammars. You don't
|
||||
initialize this yourself, but it's a result of a "Grammar1 | Grammar2"
|
||||
operation.
|
||||
"""
|
||||
|
||||
def __init__(self, children: List[Node]) -> None:
|
||||
self.children = children
|
||||
|
||||
def __or__(self, other_node: Node) -> "AnyNode":
|
||||
return AnyNode(self.children + [other_node])
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "%s(%r)" % (self.__class__.__name__, self.children)
|
||||
|
||||
|
||||
class NodeSequence(Node):
|
||||
"""
|
||||
Concatenation operation of several grammars. You don't initialize this
|
||||
yourself, but it's a result of a "Grammar1 + Grammar2" operation.
|
||||
"""
|
||||
|
||||
def __init__(self, children: List[Node]) -> None:
|
||||
self.children = children
|
||||
|
||||
def __add__(self, other_node: Node) -> "NodeSequence":
|
||||
return NodeSequence(self.children + [other_node])
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "%s(%r)" % (self.__class__.__name__, self.children)
|
||||
|
||||
|
||||
class Regex(Node):
|
||||
"""
|
||||
Regular expression.
|
||||
"""
|
||||
|
||||
def __init__(self, regex: str) -> None:
|
||||
re.compile(regex) # Validate
|
||||
|
||||
self.regex = regex
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "%s(/%s/)" % (self.__class__.__name__, self.regex)
|
||||
|
||||
|
||||
class Lookahead(Node):
|
||||
"""
|
||||
Lookahead expression.
|
||||
"""
|
||||
|
||||
def __init__(self, childnode: Node, negative: bool = False) -> None:
|
||||
self.childnode = childnode
|
||||
self.negative = negative
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "%s(%r)" % (self.__class__.__name__, self.childnode)
|
||||
|
||||
|
||||
class Variable(Node):
|
||||
"""
|
||||
Mark a variable in the regular grammar. This will be translated into a
|
||||
named group. Each variable can have his own completer, validator, etc..
|
||||
|
||||
:param childnode: The grammar which is wrapped inside this variable.
|
||||
:param varname: String.
|
||||
"""
|
||||
|
||||
def __init__(self, childnode: Node, varname: str = "") -> None:
|
||||
self.childnode = childnode
|
||||
self.varname = varname
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "%s(childnode=%r, varname=%r)" % (
|
||||
self.__class__.__name__,
|
||||
self.childnode,
|
||||
self.varname,
|
||||
)
|
||||
|
||||
|
||||
class Repeat(Node):
|
||||
def __init__(
|
||||
self,
|
||||
childnode: Node,
|
||||
min_repeat: int = 0,
|
||||
max_repeat: Optional[int] = None,
|
||||
greedy: bool = True,
|
||||
) -> None:
|
||||
self.childnode = childnode
|
||||
self.min_repeat = min_repeat
|
||||
self.max_repeat = max_repeat
|
||||
self.greedy = greedy
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "%s(childnode=%r)" % (self.__class__.__name__, self.childnode)
|
||||
|
||||
|
||||
def tokenize_regex(input: str) -> List[str]:
|
||||
"""
|
||||
Takes a string, representing a regular expression as input, and tokenizes
|
||||
it.
|
||||
|
||||
:param input: string, representing a regular expression.
|
||||
:returns: List of tokens.
|
||||
"""
|
||||
# Regular expression for tokenizing other regular expressions.
|
||||
p = re.compile(
|
||||
r"""^(
|
||||
\(\?P\<[a-zA-Z0-9_-]+\> | # Start of named group.
|
||||
\(\?#[^)]*\) | # Comment
|
||||
\(\?= | # Start of lookahead assertion
|
||||
\(\?! | # Start of negative lookahead assertion
|
||||
\(\?<= | # If preceded by.
|
||||
\(\?< | # If not preceded by.
|
||||
\(?: | # Start of group. (non capturing.)
|
||||
\( | # Start of group.
|
||||
\(?[iLmsux] | # Flags.
|
||||
\(?P=[a-zA-Z]+\) | # Back reference to named group
|
||||
\) | # End of group.
|
||||
\{[^{}]*\} | # Repetition
|
||||
\*\? | \+\? | \?\?\ | # Non greedy repetition.
|
||||
\* | \+ | \? | # Repetition
|
||||
\#.*\n | # Comment
|
||||
\\. |
|
||||
|
||||
# Character group.
|
||||
\[
|
||||
( [^\]\\] | \\.)*
|
||||
\] |
|
||||
|
||||
[^(){}] |
|
||||
.
|
||||
)""",
|
||||
re.VERBOSE,
|
||||
)
|
||||
|
||||
tokens = []
|
||||
|
||||
while input:
|
||||
m = p.match(input)
|
||||
if m:
|
||||
token, input = input[: m.end()], input[m.end() :]
|
||||
if not token.isspace():
|
||||
tokens.append(token)
|
||||
else:
|
||||
raise Exception("Could not tokenize input regex.")
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
def parse_regex(regex_tokens: List[str]) -> Node:
|
||||
"""
|
||||
Takes a list of tokens from the tokenizer, and returns a parse tree.
|
||||
"""
|
||||
# We add a closing brace because that represents the final pop of the stack.
|
||||
tokens: List[str] = [")"] + regex_tokens[::-1]
|
||||
|
||||
def wrap(lst: List[Node]) -> Node:
|
||||
""" Turn list into sequence when it contains several items. """
|
||||
if len(lst) == 1:
|
||||
return lst[0]
|
||||
else:
|
||||
return NodeSequence(lst)
|
||||
|
||||
def _parse() -> Node:
|
||||
or_list: List[List[Node]] = []
|
||||
result: List[Node] = []
|
||||
|
||||
def wrapped_result() -> Node:
|
||||
if or_list == []:
|
||||
return wrap(result)
|
||||
else:
|
||||
or_list.append(result)
|
||||
return AnyNode([wrap(i) for i in or_list])
|
||||
|
||||
while tokens:
|
||||
t = tokens.pop()
|
||||
|
||||
if t.startswith("(?P<"):
|
||||
variable = Variable(_parse(), varname=t[4:-1])
|
||||
result.append(variable)
|
||||
|
||||
elif t in ("*", "*?"):
|
||||
greedy = t == "*"
|
||||
result[-1] = Repeat(result[-1], greedy=greedy)
|
||||
|
||||
elif t in ("+", "+?"):
|
||||
greedy = t == "+"
|
||||
result[-1] = Repeat(result[-1], min_repeat=1, greedy=greedy)
|
||||
|
||||
elif t in ("?", "??"):
|
||||
if result == []:
|
||||
raise Exception("Nothing to repeat." + repr(tokens))
|
||||
else:
|
||||
greedy = t == "?"
|
||||
result[-1] = Repeat(
|
||||
result[-1], min_repeat=0, max_repeat=1, greedy=greedy
|
||||
)
|
||||
|
||||
elif t == "|":
|
||||
or_list.append(result)
|
||||
result = []
|
||||
|
||||
elif t in ("(", "(?:"):
|
||||
result.append(_parse())
|
||||
|
||||
elif t == "(?!":
|
||||
result.append(Lookahead(_parse(), negative=True))
|
||||
|
||||
elif t == "(?=":
|
||||
result.append(Lookahead(_parse(), negative=False))
|
||||
|
||||
elif t == ")":
|
||||
return wrapped_result()
|
||||
|
||||
elif t.startswith("#"):
|
||||
pass
|
||||
|
||||
elif t.startswith("{"):
|
||||
# TODO: implement!
|
||||
raise Exception("{}-style repetition not yet supported".format(t))
|
||||
|
||||
elif t.startswith("(?"):
|
||||
raise Exception("%r not supported" % t)
|
||||
|
||||
elif t.isspace():
|
||||
pass
|
||||
else:
|
||||
result.append(Regex(t))
|
||||
|
||||
raise Exception("Expecting ')' token")
|
||||
|
||||
result = _parse()
|
||||
|
||||
if len(tokens) != 0:
|
||||
raise Exception("Unmatched parentheses.")
|
||||
else:
|
||||
return result
|
|
@ -0,0 +1,60 @@
|
|||
"""
|
||||
Validator for a regular language.
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from prompt_toolkit.document import Document
|
||||
from prompt_toolkit.validation import ValidationError, Validator
|
||||
|
||||
from .compiler import _CompiledGrammar
|
||||
|
||||
__all__ = [
|
||||
"GrammarValidator",
|
||||
]
|
||||
|
||||
|
||||
class GrammarValidator(Validator):
|
||||
"""
|
||||
Validator which can be used for validation according to variables in
|
||||
the grammar. Each variable can have its own validator.
|
||||
|
||||
:param compiled_grammar: `GrammarCompleter` instance.
|
||||
:param validators: `dict` mapping variable names of the grammar to the
|
||||
`Validator` instances to be used for each variable.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, compiled_grammar: _CompiledGrammar, validators: Dict[str, Validator]
|
||||
) -> None:
|
||||
|
||||
self.compiled_grammar = compiled_grammar
|
||||
self.validators = validators
|
||||
|
||||
def validate(self, document: Document) -> None:
|
||||
# Parse input document.
|
||||
# We use `match`, not `match_prefix`, because for validation, we want
|
||||
# the actual, unambiguous interpretation of the input.
|
||||
m = self.compiled_grammar.match(document.text)
|
||||
|
||||
if m:
|
||||
for v in m.variables():
|
||||
validator = self.validators.get(v.varname)
|
||||
|
||||
if validator:
|
||||
# Unescape text.
|
||||
unwrapped_text = self.compiled_grammar.unescape(v.varname, v.value)
|
||||
|
||||
# Create a document, for the completions API (text/cursor_position)
|
||||
inner_document = Document(unwrapped_text, len(unwrapped_text))
|
||||
|
||||
try:
|
||||
validator.validate(inner_document)
|
||||
except ValidationError as e:
|
||||
raise ValidationError(
|
||||
cursor_position=v.start + e.cursor_position,
|
||||
message=e.message,
|
||||
) from e
|
||||
else:
|
||||
raise ValidationError(
|
||||
cursor_position=len(document.text), message="Invalid command"
|
||||
)
|
|
@ -0,0 +1,6 @@
|
|||
from .server import PromptToolkitSSHServer, PromptToolkitSSHSession
|
||||
|
||||
__all__ = [
|
||||
"PromptToolkitSSHSession",
|
||||
"PromptToolkitSSHServer",
|
||||
]
|
Binary file not shown.
Binary file not shown.
151
venv/Lib/site-packages/prompt_toolkit/contrib/ssh/server.py
Normal file
151
venv/Lib/site-packages/prompt_toolkit/contrib/ssh/server.py
Normal file
|
@ -0,0 +1,151 @@
|
|||
"""
|
||||
Utility for running a prompt_toolkit application in an asyncssh server.
|
||||
"""
|
||||
import asyncio
|
||||
import traceback
|
||||
from typing import Awaitable, Callable, Optional, TextIO, cast
|
||||
|
||||
import asyncssh
|
||||
|
||||
from prompt_toolkit.application.current import AppSession, create_app_session
|
||||
from prompt_toolkit.data_structures import Size
|
||||
from prompt_toolkit.input import create_pipe_input
|
||||
from prompt_toolkit.output.vt100 import Vt100_Output
|
||||
|
||||
__all__ = ["PromptToolkitSSHSession", "PromptToolkitSSHServer"]
|
||||
|
||||
|
||||
class PromptToolkitSSHSession(asyncssh.SSHServerSession):
|
||||
def __init__(
|
||||
self, interact: Callable[["PromptToolkitSSHSession"], Awaitable[None]]
|
||||
) -> None:
|
||||
self.interact = interact
|
||||
self.interact_task: Optional[asyncio.Task[None]] = None
|
||||
self._chan = None
|
||||
self.app_session: Optional[AppSession] = None
|
||||
|
||||
# PipInput object, for sending input in the CLI.
|
||||
# (This is something that we can use in the prompt_toolkit event loop,
|
||||
# but still write date in manually.)
|
||||
self._input = create_pipe_input()
|
||||
self._output = None
|
||||
|
||||
# Output object. Don't render to the real stdout, but write everything
|
||||
# in the SSH channel.
|
||||
class Stdout:
|
||||
def write(s, data):
|
||||
try:
|
||||
if self._chan is not None:
|
||||
self._chan.write(data.replace("\n", "\r\n"))
|
||||
except BrokenPipeError:
|
||||
pass # Channel not open for sending.
|
||||
|
||||
def isatty(s) -> bool:
|
||||
return True
|
||||
|
||||
def flush(s):
|
||||
pass
|
||||
|
||||
@property
|
||||
def encoding(s):
|
||||
return self._chan._orig_chan.get_encoding()[0]
|
||||
|
||||
self.stdout = cast(TextIO, Stdout())
|
||||
|
||||
def _get_size(self) -> Size:
|
||||
"""
|
||||
Callable that returns the current `Size`, required by Vt100_Output.
|
||||
"""
|
||||
if self._chan is None:
|
||||
return Size(rows=20, columns=79)
|
||||
else:
|
||||
width, height, pixwidth, pixheight = self._chan.get_terminal_size()
|
||||
return Size(rows=height, columns=width)
|
||||
|
||||
def connection_made(self, chan):
|
||||
self._chan = chan
|
||||
|
||||
def shell_requested(self) -> bool:
|
||||
return True
|
||||
|
||||
def session_started(self) -> None:
|
||||
self.interact_task = asyncio.get_event_loop().create_task(self._interact())
|
||||
|
||||
async def _interact(self) -> None:
|
||||
if self._chan is None:
|
||||
# Should not happen.
|
||||
raise Exception("`_interact` called before `connection_made`.")
|
||||
|
||||
if hasattr(self._chan, "set_line_mode") and self._chan._editor is not None:
|
||||
# Disable the line editing provided by asyncssh. Prompt_toolkit
|
||||
# provides the line editing.
|
||||
self._chan.set_line_mode(False)
|
||||
|
||||
term = self._chan.get_terminal_type()
|
||||
|
||||
self._output = Vt100_Output(
|
||||
self.stdout, self._get_size, term=term, write_binary=False
|
||||
)
|
||||
with create_app_session(input=self._input, output=self._output) as session:
|
||||
self.app_session = session
|
||||
try:
|
||||
await self.interact(self)
|
||||
except BaseException:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# Close the connection.
|
||||
self._chan.close()
|
||||
|
||||
def terminal_size_changed(self, width, height, pixwidth, pixheight):
|
||||
# Send resize event to the current application.
|
||||
if self.app_session and self.app_session.app:
|
||||
self.app_session.app._on_resize()
|
||||
|
||||
def data_received(self, data, datatype):
|
||||
self._input.send_text(data)
|
||||
|
||||
|
||||
class PromptToolkitSSHServer(asyncssh.SSHServer):
|
||||
"""
|
||||
Run a prompt_toolkit application over an asyncssh server.
|
||||
|
||||
This takes one argument, an `interact` function, which is called for each
|
||||
connection. This should be an asynchronous function that runs the
|
||||
prompt_toolkit applications. This function runs in an `AppSession`, which
|
||||
means that we can have multiple UI interactions concurrently.
|
||||
|
||||
Example usage:
|
||||
|
||||
.. code:: python
|
||||
|
||||
async def interact(ssh_session: PromptToolkitSSHSession) -> None:
|
||||
await yes_no_dialog("my title", "my text").run_async()
|
||||
|
||||
prompt_session = PromptSession()
|
||||
text = await prompt_session.prompt_async("Type something: ")
|
||||
print_formatted_text('You said: ', text)
|
||||
|
||||
server = PromptToolkitSSHServer(interact=interact)
|
||||
loop = get_event_loop()
|
||||
loop.run_until_complete(
|
||||
asyncssh.create_server(
|
||||
lambda: MySSHServer(interact),
|
||||
"",
|
||||
port,
|
||||
server_host_keys=["/etc/ssh/..."],
|
||||
)
|
||||
)
|
||||
loop.run_forever()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, interact: Callable[[PromptToolkitSSHSession], Awaitable[None]]
|
||||
) -> None:
|
||||
self.interact = interact
|
||||
|
||||
def begin_auth(self, username):
|
||||
# No authentication.
|
||||
return False
|
||||
|
||||
def session_requested(self) -> PromptToolkitSSHSession:
|
||||
return PromptToolkitSSHSession(self.interact)
|
|
@ -0,0 +1,5 @@
|
|||
from .server import TelnetServer
|
||||
|
||||
__all__ = [
|
||||
"TelnetServer",
|
||||
]
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
10
venv/Lib/site-packages/prompt_toolkit/contrib/telnet/log.py
Normal file
10
venv/Lib/site-packages/prompt_toolkit/contrib/telnet/log.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
"""
|
||||
Python logger for the telnet server.
|
||||
"""
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__package__)
|
||||
|
||||
__all__ = [
|
||||
"logger",
|
||||
]
|
207
venv/Lib/site-packages/prompt_toolkit/contrib/telnet/protocol.py
Normal file
207
venv/Lib/site-packages/prompt_toolkit/contrib/telnet/protocol.py
Normal file
|
@ -0,0 +1,207 @@
|
|||
"""
|
||||
Parser for the Telnet protocol. (Not a complete implementation of the telnet
|
||||
specification, but sufficient for a command line interface.)
|
||||
|
||||
Inspired by `Twisted.conch.telnet`.
|
||||
"""
|
||||
import struct
|
||||
from typing import Callable, Generator
|
||||
|
||||
from .log import logger
|
||||
|
||||
__all__ = [
|
||||
"TelnetProtocolParser",
|
||||
]
|
||||
|
||||
|
||||
def int2byte(number: int) -> bytes:
|
||||
return bytes((number,))
|
||||
|
||||
|
||||
# Telnet constants.
|
||||
NOP = int2byte(0)
|
||||
SGA = int2byte(3)
|
||||
|
||||
IAC = int2byte(255)
|
||||
DO = int2byte(253)
|
||||
DONT = int2byte(254)
|
||||
LINEMODE = int2byte(34)
|
||||
SB = int2byte(250)
|
||||
WILL = int2byte(251)
|
||||
WONT = int2byte(252)
|
||||
MODE = int2byte(1)
|
||||
SE = int2byte(240)
|
||||
ECHO = int2byte(1)
|
||||
NAWS = int2byte(31)
|
||||
LINEMODE = int2byte(34)
|
||||
SUPPRESS_GO_AHEAD = int2byte(3)
|
||||
|
||||
TTYPE = int2byte(24)
|
||||
SEND = int2byte(1)
|
||||
IS = int2byte(0)
|
||||
|
||||
DM = int2byte(242)
|
||||
BRK = int2byte(243)
|
||||
IP = int2byte(244)
|
||||
AO = int2byte(245)
|
||||
AYT = int2byte(246)
|
||||
EC = int2byte(247)
|
||||
EL = int2byte(248)
|
||||
GA = int2byte(249)
|
||||
|
||||
|
||||
class TelnetProtocolParser:
|
||||
"""
|
||||
Parser for the Telnet protocol.
|
||||
Usage::
|
||||
|
||||
def data_received(data):
|
||||
print(data)
|
||||
|
||||
def size_received(rows, columns):
|
||||
print(rows, columns)
|
||||
|
||||
p = TelnetProtocolParser(data_received, size_received)
|
||||
p.feed(binary_data)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_received_callback: Callable[[bytes], None],
|
||||
size_received_callback: Callable[[int, int], None],
|
||||
ttype_received_callback: Callable[[str], None],
|
||||
) -> None:
|
||||
|
||||
self.data_received_callback = data_received_callback
|
||||
self.size_received_callback = size_received_callback
|
||||
self.ttype_received_callback = ttype_received_callback
|
||||
|
||||
self._parser = self._parse_coroutine()
|
||||
self._parser.send(None) # type: ignore
|
||||
|
||||
def received_data(self, data: bytes) -> None:
|
||||
self.data_received_callback(data)
|
||||
|
||||
def do_received(self, data: bytes) -> None:
|
||||
""" Received telnet DO command. """
|
||||
logger.info("DO %r", data)
|
||||
|
||||
def dont_received(self, data: bytes) -> None:
|
||||
""" Received telnet DONT command. """
|
||||
logger.info("DONT %r", data)
|
||||
|
||||
def will_received(self, data: bytes) -> None:
|
||||
""" Received telnet WILL command. """
|
||||
logger.info("WILL %r", data)
|
||||
|
||||
def wont_received(self, data: bytes) -> None:
|
||||
""" Received telnet WONT command. """
|
||||
logger.info("WONT %r", data)
|
||||
|
||||
def command_received(self, command: bytes, data: bytes) -> None:
|
||||
if command == DO:
|
||||
self.do_received(data)
|
||||
|
||||
elif command == DONT:
|
||||
self.dont_received(data)
|
||||
|
||||
elif command == WILL:
|
||||
self.will_received(data)
|
||||
|
||||
elif command == WONT:
|
||||
self.wont_received(data)
|
||||
|
||||
else:
|
||||
logger.info("command received %r %r", command, data)
|
||||
|
||||
def naws(self, data: bytes) -> None:
|
||||
"""
|
||||
Received NAWS. (Window dimensions.)
|
||||
"""
|
||||
if len(data) == 4:
|
||||
# NOTE: the first parameter of struct.unpack should be
|
||||
# a 'str' object. Both on Py2/py3. This crashes on OSX
|
||||
# otherwise.
|
||||
columns, rows = struct.unpack(str("!HH"), data)
|
||||
self.size_received_callback(rows, columns)
|
||||
else:
|
||||
logger.warning("Wrong number of NAWS bytes")
|
||||
|
||||
def ttype(self, data: bytes) -> None:
|
||||
"""
|
||||
Received terminal type.
|
||||
"""
|
||||
subcmd, data = data[0:1], data[1:]
|
||||
if subcmd == IS:
|
||||
ttype = data.decode("ascii")
|
||||
self.ttype_received_callback(ttype)
|
||||
else:
|
||||
logger.warning("Received a non-IS terminal type Subnegotiation")
|
||||
|
||||
def negotiate(self, data: bytes) -> None:
|
||||
"""
|
||||
Got negotiate data.
|
||||
"""
|
||||
command, payload = data[0:1], data[1:]
|
||||
|
||||
if command == NAWS:
|
||||
self.naws(payload)
|
||||
elif command == TTYPE:
|
||||
self.ttype(payload)
|
||||
else:
|
||||
logger.info("Negotiate (%r got bytes)", len(data))
|
||||
|
||||
def _parse_coroutine(self) -> Generator[None, bytes, None]:
|
||||
"""
|
||||
Parser state machine.
|
||||
Every 'yield' expression returns the next byte.
|
||||
"""
|
||||
while True:
|
||||
d = yield
|
||||
|
||||
if d == int2byte(0):
|
||||
pass # NOP
|
||||
|
||||
# Go to state escaped.
|
||||
elif d == IAC:
|
||||
d2 = yield
|
||||
|
||||
if d2 == IAC:
|
||||
self.received_data(d2)
|
||||
|
||||
# Handle simple commands.
|
||||
elif d2 in (NOP, DM, BRK, IP, AO, AYT, EC, EL, GA):
|
||||
self.command_received(d2, b"")
|
||||
|
||||
# Handle IAC-[DO/DONT/WILL/WONT] commands.
|
||||
elif d2 in (DO, DONT, WILL, WONT):
|
||||
d3 = yield
|
||||
self.command_received(d2, d3)
|
||||
|
||||
# Subnegotiation
|
||||
elif d2 == SB:
|
||||
# Consume everything until next IAC-SE
|
||||
data = []
|
||||
|
||||
while True:
|
||||
d3 = yield
|
||||
|
||||
if d3 == IAC:
|
||||
d4 = yield
|
||||
if d4 == SE:
|
||||
break
|
||||
else:
|
||||
data.append(d4)
|
||||
else:
|
||||
data.append(d3)
|
||||
|
||||
self.negotiate(b"".join(data))
|
||||
else:
|
||||
self.received_data(d)
|
||||
|
||||
def feed(self, data: bytes) -> None:
|
||||
"""
|
||||
Feed data to the parser.
|
||||
"""
|
||||
for b in data:
|
||||
self._parser.send(int2byte(b))
|
349
venv/Lib/site-packages/prompt_toolkit/contrib/telnet/server.py
Normal file
349
venv/Lib/site-packages/prompt_toolkit/contrib/telnet/server.py
Normal file
|
@ -0,0 +1,349 @@
|
|||
"""
|
||||
Telnet server.
|
||||
"""
|
||||
import asyncio
|
||||
import contextvars # Requires Python3.7!
|
||||
import socket
|
||||
from asyncio import get_event_loop
|
||||
from typing import Awaitable, Callable, List, Optional, Set, TextIO, Tuple, cast
|
||||
|
||||
from prompt_toolkit.application.current import create_app_session, get_app
|
||||
from prompt_toolkit.application.run_in_terminal import run_in_terminal
|
||||
from prompt_toolkit.data_structures import Size
|
||||
from prompt_toolkit.formatted_text import AnyFormattedText, to_formatted_text
|
||||
from prompt_toolkit.input import create_pipe_input
|
||||
from prompt_toolkit.output.vt100 import Vt100_Output
|
||||
from prompt_toolkit.renderer import print_formatted_text as print_formatted_text
|
||||
from prompt_toolkit.styles import BaseStyle, DummyStyle
|
||||
|
||||
from .log import logger
|
||||
from .protocol import (
|
||||
DO,
|
||||
ECHO,
|
||||
IAC,
|
||||
LINEMODE,
|
||||
MODE,
|
||||
NAWS,
|
||||
SB,
|
||||
SE,
|
||||
SEND,
|
||||
SUPPRESS_GO_AHEAD,
|
||||
TTYPE,
|
||||
WILL,
|
||||
TelnetProtocolParser,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TelnetServer",
|
||||
]
|
||||
|
||||
|
||||
def int2byte(number: int) -> bytes:
|
||||
return bytes((number,))
|
||||
|
||||
|
||||
def _initialize_telnet(connection: socket.socket) -> None:
|
||||
logger.info("Initializing telnet connection")
|
||||
|
||||
# Iac Do Linemode
|
||||
connection.send(IAC + DO + LINEMODE)
|
||||
|
||||
# Suppress Go Ahead. (This seems important for Putty to do correct echoing.)
|
||||
# This will allow bi-directional operation.
|
||||
connection.send(IAC + WILL + SUPPRESS_GO_AHEAD)
|
||||
|
||||
# Iac sb
|
||||
connection.send(IAC + SB + LINEMODE + MODE + int2byte(0) + IAC + SE)
|
||||
|
||||
# IAC Will Echo
|
||||
connection.send(IAC + WILL + ECHO)
|
||||
|
||||
# Negotiate window size
|
||||
connection.send(IAC + DO + NAWS)
|
||||
|
||||
# Negotiate terminal type
|
||||
# Assume the client will accept the negociation with `IAC + WILL + TTYPE`
|
||||
connection.send(IAC + DO + TTYPE)
|
||||
|
||||
# We can then select the first terminal type supported by the client,
|
||||
# which is generally the best type the client supports
|
||||
# The client should reply with a `IAC + SB + TTYPE + IS + ttype + IAC + SE`
|
||||
connection.send(IAC + SB + TTYPE + SEND + IAC + SE)
|
||||
|
||||
|
||||
class _ConnectionStdout:
|
||||
"""
|
||||
Wrapper around socket which provides `write` and `flush` methods for the
|
||||
Vt100_Output output.
|
||||
"""
|
||||
|
||||
def __init__(self, connection: socket.socket, encoding: str) -> None:
|
||||
self._encoding = encoding
|
||||
self._connection = connection
|
||||
self._errors = "strict"
|
||||
self._buffer: List[bytes] = []
|
||||
|
||||
def write(self, data: str) -> None:
|
||||
data = data.replace("\n", "\r\n")
|
||||
self._buffer.append(data.encode(self._encoding, errors=self._errors))
|
||||
self.flush()
|
||||
|
||||
def isatty(self) -> bool:
|
||||
return True
|
||||
|
||||
def flush(self) -> None:
|
||||
try:
|
||||
self._connection.send(b"".join(self._buffer))
|
||||
except socket.error as e:
|
||||
logger.warning("Couldn't send data over socket: %s" % e)
|
||||
|
||||
self._buffer = []
|
||||
|
||||
@property
|
||||
def encoding(self) -> str:
|
||||
return self._encoding
|
||||
|
||||
@property
|
||||
def errors(self) -> str:
|
||||
return self._errors
|
||||
|
||||
|
||||
class TelnetConnection:
|
||||
"""
|
||||
Class that represents one Telnet connection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conn: socket.socket,
|
||||
addr: Tuple[str, int],
|
||||
interact: Callable[["TelnetConnection"], Awaitable[None]],
|
||||
server: "TelnetServer",
|
||||
encoding: str,
|
||||
style: Optional[BaseStyle],
|
||||
) -> None:
|
||||
|
||||
self.conn = conn
|
||||
self.addr = addr
|
||||
self.interact = interact
|
||||
self.server = server
|
||||
self.encoding = encoding
|
||||
self.style = style
|
||||
self._closed = False
|
||||
self._ready = asyncio.Event()
|
||||
self.vt100_output = None
|
||||
|
||||
# Create "Output" object.
|
||||
self.size = Size(rows=40, columns=79)
|
||||
|
||||
# Initialize.
|
||||
_initialize_telnet(conn)
|
||||
|
||||
# Create input.
|
||||
self.vt100_input = create_pipe_input()
|
||||
|
||||
# Create output.
|
||||
def get_size() -> Size:
|
||||
return self.size
|
||||
|
||||
self.stdout = cast(TextIO, _ConnectionStdout(conn, encoding=encoding))
|
||||
|
||||
def data_received(data: bytes) -> None:
|
||||
""" TelnetProtocolParser 'data_received' callback """
|
||||
self.vt100_input.send_bytes(data)
|
||||
|
||||
def size_received(rows: int, columns: int) -> None:
|
||||
""" TelnetProtocolParser 'size_received' callback """
|
||||
self.size = Size(rows=rows, columns=columns)
|
||||
if self.vt100_output is not None:
|
||||
get_app()._on_resize()
|
||||
|
||||
def ttype_received(ttype: str) -> None:
|
||||
""" TelnetProtocolParser 'ttype_received' callback """
|
||||
self.vt100_output = Vt100_Output(
|
||||
self.stdout, get_size, term=ttype, write_binary=False
|
||||
)
|
||||
self._ready.set()
|
||||
|
||||
self.parser = TelnetProtocolParser(data_received, size_received, ttype_received)
|
||||
self.context: Optional[contextvars.Context] = None
|
||||
|
||||
async def run_application(self) -> None:
|
||||
"""
|
||||
Run application.
|
||||
"""
|
||||
|
||||
def handle_incoming_data() -> None:
|
||||
data = self.conn.recv(1024)
|
||||
if data:
|
||||
self.feed(data)
|
||||
else:
|
||||
# Connection closed by client.
|
||||
logger.info("Connection closed by client. %r %r" % self.addr)
|
||||
self.close()
|
||||
|
||||
# Add reader.
|
||||
loop = get_event_loop()
|
||||
loop.add_reader(self.conn, handle_incoming_data)
|
||||
|
||||
try:
|
||||
# Wait for v100_output to be properly instantiated
|
||||
await self._ready.wait()
|
||||
with create_app_session(input=self.vt100_input, output=self.vt100_output):
|
||||
self.context = contextvars.copy_context()
|
||||
await self.interact(self)
|
||||
except Exception as e:
|
||||
print("Got %s" % type(e).__name__, e)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
self.close()
|
||||
|
||||
def feed(self, data: bytes) -> None:
|
||||
"""
|
||||
Handler for incoming data. (Called by TelnetServer.)
|
||||
"""
|
||||
self.parser.feed(data)
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Closed by client.
|
||||
"""
|
||||
if not self._closed:
|
||||
self._closed = True
|
||||
|
||||
self.vt100_input.close()
|
||||
get_event_loop().remove_reader(self.conn)
|
||||
self.conn.close()
|
||||
|
||||
def send(self, formatted_text: AnyFormattedText) -> None:
|
||||
"""
|
||||
Send text to the client.
|
||||
"""
|
||||
if self.vt100_output is None:
|
||||
return
|
||||
formatted_text = to_formatted_text(formatted_text)
|
||||
print_formatted_text(
|
||||
self.vt100_output, formatted_text, self.style or DummyStyle()
|
||||
)
|
||||
|
||||
def send_above_prompt(self, formatted_text: AnyFormattedText) -> None:
|
||||
"""
|
||||
Send text to the client.
|
||||
This is asynchronous, returns a `Future`.
|
||||
"""
|
||||
formatted_text = to_formatted_text(formatted_text)
|
||||
return self._run_in_terminal(lambda: self.send(formatted_text))
|
||||
|
||||
def _run_in_terminal(self, func: Callable[[], None]) -> None:
|
||||
# Make sure that when an application was active for this connection,
|
||||
# that we print the text above the application.
|
||||
if self.context:
|
||||
self.context.run(run_in_terminal, func)
|
||||
else:
|
||||
raise RuntimeError("Called _run_in_terminal outside `run_application`.")
|
||||
|
||||
def erase_screen(self) -> None:
|
||||
"""
|
||||
Erase the screen and move the cursor to the top.
|
||||
"""
|
||||
if self.vt100_output is None:
|
||||
return
|
||||
self.vt100_output.erase_screen()
|
||||
self.vt100_output.cursor_goto(0, 0)
|
||||
self.vt100_output.flush()
|
||||
|
||||
|
||||
async def _dummy_interact(connection: TelnetConnection) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class TelnetServer:
|
||||
"""
|
||||
Telnet server implementation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 23,
|
||||
interact: Callable[[TelnetConnection], Awaitable[None]] = _dummy_interact,
|
||||
encoding: str = "utf-8",
|
||||
style: Optional[BaseStyle] = None,
|
||||
) -> None:
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.interact = interact
|
||||
self.encoding = encoding
|
||||
self.style = style
|
||||
self._application_tasks: List[asyncio.Task] = []
|
||||
|
||||
self.connections: Set[TelnetConnection] = set()
|
||||
self._listen_socket: Optional[socket.socket] = None
|
||||
|
||||
@classmethod
|
||||
def _create_socket(cls, host: str, port: int) -> socket.socket:
|
||||
# Create and bind socket
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
s.bind((host, port))
|
||||
|
||||
s.listen(4)
|
||||
return s
|
||||
|
||||
def start(self) -> None:
|
||||
"""
|
||||
Start the telnet server.
|
||||
Don't forget to call `loop.run_forever()` after doing this.
|
||||
"""
|
||||
self._listen_socket = self._create_socket(self.host, self.port)
|
||||
logger.info(
|
||||
"Listening for telnet connections on %s port %r", self.host, self.port
|
||||
)
|
||||
|
||||
get_event_loop().add_reader(self._listen_socket, self._accept)
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._listen_socket:
|
||||
get_event_loop().remove_reader(self._listen_socket)
|
||||
self._listen_socket.close()
|
||||
|
||||
# Wait for all applications to finish.
|
||||
for t in self._application_tasks:
|
||||
t.cancel()
|
||||
|
||||
for t in self._application_tasks:
|
||||
await t
|
||||
|
||||
def _accept(self) -> None:
|
||||
"""
|
||||
Accept new incoming connection.
|
||||
"""
|
||||
if self._listen_socket is None:
|
||||
return # Should not happen. `_accept` is called after `start`.
|
||||
|
||||
conn, addr = self._listen_socket.accept()
|
||||
logger.info("New connection %r %r", *addr)
|
||||
|
||||
connection = TelnetConnection(
|
||||
conn, addr, self.interact, self, encoding=self.encoding, style=self.style
|
||||
)
|
||||
self.connections.add(connection)
|
||||
|
||||
# Run application for this connection.
|
||||
async def run() -> None:
|
||||
logger.info("Starting interaction %r %r", *addr)
|
||||
try:
|
||||
await connection.run_application()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
finally:
|
||||
self.connections.remove(connection)
|
||||
self._application_tasks.remove(task)
|
||||
logger.info("Stopping interaction %r %r", *addr)
|
||||
|
||||
task = get_event_loop().create_task(run())
|
||||
self._application_tasks.append(task)
|
Loading…
Add table
Add a link
Reference in a new issue