diff --git a/mycli/clistyle.py b/mycli/clistyle.py index 11ae5948..8c89ddf8 100644 --- a/mycli/clistyle.py +++ b/mycli/clistyle.py @@ -1,9 +1,10 @@ -# type: ignore +from __future__ import annotations import logging from prompt_toolkit.styles import Style, merge_styles from prompt_toolkit.styles.pygments import style_from_pygments_cls +from prompt_toolkit.styles.style import _MergedStyle from pygments.style import Style as PygmentsStyle import pygments.styles from pygments.token import Token, string_to_tokentype @@ -12,7 +13,7 @@ logger = logging.getLogger(__name__) # map Pygments tokens (ptk 1.0) to class names (ptk 2.0). -TOKEN_TO_PROMPT_STYLE = { +TOKEN_TO_PROMPT_STYLE: dict[Token, str] = { Token.Menu.Completions.Completion.Current: "completion-menu.completion.current", Token.Menu.Completions.Completion: "completion-menu.completion", Token.Menu.Completions.Meta.Current: "completion-menu.meta.completion.current", @@ -42,10 +43,10 @@ } # reverse dict for cli_helpers, because they still expect Pygments tokens. -PROMPT_STYLE_TO_TOKEN = {v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()} +PROMPT_STYLE_TO_TOKEN: dict[str, Token] = {v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()} # all tokens that the Pygments MySQL lexer can produce -OVERRIDE_STYLE_TO_TOKEN = { +OVERRIDE_STYLE_TO_TOKEN: dict[str, Token] = { "sql.comment": Token.Comment, "sql.comment.multi-line": Token.Comment.Multiline, "sql.comment.single-line": Token.Comment.Single, @@ -76,7 +77,11 @@ } -def parse_pygments_style(token_name, style_object, style_dict): +def parse_pygments_style( + token_name: str, + style_object: PygmentsStyle | str, + style_dict: dict[str, str], +) -> tuple[Token, str]: """Parse token type and style string. :param token_name: str name of Pygments token. Example: "Token.String" @@ -85,20 +90,21 @@ def parse_pygments_style(token_name, style_object, style_dict): """ token_type = string_to_tokentype(token_name) - try: + if isinstance(style_object, PygmentsStyle): + # When a Pygments Style class is passed, use its "styles" mapping. other_token_type = string_to_tokentype(style_dict[token_name]) return token_type, style_object.styles[other_token_type] - except AttributeError: + else: return token_type, style_dict[token_name] -def style_factory(name, cli_style): +def style_factory(name: str, cli_style: dict[str, str]) -> _MergedStyle: try: - style = pygments.styles.get_style_by_name(name) + style: PygmentsStyle = pygments.styles.get_style_by_name(name) except ClassNotFound: style = pygments.styles.get_style_by_name("native") - prompt_styles = [] + prompt_styles: list[tuple[str, str]] = [] # prompt-toolkit used pygments tokens for styling before, switched to style # names in 2.0. Convert old token types to new style names, for backwards compatibility. for token in cli_style: @@ -116,13 +122,13 @@ def style_factory(name, cli_style): # https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/styles/defaults.py prompt_styles.append((token, cli_style[token])) - override_style = Style([("bottom-toolbar", "noreverse")]) + override_style: Style = Style([("bottom-toolbar", "noreverse")]) return merge_styles([style_from_pygments_cls(style), override_style, Style(prompt_styles)]) -def style_factory_output(name, cli_style): +def style_factory_output(name: str, cli_style: dict[str, str]) -> PygmentsStyle: try: - style = pygments.styles.get_style_by_name(name).styles + style: dict[PygmentsStyle | str, str] = pygments.styles.get_style_by_name(name).styles except ClassNotFound: style = pygments.styles.get_style_by_name("native").styles