diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0d6b6778a..9dac19234 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,7 +15,7 @@ repos: # Run the linter. - id: ruff files: ^tuplex/python/tuplex.*\.py$ - args: [ "--fix", "--select", "I" ] + args: [ "--fix", "--config", "ruff.toml"] types_or: [ python, pyi ] # Run the formatter. - id: ruff-format diff --git a/pyproject.toml b/pyproject.toml index dc7fe4af5..aefc4e5dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,3 @@ requires = [ "requests" ] build-backend = "setuptools.build_meta" - - -[tool.ruff] -include = ["pyproject.toml", "tuplex/python/tuplex/**/*.py"] diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 000000000..9bc42fb7c --- /dev/null +++ b/ruff.toml @@ -0,0 +1,9 @@ +#"--select", "I", "--select", "F" +[lint] +# Add "B", "Q" for flake8 checks. +select = ["I", "E4", "E7", "E9", "F", "CPY001", "T201", "T203", "ANN001", "ANN002", "ANN003", "ANN201", "ANN202", "ANN204", "ANN205", "ANN206"] +ignore = [] + +# Allow fix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = [] \ No newline at end of file diff --git a/tuplex/python/tuplex/__init__.py b/tuplex/python/tuplex/__init__.py index ad8d14b5e..20aa0a4c1 100644 --- a/tuplex/python/tuplex/__init__.py +++ b/tuplex/python/tuplex/__init__.py @@ -10,6 +10,7 @@ # ----------------------------------------------------------------------------------------------------------------------# import logging +from typing import Optional, Union # expose aws setup for better convenience import tuplex.distributed @@ -23,7 +24,12 @@ # for convenience create a dummy function to return a default-configured Lambda context -def LambdaContext(conf=None, name=None, s3_scratch_dir=None, **kwargs): +def LambdaContext( + conf: Union[None, str, dict] = None, + name: Optional[str] = None, + s3_scratch_dir: Optional[str] = None, + **kwargs: dict, +) -> Context: import uuid if s3_scratch_dir is None: diff --git a/tuplex/python/tuplex/context.py b/tuplex/python/tuplex/context.py index 46763b72d..04e8d2e0c 100644 --- a/tuplex/python/tuplex/context.py +++ b/tuplex/python/tuplex/context.py @@ -21,6 +21,7 @@ import os import sys import uuid +from typing import Any, List, Optional, Tuple, Union from tuplex.utils.common import ( current_user, @@ -45,7 +46,9 @@ class Context: - def __init__(self, conf=None, name="", **kwargs): + def __init__( + self, conf: Union[None, str, dict] = None, name: str = "", **kwargs: dict + ) -> None: r"""creates new Context object, the main entry point for all operations with the Tuplex big data framework Args: @@ -235,7 +238,13 @@ def __init__(self, conf=None, name="", **kwargs): self.metrics = Metrics(python_metrics) assert self.metrics - def parallelize(self, value_list, columns=None, schema=None, auto_unpack=True): + def parallelize( + self, + value_list: List[Any], + columns: Optional[List[str]] = None, + schema: Optional[Union[Tuple, List]] = None, + auto_unpack: bool = True, + ) -> "DataSet": """passes data to the Tuplex framework. Must be a list of primitive objects (e.g. of type bool, int, float, str) or a list of (nested) tuples of these types. @@ -273,14 +282,14 @@ def parallelize(self, value_list, columns=None, schema=None, auto_unpack=True): def csv( self, - pattern, - columns=None, - header=None, - delimiter=None, - quotechar='"', - null_values=[""], - type_hints={}, - ): + pattern: str, + columns: Optional[List[str]] = None, + header: Optional[bool] = None, + delimiter: Optional[str] = None, + quotechar: str = '"', + null_values: List[str] = [""], + type_hints: dict = {}, + ) -> "DataSet": """reads csv (comma separated values) files. This function may either be provided with parameters that help to determine the delimiter, whether a header present or what kind of quote char is used. Overall, CSV parsing is done according to the RFC-4180 standard @@ -350,11 +359,11 @@ def csv( ) return ds - def text(self, pattern, null_values=None): + def text(self, pattern: str, null_values: Optional[List[str]] = None) -> "DataSet": """reads text files. Args: pattern (str): a file glob pattern, e.g. /data/file.csv or /data/\*.csv or /\*/\*csv - null_values (List[str]): a list of string to interpret as None. When empty list or None, empty lines will be the empty string '' + null_values (List[str]): a list of strings to interpret as None. When empty list or None, empty lines will be the empty string '' Returns: tuplex.dataset.DataSet: A Tuplex Dataset object that allows further ETL operations """ @@ -372,7 +381,7 @@ def text(self, pattern, null_values=None): ds._dataSet = self._context.text(pattern, null_values) return ds - def orc(self, pattern, columns=None): + def orc(self, pattern: str, columns: Optional[List[str]] = None) -> "DataSet": """reads orc files. Args: pattern (str): a file glob pattern, e.g. /data/file.csv or /data/\*.csv or /\*/\*csv @@ -390,7 +399,7 @@ def orc(self, pattern, columns=None): ds._dataSet = self._context.orc(pattern, columns) return ds - def options(self, nested=False): + def options(self, nested: bool = False) -> dict: """retrieves all framework parameters as dictionary Args: @@ -411,7 +420,7 @@ def options(self, nested=False): else: return opt - def optionsToYAML(self, file_path="config.yaml"): + def optionsToYAML(self, file_path: str = "config.yaml") -> None: """saves options as yaml file to (local) filepath Args: @@ -420,7 +429,7 @@ def optionsToYAML(self, file_path="config.yaml"): save_conf_yaml(self.options(), file_path) - def ls(self, pattern): + def ls(self, pattern: str) -> List[str]: """ return a list of strings of all files found matching the pattern. The same pattern can be supplied to read inputs. Args: @@ -433,7 +442,7 @@ def ls(self, pattern): assert self._context return self._context.ls(pattern) - def cp(self, pattern, target_uri): + def cp(self, pattern: str, target_uri: str) -> None: """ copies all files matching the pattern to a target uri. If more than one file is found, a folder is created containing all the files relative to the longest shared path prefix. @@ -448,7 +457,7 @@ def cp(self, pattern, target_uri): assert self._context return self._context.cp(pattern, target_uri) - def rm(self, pattern): + def rm(self, pattern: str) -> None: """ removes all files matching the pattern Args: @@ -463,7 +472,7 @@ def rm(self, pattern): return self._context.rm(pattern) @property - def uiWebURL(self): + def uiWebURL(self) -> str: """ retrieve URL of webUI if running Returns: diff --git a/tuplex/python/tuplex/dataset.py b/tuplex/python/tuplex/dataset.py index 27e0d37a7..e86441146 100644 --- a/tuplex/python/tuplex/dataset.py +++ b/tuplex/python/tuplex/dataset.py @@ -10,6 +10,7 @@ # ----------------------------------------------------------------------------------------------------------------------# import logging +from typing import Any, Callable, List, Optional, Sequence, Tuple, TypeVar, Union import cloudpickle @@ -29,10 +30,10 @@ class DataSet: - def __init__(self): + def __init__(self) -> None: self._dataSet: _DataSet = None - def unique(self): + def unique(self) -> "DataSet": """removes duplicates from Dataset (out-of-order). Equivalent to a DISTINCT clause in a SQL-statement. Returns: tuplex.dataset.Dataset: A Tuplex Dataset object that allows further ETL operations. @@ -45,7 +46,7 @@ def unique(self): ds._dataSet = self._dataSet.unique() return ds - def map(self, ftor): + def map(self, ftor: Callable) -> "DataSet": """ performs a map operation using the provided udf function over the dataset and returns a dataset for further processing. @@ -79,7 +80,7 @@ def map(self, ftor): ds._dataSet = self._dataSet.map(code, cloudpickle.dumps(ftor), g) return ds - def filter(self, ftor): + def filter(self, ftor: Callable) -> "DataSet": """ performs a map operation using the provided udf function over the dataset and returns a dataset for further processing. @@ -109,7 +110,7 @@ def filter(self, ftor): ds._dataSet = self._dataSet.filter(code, cloudpickle.dumps(ftor), g) return ds - def collect(self): + def collect(self) -> List[Any]: """action that generates a physical plan, processes data and collects result then as list of tuples. Returns: @@ -121,7 +122,7 @@ def collect(self): ) return self._dataSet.collect() - def take(self, nrows=5): + def take(self, nrows: int = 5) -> List[Any]: """action that generates a physical plan, processes data and collects the top results then as list of tuples. Args: @@ -140,7 +141,7 @@ def take(self, nrows=5): return self._dataSet.take(nrows) - def show(self, nrows=None): + def show(self, nrows: Optional[int] = None) -> None: """action that generates a physical plan, processes data and prints results as nicely formatted ASCII table to stdout. @@ -158,7 +159,7 @@ def show(self, nrows=None): self._dataSet.show(nrows) - def resolve(self, eclass, ftor): + def resolve(self, eclass: TypeVar, ftor: Callable) -> "DataSet": """Adds a resolver operator to the pipeline. The signature of ftor needs to be identical to the one of the preceding operator. Args: @@ -197,7 +198,7 @@ def resolve(self, eclass, ftor): ds._dataSet = self._dataSet.resolve(ec, code, cloudpickle.dumps(ftor), g) return ds - def withColumn(self, column, ftor): + def withColumn(self, column: str, ftor: Callable) -> "DataSet": """appends a new column to the dataset by calling ftor over existing tuples Args: @@ -227,7 +228,7 @@ def withColumn(self, column, ftor): ds._dataSet = self._dataSet.withColumn(column, code, cloudpickle.dumps(ftor), g) return ds - def mapColumn(self, column, ftor): + def mapColumn(self, column: Union[int, str], ftor: Callable) -> "DataSet": """maps directly one column. UDF takes as argument directly the value of the specified column and will overwrite that column with the result. If you need access to multiple columns, use withColumn instead. If the column name already exists, it will be overwritten. @@ -258,7 +259,7 @@ def mapColumn(self, column, ftor): ds._dataSet = self._dataSet.mapColumn(column, code, cloudpickle.dumps(ftor), g) return ds - def selectColumns(self, columns): + def selectColumns(self, columns: List[Union[str, int]]) -> "DataSet": """selects a subset of columns as defined through columns which is a list or a single column Args: @@ -289,7 +290,7 @@ def selectColumns(self, columns): ds._dataSet = self._dataSet.selectColumns(columns) return ds - def renameColumn(self, key, newColumnName): + def renameColumn(self, key: str, newColumnName: str) -> "DataSet": """rename a column in dataset Args: key: str|int, old column name or (0-indexed) position. @@ -315,7 +316,7 @@ def renameColumn(self, key, newColumnName): raise TypeError("key must be int or str") return ds - def ignore(self, eclass): + def ignore(self, eclass: TypeVar) -> "DataSet": """ignores exceptions of type eclass caused by previous operator Args: @@ -342,7 +343,7 @@ def ignore(self, eclass): ds._dataSet = self._dataSet.ignore(ec) return ds - def cache(self, store_specialized=True): + def cache(self, store_specialized: bool = True) -> "DataSet": """materializes rows in main-memory for reuse with several pipelines. Can be also used to benchmark certain pipeline costs Args: @@ -361,7 +362,7 @@ def cache(self, store_specialized=True): return ds @property - def columns(self): + def columns(self) -> List[str]: """retrieve names of columns if assigned Returns: @@ -371,7 +372,7 @@ def columns(self): return cols if len(cols) > 0 else None @property - def types(self): + def types(self) -> List[TypeVar]: """output schema as list of type objects of the dataset. If the dataset has an error, None is returned. Returns: @@ -381,8 +382,13 @@ def types(self): return types def join( - self, dsRight, leftKeyColumn, rightKeyColumn, prefixes=None, suffixes=None - ): + self, + dsRight: "DataSet", + leftKeyColumn: str, + rightKeyColumn: str, + prefixes: Union[None, Tuple[str, str], List[str]] = None, + suffixes: Union[None, Tuple[str, str], List[str]] = None, + ) -> "DataSet": """ (inner) join with other dataset Args: @@ -434,8 +440,13 @@ def join( return ds def leftJoin( - self, dsRight, leftKeyColumn, rightKeyColumn, prefixes=None, suffixes=None - ): + self, + dsRight: "DataSet", + leftKeyColumn: str, + rightKeyColumn: str, + prefixes: Union[None, Tuple[str, str], List[str]] = None, + suffixes: Union[None, Tuple[str, str], List[str]] = None, + ) -> "DataSet": """ left (outer) join with other dataset Args: @@ -488,18 +499,18 @@ def leftJoin( def tocsv( self, - path, - part_size=0, - num_rows=max_rows, - num_parts=0, - part_name_generator=None, - null_value=None, - header=True, - ): + path: str, + part_size: int = 0, + num_rows: int = max_rows, + num_parts: int = 0, + part_name_generator: Optional[Callable] = None, + null_value: Optional[Any] = None, + header: bool = True, + ) -> None: """ save dataset to one or more csv files. Triggers execution of pipeline. Args: path: path where to save files to - split_size: optional size in bytes for each part to not exceed. + part_size: optional size in bytes for each part to not exceed. num_rows: limit number of output rows num_parts: number of parts to split output into. The last part will be the smallest part_name_generator: optional name generator function to the output parts, receives an integer \ @@ -542,12 +553,12 @@ def tocsv( def toorc( self, - path, - part_size=0, - num_rows=max_rows, - num_parts=0, - part_name_generator=None, - ): + path: str, + part_size: int = 0, + num_rows: int = max_rows, + num_parts: int = 0, + part_name_generator: Callable = None, + ) -> None: """ save dataset to one or more orc files. Triggers execution of pipeline. Args: path: path where to save files to @@ -579,7 +590,9 @@ def toorc( self._dataSet.toorc(path, code, code_pickled, num_parts, part_size, num_rows) - def aggregate(self, combine, aggregate, initial_value): + def aggregate( + self, combine: Callable, aggregate: Callable, initial_value: Any + ) -> "Dataset": # noqa: F821 """ cf. aggregateByKey for details Args: @@ -628,14 +641,20 @@ def aggregate(self, combine, aggregate, initial_value): ) return ds - def aggregateByKey(self, combine, aggregate, initial_value, key_columns): + def aggregateByKey( + self, + combine: Callable, + aggregate: Callable, + initial_value: Any, + key_columns: Sequence[Union[int, str]], + ) -> "tuplex.Dataset": # noqa: F821 """ An experimental aggregateByKey function similar to aggregate. There are several scenarios that do not work with this function yet and its performance hasn't been properly optimized either. Data is grouped by the supplied key_columns. Then, for each group a new aggregate is initialized using the initial_value, which can be thought of as a neutral value. The aggregate function is then called for each element and the current aggregate structure. It is guaranteed that the combine function is called at least once per group by applying the initial_value to the aggregate. Args: - combine: a UDF to combine two aggregates (results of the aggregate function or the initial_value). E.g., cobmine = lambda agg1, agg2: agg1 + agg2. The initial value should be the neutral element. - aggregate: a UDF which produces a result by combining a value with the aggregate initialized by initial_value. E.g., aggreagte = lambda agg, value: agg + value sums up values. + combine: a UDF to combine two aggregates (results of the aggregate function or the initial_value). E.g., combine = lambda agg1, agg2: agg1 + agg2. The initial value should be the neutral element. + aggregate: a UDF which produces a result by combining a value with the aggregate initialized by initial_value. E.g., aggregate = lambda agg, value: agg + value sums up values. initial_value: a neutral initial value. key_columns: the columns to group the aggregate by, a sequence of a mix of strings or integers. If specified as a single string or number, aggregation is over a single column. Returns: @@ -685,7 +704,7 @@ def aggregateByKey(self, combine, aggregate, initial_value, key_columns): return ds @property - def exception_counts(self): + def exception_counts(self) -> dict: """ Returns: dictionary of exception class names with integer keys, i.e. the counts. Returns None diff --git a/tuplex/python/tuplex/distributed.py b/tuplex/python/tuplex/distributed.py index 096bf56a3..5e884d8f7 100644 --- a/tuplex/python/tuplex/distributed.py +++ b/tuplex/python/tuplex/distributed.py @@ -11,6 +11,8 @@ try: import boto3 + import botocore + import botocore.client import botocore.exceptions except Exception: # ignore here, because boto3 is optional @@ -24,34 +26,37 @@ import sys import threading import time +from typing import Optional, Tuple # Tuplex specific imports from tuplex.utils.common import current_user, host_name +_logger = logging.getLogger(__name__) -def current_iam_user(): + +def current_iam_user() -> str: iam = boto3.resource("iam") user = iam.CurrentUser() return user.user_name.lower() -def default_lambda_name(): +def default_lambda_name() -> str: return "tuplex-lambda-runner" -def default_lambda_role(): +def default_lambda_role() -> str: return "tuplex-lambda-role" -def default_bucket_name(): +def default_bucket_name() -> str: return "tuplex-" + current_iam_user() -def default_scratch_dir(): +def default_scratch_dir() -> str: return default_bucket_name() + "/scratch" -def current_region(): +def current_region() -> str: session = boto3.session.Session() region = session.region_name @@ -62,7 +67,9 @@ def current_region(): return region -def check_credentials(aws_access_key_id=None, aws_secret_access_key=None): +def check_credentials( + aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None +) -> bool: kwargs = {} if isinstance(aws_access_key_id, str): kwargs["aws_access_key_id"] = aws_access_key_id @@ -81,7 +88,9 @@ def check_credentials(aws_access_key_id=None, aws_secret_access_key=None): return True -def ensure_s3_bucket(s3_client, bucket_name, region): +def ensure_s3_bucket( + s3_client: "botocore.client.S3", bucket_name: str, region: str +) -> None: bucket_names = list(map(lambda b: b["Name"], s3_client.list_buckets()["Buckets"])) if bucket_name not in bucket_names: @@ -105,7 +114,7 @@ def ensure_s3_bucket(s3_client, bucket_name, region): logging.info("Found bucket {}".format(bucket_name)) -def create_lambda_role(iam_client, lambda_role): +def create_lambda_role(iam_client: "botocore.client.IAM", lambda_role: str) -> None: # Roles required for AWS Lambdas trust_policy = '{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Principal":{"Service":"lambda.amazonaws.com"},"Action":"sts:AssumeRole"}]}' lambda_access_to_s3 = '{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":["s3:*MultipartUpload*","s3:Get*","s3:ListBucket","s3:Put*"],"Resource":"*"}]}' @@ -139,7 +148,7 @@ def create_lambda_role(iam_client, lambda_role): raise Exception("Failed to create AWS Lambda Role.") -def remove_lambda_role(iam_client, lambda_role): +def remove_lambda_role(iam_client: "botocore.client.IAM", lambda_role: str) -> None: # detach policies... try: iam_client.detach_role_policy( @@ -165,11 +174,12 @@ def remove_lambda_role(iam_client, lambda_role): ) ) - # delete role... iam_client.delete_role(RoleName=lambda_role) -def setup_lambda_role(iam_client, lambda_role, region, overwrite): +def setup_lambda_role( + iam_client: "botocore.client.IAM", lambda_role: str, region: str, overwrite: bool +) -> None: try: response = iam_client.get_role(RoleName=lambda_role) logging.info("Found Lambda role from {}".format(response["Role"]["CreateDate"])) @@ -187,7 +197,7 @@ def setup_lambda_role(iam_client, lambda_role, region, overwrite): create_lambda_role(iam_client, lambda_role) -def sizeof_fmt(num, suffix="B"): +def sizeof_fmt(num: int, suffix: str = "B") -> str: # from https://stackoverflow.com/questions/1094841/get-human-readable-version-of-file-size for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: if abs(num) < 1024.0: @@ -197,13 +207,13 @@ def sizeof_fmt(num, suffix="B"): class ProgressPercentage(object): - def __init__(self, filename): + def __init__(self, filename: str) -> None: self._filename = filename self._size = float(os.path.getsize(filename)) self._seen_so_far = 0 self._lock = threading.Lock() - def __call__(self, bytes_amount): + def __call__(self, bytes_amount: int) -> None: # To simplify, assume this is hooked up to a single filename with self._lock: self._seen_so_far += bytes_amount @@ -220,7 +230,7 @@ def __call__(self, bytes_amount): sys.stdout.flush() -def s3_split_uri(uri): +def s3_split_uri(uri: str) -> Tuple[str, str]: assert "/" in uri, "at least one / is required!" uri = uri.replace("s3://", "") @@ -230,16 +240,16 @@ def s3_split_uri(uri): def upload_lambda( - iam_client, - lambda_client, - lambda_function_name, - lambda_role, - lambda_zip_file, - overwrite=False, - s3_client=None, - s3_scratch_space=None, - quiet=False, -): + iam_client: Optional[str], + lambda_client: Optional[str], + lambda_function_name: Optional[str], + lambda_role: Optional[str], + lambda_zip_file: Optional[str], + overwrite: bool = False, + s3_client: "botocore.client.S3" = None, + s3_scratch_space: Optional[str] = None, + quiet: bool = False, +) -> dict: # AWS only allows 50MB to be uploaded directly via request. Else, requires S3 upload. ZIP_UPLOAD_LIMIT_SIZE = 50000000 @@ -396,7 +406,7 @@ def upload_lambda( return response -def find_lambda_package(): +def find_lambda_package() -> Optional[str]: """ Check whether a compatible zip file in tuplex/other could be found for auto-upload Returns: None or path to lambda zip to upload @@ -415,17 +425,17 @@ def find_lambda_package(): def setup_aws( - aws_access_key=None, - aws_secret_key=None, - overwrite=True, - iam_user=None, - lambda_name=None, - lambda_role=None, - lambda_file=None, - region=None, - s3_scratch_uri=None, - quiet=False, -): + aws_access_key: Optional[str] = None, + aws_secret_key: Optional[str] = None, + overwrite: Optional[str] = True, + iam_user: Optional[str] = None, + lambda_name: Optional[str] = None, + lambda_role: Optional[str] = None, + lambda_file: Optional[str] = None, + region: Optional[str] = None, + s3_scratch_uri: Optional[str] = None, + quiet: bool = False, +) -> None: start_time = time.time() # detect defaults. Important to do this here, because don't want to always invoke boto3/botocore @@ -497,4 +507,6 @@ def setup_aws( # done, print if quiet was not set to False if not quiet: - print("\nCompleted lambda setup in {:.2f}s".format(time.time() - start_time)) + _logger.info( + "\nCompleted lambda setup in {:.2f}s".format(time.time() - start_time) + ) diff --git a/tuplex/python/tuplex/exceptions.py b/tuplex/python/tuplex/exceptions.py index 0c50fa997..534c10465 100644 --- a/tuplex/python/tuplex/exceptions.py +++ b/tuplex/python/tuplex/exceptions.py @@ -8,9 +8,10 @@ # Created by Leonhard Spiegelberg first on 1/1/2021 # # License: Apache 2.0 # # ----------------------------------------------------------------------------------------------------------------------# +from typing import TypeVar -def classToExceptionCode(cls): +def classToExceptionCode(cls: TypeVar) -> int: """ return C++ enum exception code for class Args: diff --git a/tuplex/python/tuplex/metrics.py b/tuplex/python/tuplex/metrics.py index 0bcf3fd34..776ee7e46 100644 --- a/tuplex/python/tuplex/metrics.py +++ b/tuplex/python/tuplex/metrics.py @@ -30,7 +30,7 @@ class Metrics: context object. """ - def __init__(self, metrics: _Metrics): + def __init__(self, metrics: _Metrics) -> None: """ Creates a Metrics object by using the context object to set its metric parameter and store the resulting @@ -101,7 +101,7 @@ def as_json(self) -> str: assert self._metrics return self._metrics.getJSONString() - def as_dict(self): + def as_dict(self) -> dict: """ all measurements in nested dictionary Returns: diff --git a/tuplex/python/tuplex/repl/__init__.py b/tuplex/python/tuplex/repl/__init__.py index ef1fb8e83..243ca4e20 100644 --- a/tuplex/python/tuplex/repl/__init__.py +++ b/tuplex/python/tuplex/repl/__init__.py @@ -9,6 +9,7 @@ # License: Apache 2.0 # # ----------------------------------------------------------------------------------------------------------------------# +import logging import os import sys @@ -23,8 +24,10 @@ except (ImportError, NameError): __version__ = "dev" +_logger = logging.getLogger(__name__) -def TuplexBanner(): + +def TuplexBanner() -> str: banner = """Welcome to\n _____ _ |_ _| _ _ __ | | _____ __ @@ -55,4 +58,4 @@ def TuplexBanner(): shell.init(locals=_locals) shell.interact(banner=TuplexBanner() + "\n Interactive Shell mode") else: - print(TuplexBanner()) + _logger.info(TuplexBanner()) diff --git a/tuplex/python/tuplex/utils/common.py b/tuplex/python/tuplex/utils/common.py index e0708fc3d..32d6fb47f 100644 --- a/tuplex/python/tuplex/utils/common.py +++ b/tuplex/python/tuplex/utils/common.py @@ -25,6 +25,7 @@ import time import urllib.request from datetime import datetime +from typing import Callable import iso8601 import psutil @@ -42,8 +43,12 @@ except ImportError: __version__ = "dev" +from typing import Any, Optional, Union -def cmd_exists(cmd): +_logger = logging.getLogger(__name__) + + +def cmd_exists(cmd: str) -> bool: """ checks whether command `cmd` exists or not Args: @@ -55,7 +60,7 @@ def cmd_exists(cmd): return shutil.which(cmd) is not None -def is_shared_lib(path): +def is_shared_lib(path: str) -> bool: """ Args: path: str path to a file @@ -74,7 +79,7 @@ def is_shared_lib(path): ) -def current_timestamp(): +def current_timestamp() -> str: """ get current time as isoformatted string Returns: isoformatted current time (utc) @@ -83,7 +88,7 @@ def current_timestamp(): return str(datetime.now().isoformat()) -def current_user(): +def current_user() -> str: """ retrieve current user name Returns: username as string @@ -95,7 +100,7 @@ def current_user(): return getpass.getuser() -def host_name(): +def host_name() -> str: """ retrieve host name to identify machine Returns: some hostname as string @@ -107,7 +112,7 @@ def host_name(): return socket.gethostbyaddr(socket.gethostname())[0] -def post_json(url, data): +def post_json(url: str, data: dict) -> dict: """ perform a post request to a REST endpoint with JSON Args: @@ -126,12 +131,12 @@ def post_json(url, data): return json.loads(response.read()) -def get_json(url, timeout=10): +def get_json(url: str, timeout: float = 10) -> dict: """ perform a GET request to given URL Args: url: hostname & port - + timeout: timeout in s Returns: python dictionary of decoded json """ @@ -141,7 +146,7 @@ def get_json(url, timeout=10): return json.loads(response.read()) -def in_jupyter_notebook(): +def in_jupyter_notebook() -> bool: """check whether frameworks runs in jupyter notebook. Returns: ``True`` if the module is running in IPython kernel, @@ -163,7 +168,7 @@ def in_jupyter_notebook(): return False # Probably standard Python interpreter -def in_google_colab(): +def in_google_colab() -> bool: """ check whether framework runs in Google Colab environment Returns: @@ -179,7 +184,7 @@ def in_google_colab(): return shell_name_matching -def is_in_interactive_mode(): +def is_in_interactive_mode() -> bool: """checks whether the module is loaded in an interactive shell session or not Returns: True when in interactive mode. Note that Jupyter notebook also returns True here. @@ -189,7 +194,7 @@ def is_in_interactive_mode(): return bool(getattr(sys, "ps1", sys.flags.interactive)) -def flatten_dict(d, sep=".", parent_key=""): +def flatten_dict(d: dict, sep: str = ".", parent_key: str = "") -> dict: """ flattens a nested dictionary into a flat dictionary by concatenating keys with the separator. Args: d (dict): The dictionary to flatten @@ -210,7 +215,7 @@ def flatten_dict(d, sep=".", parent_key=""): return dict(items) -def unflatten_dict(dictionary, sep="."): +def unflatten_dict(dictionary: dict, sep: str = ".") -> dict: """ unflattens a dictionary into a nested dictionary according to sep Args: @@ -236,7 +241,7 @@ def unflatten_dict(dictionary, sep="."): return resultDict -def save_conf_yaml(conf, file_path): +def save_conf_yaml(conf: dict, file_path: str) -> None: """saves a dictionary holding the configuration options to Tuplex Yaml format. \ Dict can be either flattened or not. @@ -245,7 +250,7 @@ def save_conf_yaml(conf, file_path): file_path: """ - def beautify_nesting(d): + def beautify_nesting(d: Union[dict, Any]) -> Any: # i.e. make lists out of dicts if isinstance(d, dict): items = d.items() @@ -265,7 +270,7 @@ def beautify_nesting(d): f.write(out) -def pythonize_options(options): +def pythonize_options(options: dict) -> dict: """ convert string based options into python objects/types Args: @@ -275,7 +280,7 @@ def pythonize_options(options): dict with python types """ - def parse_string(item): + def parse_string(item: str) -> Any: """ check what kind of variable string represents and convert accordingly Args: @@ -313,7 +318,7 @@ def parse_string(item): return {k: parse_string(v) for k, v in options.items()} -def load_conf_yaml(file_path): +def load_conf_yaml(file_path: str) -> dict: """loads yaml file and converts contents to nested dictionary Args: @@ -322,7 +327,7 @@ def load_conf_yaml(file_path): """ # helper function to get correct nesting from yaml file! - def to_nested_dict(obj): + def to_nested_dict(obj: Any) -> dict: resultDict = dict() if isinstance(obj, list): for item in obj: @@ -349,7 +354,7 @@ def to_nested_dict(obj): return to_nested_dict(d) -def stringify_dict(d): +def stringify_dict(d: dict) -> dict: """convert keys and vals into strings Args: d (dict): dictionary @@ -361,7 +366,7 @@ def stringify_dict(d): return {str(key): str(val) for key, val in d.items()} -def registerLoggingCallback(callback): +def registerLoggingCallback(callback: Callable) -> None: """ register a custom logging callback function with tuplex Args: @@ -373,7 +378,7 @@ def registerLoggingCallback(callback): from ..libexec.tuplex import registerLoggingCallback as ccRegister # create a wrapper to capture exceptions properly and avoid crashing - def wrapper(level, time_info, logger_name, msg): + def wrapper(level: int, time_info: str, logger_name: str, msg: str) -> None: args = (level, time_info, logger_name, msg) try: @@ -384,7 +389,7 @@ def wrapper(level, time_info, logger_name, msg): ccRegister(wrapper) -def logging_callback(level, time_info, logger_name, msg): +def logging_callback(level: int, time_info: str, logger_name: str, msg: str) -> None: """ this is a callback function which can be used to redirect C++ logging to python logging. :param level: logging level as integer, for values cf. PythonCommon.h @@ -441,7 +446,7 @@ def logging_callback(level, time_info, logger_name, msg): # register at exit function to take care of exit handlers -def auto_shutdown_all(): +def auto_shutdown_all() -> None: """ helper function to automatially shutdown whatever is in the global exit handler array. Resets global variable. Returns: @@ -456,13 +461,15 @@ def auto_shutdown_all(): if msg: logging.info(msg) func(args) - logging.info("Shutdown {} successfully".format(name)) + logging.debug("Shutdown {} successfully".format(name)) except Exception: logging.error("Failed to shutdown {}".format(name)) __exit_handlers__ = [] -def register_auto_shutdown(name, func, args, msg=None): +def register_auto_shutdown( + name: str, func: Callable, args: tuple, msg: Optional[str] = None +) -> None: global __exit_handlers__ __exit_handlers__.append((name, func, args, msg)) @@ -470,7 +477,7 @@ def register_auto_shutdown(name, func, args, msg=None): atexit.register(auto_shutdown_all) -def is_process_running(name): +def is_process_running(name: str) -> bool: """ helper function to check if a process is running on the local machine Args: @@ -490,7 +497,9 @@ def is_process_running(name): return False -def mongodb_uri(mongodb_url, mongodb_port, db_name="tuplex-history"): +def mongodb_uri( + mongodb_url: str, mongodb_port: int, db_name: str = "tuplex-history" +) -> str: """ constructs a fully qualified MongoDB URI Args: @@ -505,8 +514,11 @@ def mongodb_uri(mongodb_url, mongodb_port, db_name="tuplex-history"): def check_mongodb_connection( - mongodb_url, mongodb_port, db_name="tuplex-history", timeout=10.0 -): + mongodb_url: str, + mongodb_port: int, + db_name: str = "tuplex-history", + timeout: float = 10.0, +) -> None: """ connects to a MongoDB database instance, raises exception if connection fails Args: @@ -564,7 +576,7 @@ def check_mongodb_connection( logging.debug("Connection test to MongoDB succeeded") -def shutdown_process_via_kill(pid): +def shutdown_process_via_kill(pid: int) -> None: """ issues a KILL signals to a process with pid Args: @@ -578,12 +590,12 @@ def shutdown_process_via_kill(pid): def find_or_start_mongodb( - mongodb_url, - mongodb_port, - mongodb_datapath, - mongodb_logpath, - db_name="tuplex-history", -): + mongodb_url: str, + mongodb_port: int, + mongodb_datapath: str, + mongodb_logpath: str, + db_name: str = "tuplex-history", +) -> None: """ attempts to connect to a MongoDB database. If no running local MongoDB is found, will auto-start a mongodb database. R aises exception when fails. @@ -717,7 +729,7 @@ def find_or_start_mongodb( check_mongodb_connection(mongodb_url, mongodb_port, db_name) -def log_gunicorn_errors(logpath): +def log_gunicorn_errors(logpath: str) -> None: """ uses logging module to print out gunicorn errors if something went wrong Args: @@ -739,7 +751,9 @@ def log_gunicorn_errors(logpath): logging.error("Gunicorn error log:\n {}".format("".join(lines[first_idx:]))) -def find_or_start_webui(mongo_uri, hostname, port, web_logfile): +def find_or_start_webui( + mongo_uri: str, hostname: str, port: int, web_logfile: str +) -> None: """ tries to connect to Tuplex WebUI. If local uri is specified, autostarts WebUI. Args: @@ -950,7 +964,7 @@ def find_or_start_webui(mongo_uri, hostname, port, web_logfile): "Adding auto-shutdown of process with PID={} (WebUI)".format(ui_pid) ) - def shutdown_gunicorn(pid): + def shutdown_gunicorn(pid: int) -> None: pids_to_kill = [] # iterate over all gunicorn processes and kill them all @@ -991,7 +1005,7 @@ def shutdown_gunicorn(pid): return version_info -def ensure_webui(options): +def ensure_webui(options: dict) -> None: """ Helper function to ensure WebUI/MongoDB is auto-started when webui is specified Args: @@ -1054,7 +1068,7 @@ def ensure_webui(options): webui_uri = webui_url + ":" + str(webui_port) if not webui_uri.startswith("http"): webui_uri = "http://" + webui_uri - print("Tuplex WebUI can be accessed under {}".format(webui_uri)) + _logger.info("Tuplex WebUI can be accessed under {}".format(webui_uri)) except Exception as e: logging.error( "Failed to start or connect to Tuplex WebUI. Details: {}".format(e) diff --git a/tuplex/python/tuplex/utils/globs.py b/tuplex/python/tuplex/utils/globs.py index f938e5035..83c31cb6d 100644 --- a/tuplex/python/tuplex/utils/globs.py +++ b/tuplex/python/tuplex/utils/globs.py @@ -15,6 +15,8 @@ import sys import types import weakref +from types import CodeType +from typing import Any, Callable, List, Tuple # ALWAYS import cloudpickle before dill, b.c. of https://github.com/uqfoundation/dill/issues/383 from cloudpickle.cloudpickle import _get_cell_contents @@ -31,7 +33,7 @@ EXTENDED_ARG = dis.EXTENDED_ARG -def _extract_code_globals(co): +def _extract_code_globals(co: CodeType) -> dict: """ Find all globals names read or written to by codeblock co """ @@ -55,7 +57,7 @@ def _extract_code_globals(co): return out_names -def _find_imported_submodules(code, top_level_dependencies): +def _find_imported_submodules(code: CodeType, top_level_dependencies: List[Any]) -> Any: """ Find currently imported submodules used by a function. Submodules used by a function need to be detected and referenced for the @@ -103,7 +105,7 @@ def func(): return subimports -def _walk_global_ops(code): +def _walk_global_ops(code: Any) -> Any: """ Yield (opcode, argument number) tuples for all global-referencing instructions in *code*. @@ -114,7 +116,7 @@ def _walk_global_ops(code): yield instr.arg, instr.argval -def _function_getstate(func): +def _function_getstate(func: Callable) -> Tuple[dict, dict]: # - Put func's dynamic attributes (stored in func.__dict__) in state. These # attributes will be restored at unpickling time using # f.__dict__.update(state) @@ -163,7 +165,7 @@ def _function_getstate(func): # end from cloudpickle -def get_globals(func): +def get_globals(func: Callable) -> dict: _, d = _function_getstate(func) func_globals = d["__globals__"] diff --git a/tuplex/python/tuplex/utils/interactive_shell.py b/tuplex/python/tuplex/utils/interactive_shell.py index 56a929b02..e45e9a5e9 100644 --- a/tuplex/python/tuplex/utils/interactive_shell.py +++ b/tuplex/python/tuplex/utils/interactive_shell.py @@ -17,6 +17,7 @@ import sys from code import InteractiveConsole from types import FunctionType, LambdaType +from typing import Callable, Optional from prompt_toolkit.history import InMemoryHistory @@ -42,11 +43,11 @@ # the idea is basically, we can't simply call 'import tuplex' because this would # lead to a circular import. Yet, for user convenience, simply exposing tuplex.Context should be sufficient! class TuplexModuleHelper: - def __init__(self, context_cls): + def __init__(self, context_cls: "Context") -> None: self._context_cls = context_cls @property - def Context(self): + def Context(self) -> "Context": return self._context_cls @@ -56,15 +57,15 @@ class TuplexShell(InteractiveConsole): # use BORG design pattern to make class singleton alike __shared_state = {} - def __init__(self): + def __init__(self) -> None: self.__dict__ = self.__shared_state def init( self, - locals=None, - filename="", - histfile=os.path.expanduser("~/.console_history"), - ): + locals: Optional[dict] = None, + filename: str = "", + histfile: str = os.path.expanduser("~/.console_history"), + ) -> None: # add dummy helper for context if locals is not None and "Context" in locals.keys(): locals["tuplex"] = TuplexModuleHelper(locals["Context"]) @@ -76,7 +77,7 @@ def init( self._lastLine = "" self.historyDict = {} - def push(self, line): + def push(self, line: str) -> bool: """Push a line to the interpreter. The line should not have a trailing newline; it may have internal newlines. The line is appended to a buffer and the @@ -111,7 +112,7 @@ def push(self, line): return more - def get_lambda_source(self, f): + def get_lambda_source(self, f: Callable) -> str: # Won't this work for functions as well? assert self.initialized, "must call init on TuplexShell object first" @@ -135,7 +136,7 @@ def get_lambda_source(self, f): vault.extractAndPutAllLambdas(src_info, f_filename, f_lineno, f_colno, f_globs) return vault.get(f, f_filename, f_lineno, f_colno, f_globs) - def get_function_source(self, f): + def get_function_source(self, f: Callable) -> str: assert self.initialized, "must call init on TuplexShell object first" assert isinstance(f, FunctionType) and f.__code__.co_name != "", ( @@ -166,13 +167,15 @@ def get_function_source(self, f): logging.error( 'Could not find function "{}" in source'.format(function_name) ) - return None + return "" return source # taken from Lib/code.py # overwritten to customize behaviour - def interact(self, banner=None, exitmsg=None): + def interact( + self, banner: Optional[str] = None, exitmsg: Optional[str] = None + ) -> None: """Closely emulate the interactive Python console. The optional banner argument specifies the banner to print before the first interaction; by default it prints a banner diff --git a/tuplex/python/tuplex/utils/jedi_completer.py b/tuplex/python/tuplex/utils/jedi_completer.py index deecf8517..613b5d7c9 100644 --- a/tuplex/python/tuplex/utils/jedi_completer.py +++ b/tuplex/python/tuplex/utils/jedi_completer.py @@ -9,20 +9,24 @@ # License: Apache 2.0 # # ----------------------------------------------------------------------------------------------------------------------# +from typing import Any, List + from jedi import Interpreter, settings -from prompt_toolkit.completion import Completer, Completion +from prompt_toolkit.completion import CompleteEvent, Completer, Completion +from prompt_toolkit.document import Document class JediCompleter(Completer): """REPL Completer using jedi""" - def __init__(self, get_locals): + def __init__(self, get_locals: Any) -> None: # per default jedi is case insensitive, however we want it to be case sensitive settings.case_insensitive_completion = False - self.get_locals = get_locals - def get_completions(self, document, complete_event): + def get_completions( + self, document: Document, complete_event: CompleteEvent + ) -> List[Completion]: _locals = self.get_locals() interpreter = Interpreter(document.text, [_locals]) diff --git a/tuplex/python/tuplex/utils/jupyter.py b/tuplex/python/tuplex/utils/jupyter.py index 40fa34f70..59272c0a3 100644 --- a/tuplex/python/tuplex/utils/jupyter.py +++ b/tuplex/python/tuplex/utils/jupyter.py @@ -19,14 +19,14 @@ from notebook.notebookapp import list_running_servers -def get_jupyter_notebook_info(): +def get_jupyter_notebook_info() -> dict: """ retrieve infos about the currently running jupyter notebook if possible Returns: dict with several info attributes. If info for current notebook could not be retrieved, returns empty dict """ - def get(url): + def get(url: str) -> dict: req = urllib.request.Request(url, headers={"content-type": "application/json"}) response = urllib.request.urlopen(req) return json.loads(response.read()) diff --git a/tuplex/python/tuplex/utils/reflection.py b/tuplex/python/tuplex/utils/reflection.py index 258da0b27..6397ab8f2 100644 --- a/tuplex/python/tuplex/utils/reflection.py +++ b/tuplex/python/tuplex/utils/reflection.py @@ -11,8 +11,10 @@ import ast import inspect +import logging import re import types +from typing import Callable, List, Tuple, Union # ALWAYS import cloudpickle before dill, b.c. of https://github.com/uqfoundation/dill/issues/383 import dill @@ -29,8 +31,10 @@ # only export get_source function, rest shall be private. __all__ = ["get_source", "get_globals", "supports_lambda_closure"] +_logger = logging.getLogger(__name__) -def get_jupyter_raw_code(function_name): + +def get_jupyter_raw_code(function_name: str) -> str: # Ignore here unresolved reference, get_ipython() works in jupyter notebook. history_manager = get_ipython().history_manager # noqa: F821 hist = history_manager.get_range() @@ -55,18 +59,20 @@ def get_jupyter_raw_code(function_name): return matched_cells[-1][2] -def extractFunctionByName(code, func_name, return_linenos=False): +def extractFunctionByName( + code: str, func_name: str, return_linenos: bool = False +) -> Union[str, Tuple[str, int, int]]: class FunctionVisitor(ast.NodeVisitor): - def __init__(self): - self.lastStmtLineno = 0 - self.funcInfo = [] + def __init__(self) -> None: + self.lastStmtLineno: int = 0 + self.funcInfo: List[dict] = [] - def visit_FunctionDef(self, node): - print(self.lastStmtLineno) + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + _logger.debug(self.lastStmtLineno) self.generic_visit(node) - print(self.lastStmtLineno) + _logger.debug(self.lastStmtLineno) - def visit(self, node): + def visit(self, node: ast.AST) -> None: funcStartLineno = -1 if hasattr(node, "lineno"): self.lastStmtLineno = node.lineno @@ -89,7 +95,7 @@ def visit(self, node): # find function with name candidates = filter(lambda x: x["name"] == func_name, fv.funcInfo) - def indent(s): + def indent(s: str) -> int: return len(s) - len(s.lstrip(" \t")) lines = code.split("\n") @@ -106,9 +112,9 @@ def indent(s): return func_code -def extract_function_code(function_name, raw_code): +def extract_function_code(function_name: str, raw_code: str) -> str: # remove greedily up to num_tabs and num_spaces - def remove_tabs_and_spaces(line, num_tabs, num_spaces): + def remove_tabs_and_spaces(line: str, num_tabs: int, num_spaces: int) -> str: t = 0 s = 0 pos = 0 @@ -147,7 +153,7 @@ def remove_tabs_and_spaces(line, num_tabs, num_spaces): return extractFunctionByName(out, function_name) -def get_function_code(f): +def get_function_code(f: Callable) -> str: """jupyter notebook, retrieve function history""" assert isinstance(f, types.FunctionType) function_name = f.__code__.co_name @@ -175,7 +181,7 @@ def get_function_code(f): vault = SourceVault() -def get_source(f): +def get_source(f: Callable) -> str: """Jupyter notebook code reflection""" if isinstance(f, types.FunctionType): diff --git a/tuplex/python/tuplex/utils/source_vault.py b/tuplex/python/tuplex/utils/source_vault.py index 6e25e8d60..f95e04c5b 100644 --- a/tuplex/python/tuplex/utils/source_vault.py +++ b/tuplex/python/tuplex/utils/source_vault.py @@ -14,11 +14,12 @@ import os import sys from types import CodeType, LambdaType +from typing import Callable, List, Optional, Tuple import astor -def supports_lambda_closure(): +def supports_lambda_closure() -> bool: """ source code of lambdas can't be extracted, because there's no column information available in code objects. This can be achieved by patching 4 lines in the cpython source code. @@ -31,11 +32,11 @@ def supports_lambda_closure(): return hasattr(f.__code__, "co_firstcolno") -def extract_all_lambdas(tree): +def extract_all_lambdas(tree: ast.AST) -> List[ast.Lambda]: lambdas = [] class Visitor(ast.NodeVisitor): - def visit_Lambda(self, node): + def visit_Lambda(self, node: ast.Lambda) -> None: lambdas.append(node) Visitor().visit(tree) @@ -45,11 +46,11 @@ def visit_Lambda(self, node): # extract for lambda incl. default values # annotations are not possible with the current syntax... -def args_for_lambda_ast(lam): - return [getattr(n, "arg") for n in lam.args.args] +def args_for_lambda_ast(lam: Callable) -> List[str]: + return [n.arg for n in lam.args.args] -def gen_code_for_lambda(lam): +def gen_code_for_lambda(lam: Callable) -> str: # surround in try except if user provided malformed lambdas try: s = astor.to_source(lam) @@ -80,7 +81,7 @@ def gen_code_for_lambda(lam): return "" -def hash_code_object(code): +def hash_code_object(code: CodeType) -> bytes: # can't take the full object because this includes memory addresses # need to hash contents # for this use bytecode, varnames & constants @@ -97,8 +98,7 @@ def hash_code_object(code): return ret + b")" -# join lines and remove stupid \\n -def remove_line_breaks(source_lines): +def remove_line_breaks(source_lines: List[str]) -> str: """ expressions may be defined over multiple line using \ in python. This function removes this and joins lines. Args: @@ -130,22 +130,21 @@ class SourceVault: # borg pattern __shared_state = {} - def __init__(self): + def __init__(self) -> None: self.__dict__ = self.__shared_state self.lambdaDict = {} # new: lookup via filename, lineno and colno self.lambdaFileDict = {} - # def get(self, obj): - # """ - # returns source code for given object - # :param codeboj: - # :return: - # """ - # assert isinstance(obj, LambdaType), 'object needs to be a lambda object' - # return self.lambdaDict[hash_code_object(obj.__code__)] - def get(self, ftor, filename, lineno, colno, globs): + def get( + self, + ftor: Callable, + filename: str, + lineno: int, + colno: Optional[int], + globs: dict, + ) -> str: assert isinstance(ftor, LambdaType), "object needs to be a lambda object" # perform multiway lookup for code @@ -184,7 +183,14 @@ def get(self, ftor, filename, lineno, colno, globs): else: raise KeyError("could not find lambda function") - def extractAndPutAllLambdas(self, src_info, filename, lineno, colno, globals): + def extractAndPutAllLambdas( + self, + src_info: Tuple[List[str], int], + filename: str, + lineno: int, + colno: Optional[int], + globals: dict, + ) -> None: """ extracts the source code from all lambda functions and stores them in the source vault :param source: diff --git a/tuplex/python/tuplex/utils/tracebacks.py b/tuplex/python/tuplex/utils/tracebacks.py index eb5ba3aed..6b7e82789 100644 --- a/tuplex/python/tuplex/utils/tracebacks.py +++ b/tuplex/python/tuplex/utils/tracebacks.py @@ -12,21 +12,23 @@ import linecache import re import traceback +from types import TracebackType +from typing import Any, Callable from .reflection import get_source __all__ = ["traceback_from_udf"] -def format_traceback(tb, function_name): +def format_traceback(tb: TracebackType, function_name: str) -> str: """ helper function to format a traceback object with line numbers relative to function definition Args: - tb: - function_name: + tb: traceback object + function_name: name of function to add to traceback Returns: - + formatted traceback string """ fnames = set() @@ -67,7 +69,7 @@ def format_traceback(tb, function_name): # get traceback from sample -def traceback_from_udf(udf, x): +def traceback_from_udf(udf: Callable, x: Any) -> str: """ get a formatted traceback as string by executing a udf over a sample Args: