From 0d18bcb197809bac574fd29fc93e16a00c3ea3c4 Mon Sep 17 00:00:00 2001 From: Leonhard Spiegelberg Date: Thu, 6 Mar 2025 23:25:44 -0800 Subject: [PATCH 1/5] adding ruff to code --- .pre-commit-config.yaml | 23 + pyproject.toml | 4 + tuplex/python/tuplex/__init__.py | 21 +- tuplex/python/tuplex/context.py | 223 +++++--- tuplex/python/tuplex/dataset.py | 343 ++++++++----- tuplex/python/tuplex/distributed.py | 366 ++++++++----- tuplex/python/tuplex/exceptions.py | 118 ++--- tuplex/python/tuplex/libexec/__init__.py | 4 +- tuplex/python/tuplex/libexec/_tuplex.py | 4 +- tuplex/python/tuplex/metrics.py | 6 +- tuplex/python/tuplex/repl/__init__.py | 30 +- tuplex/python/tuplex/utils/__init__.py | 4 +- tuplex/python/tuplex/utils/common.py | 481 +++++++++++------- tuplex/python/tuplex/utils/errors.py | 8 +- tuplex/python/tuplex/utils/framework.py | 9 +- tuplex/python/tuplex/utils/globs.py | 61 ++- .../python/tuplex/utils/interactive_shell.py | 95 ++-- tuplex/python/tuplex/utils/jedi_completer.py | 27 +- tuplex/python/tuplex/utils/jupyter.py | 41 +- tuplex/python/tuplex/utils/reflection.py | 117 +++-- tuplex/python/tuplex/utils/source_vault.py | 165 +++--- tuplex/python/tuplex/utils/tracebacks.py | 37 +- tuplex/python/tuplex/utils/version.py | 2 +- 23 files changed, 1348 insertions(+), 841 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..7231bef27 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,23 @@ +repos: +#- repo: https://github.com/pre-commit/pre-commit-hooks +# rev: v2.3.0 +# hooks: +# - id: check-yaml +# exclude: ["tuplex/test/resources"] +# - id: end-of-file-fixer +# exclude: ["tuplex/test/resources"] +# - id: trailing-whitespace +# exclude: ["tuplex/test/resources"] +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.9.9 + hooks: + # Run the linter. + - id: ruff + files: ^tuplex/python/tuplex.*\.py$ + args: [ --fix ] + types_or: [ python, pyi ] + # Run the formatter. + - id: ruff-format + files: ^tuplex/python/tuplex.*\.py$ + types_or: [ python, pyi ] \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index aefc4e5dc..dc7fe4af5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,3 +12,7 @@ requires = [ "requests" ] build-backend = "setuptools.build_meta" + + +[tool.ruff] +include = ["pyproject.toml", "tuplex/python/tuplex/**/*.py"] diff --git a/tuplex/python/tuplex/__init__.py b/tuplex/python/tuplex/__init__.py index ee06cd764..a8f0aa3d1 100644 --- a/tuplex/python/tuplex/__init__.py +++ b/tuplex/python/tuplex/__init__.py @@ -20,18 +20,23 @@ from tuplex.utils.version import __version__ as __version__ + # for convenience create a dummy function to return a default-configured Lambda context def LambdaContext(conf=None, name=None, s3_scratch_dir=None, **kwargs): import uuid if s3_scratch_dir is None: s3_scratch_dir = tuplex.distributed.default_scratch_dir() - logging.debug('Detected default S3 scratch dir for this user as {}'.format(s3_scratch_dir)) + logging.debug( + "Detected default S3 scratch dir for this user as {}".format(s3_scratch_dir) + ) - lambda_conf = {'backend': 'lambda', - 'partitionSize': '1MB', - 'aws.scratchDir': s3_scratch_dir, - 'aws.requesterPay': True} + lambda_conf = { + "backend": "lambda", + "partitionSize": "1MB", + "aws.scratchDir": s3_scratch_dir, + "aws.requesterPay": True, + } if conf: lambda_conf.update(conf) @@ -40,13 +45,13 @@ def LambdaContext(conf=None, name=None, s3_scratch_dir=None, **kwargs): for k, v in kwargs.items(): if k in conf.keys(): lambda_conf[k] = v - elif 'tuplex.' + k in conf.keys(): - lambda_conf['tuplex.' + k] = v + elif "tuplex." + k in conf.keys(): + lambda_conf["tuplex." + k] = v else: lambda_conf[k] = v if name is None: - name = 'AWSLambdaContext-' + str(uuid.uuid4())[:8] + name = "AWSLambdaContext-" + str(uuid.uuid4())[:8] # There's currently a bug in the Lambda backend when transferring local data to S3: The full partition # gets transferred, not just what is needed. diff --git a/tuplex/python/tuplex/context.py b/tuplex/python/tuplex/context.py index f92a5ddee..0750dc643 100644 --- a/tuplex/python/tuplex/context.py +++ b/tuplex/python/tuplex/context.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# # # # Tuplex: Blazing Fast Python Data Science # # # @@ -7,7 +7,7 @@ # (c) 2017 - 2021, Tuplex team # # Created by Leonhard Spiegelberg first on 1/1/2021 # # License: Apache 2.0 # -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# import logging @@ -20,14 +20,29 @@ import os import glob import sys -import cloudpickle -from tuplex.utils.common import flatten_dict, load_conf_yaml, stringify_dict, unflatten_dict, save_conf_yaml, in_jupyter_notebook, in_google_colab, is_in_interactive_mode, current_user, is_shared_lib, host_name, ensure_webui, pythonize_options, logging_callback, registerLoggingCallback +from tuplex.utils.common import ( + flatten_dict, + load_conf_yaml, + stringify_dict, + unflatten_dict, + save_conf_yaml, + in_jupyter_notebook, + in_google_colab, + is_in_interactive_mode, + current_user, + is_shared_lib, + host_name, + ensure_webui, + pythonize_options, + logging_callback, + registerLoggingCallback, +) import uuid import json from .metrics import Metrics -class Context: +class Context: def __init__(self, conf=None, name="", **kwargs): r"""creates new Context object, the main entry point for all operations with the Tuplex big data framework @@ -81,8 +96,10 @@ def __init__(self, conf=None, name="", **kwargs): only serializes data that is required within the pipeline. """ - runtime_path = os.path.join(os.path.dirname(__file__), 'libexec', 'tuplex_runtime') - paths = glob.glob(runtime_path + '*') + runtime_path = os.path.join( + os.path.dirname(__file__), "libexec", "tuplex_runtime" + ) + paths = glob.glob(runtime_path + "*") if len(paths) != 1: # filter based on type (runtime must be shared object!) @@ -90,9 +107,15 @@ def __init__(self, conf=None, name="", **kwargs): if len(paths) != 1: if len(paths) == 0: - logging.error("found no tuplex runtime (tuplex_runtime.so). Faulty installation?") + logging.error( + "found no tuplex runtime (tuplex_runtime.so). Faulty installation?" + ) else: - logging.error('found following candidates for tuplex runtime:\n{}, please specify which to use.'.format(paths)) + logging.error( + "found following candidates for tuplex runtime:\n{}, please specify which to use.".format( + paths + ) + ) sys.exit(1) # pass configuration options @@ -102,17 +125,17 @@ def __init__(self, conf=None, name="", **kwargs): # put meaningful defaults for special environments... # per default disable webui - options['tuplex.webui.enable'] = False + options["tuplex.webui.enable"] = False if in_google_colab(): - logging.debug('Detected Google Colab environment, adjusting options...') + logging.debug("Detected Google Colab environment, adjusting options...") # do not use a lot of memory, restrict... - options['tuplex.driverMemory'] = '64MB' - options['tuplex.executorMemory'] = '64MB' - options['tuplex.inputSplitSize'] = '16MB' - options['tuplex.partitionSize'] = '4MB' - options['tuplex.runTimeMemory'] = '16MB' - options['tuplex.webui.enable'] = False + options["tuplex.driverMemory"] = "64MB" + options["tuplex.executorMemory"] = "64MB" + options["tuplex.inputSplitSize"] = "16MB" + options["tuplex.partitionSize"] = "4MB" + options["tuplex.runTimeMemory"] = "16MB" + options["tuplex.webui.enable"] = False if conf: if isinstance(conf, str): @@ -129,55 +152,59 @@ def __init__(self, conf=None, name="", **kwargs): options = stringify_dict(options) user = current_user() - name = name if len(name) > 0 else 'context' + str(uuid.uuid4())[:8] - mode = 'file' + name = name if len(name) > 0 else "context" + str(uuid.uuid4())[:8] + mode = "file" if is_in_interactive_mode(): - mode = 'shell' + mode = "shell" if in_jupyter_notebook(): - mode = 'jupyter' + mode = "jupyter" if in_google_colab(): - mode = 'colab' + mode = "colab" host = host_name() # pass above options as env.user, ... # also pass runtime path like that - options['tuplex.env.user'] = str(user) - options['tuplex.env.hostname'] = str(host) - options['tuplex.env.mode'] = str(mode) + options["tuplex.env.user"] = str(user) + options["tuplex.env.hostname"] = str(host) + options["tuplex.env.mode"] = str(mode) # update runtime path according to user - if 'tuplex.runTimeLibrary' in options: - runtime_path = options['tuplex.runTimeLibrary'] + if "tuplex.runTimeLibrary" in options: + runtime_path = options["tuplex.runTimeLibrary"] # normalize keys to be of format tuplex. supported_keys = json.loads(getDefaultOptionsAsJSON()).keys() key_set = set(options.keys()) for k in key_set: - if k not in supported_keys and 'tuplex.' + k in supported_keys: - options['tuplex.' + k] = options[k] + if k not in supported_keys and "tuplex." + k in supported_keys: + options["tuplex." + k] = options[k] # check if redirect to python logging module should happen or not - if 'tuplex.redirectToPythonLogging' in options.keys(): + if "tuplex.redirectToPythonLogging" in options.keys(): py_opts = pythonize_options(options) - if py_opts['tuplex.redirectToPythonLogging']: - logging.info('Redirecting C++ logging to Python') + if py_opts["tuplex.redirectToPythonLogging"]: + logging.info("Redirecting C++ logging to Python") registerLoggingCallback(logging_callback) else: # check what default options say defaults = pythonize_options(json.loads(getDefaultOptionsAsJSON())) - if defaults['tuplex.redirectToPythonLogging']: - logging.info('Redirecting C++ logging to Python') + if defaults["tuplex.redirectToPythonLogging"]: + logging.info("Redirecting C++ logging to Python") registerLoggingCallback(logging_callback) # autostart mongodb & history server if they are not running yet... # deactivate webui for google colab per default - if 'tuplex.webui.enable' not in options: + if "tuplex.webui.enable" not in options: # for google colab env, disable webui per default. if in_google_colab(): - options['tuplex.webui.enable'] = False + options["tuplex.webui.enable"] = False # fetch default options for webui ... - webui_options = {k: v for k, v in json.loads(getDefaultOptionsAsJSON()).items() if 'webui' in k or 'scratch' in k} + webui_options = { + k: v + for k, v in json.loads(getDefaultOptionsAsJSON()).items() + if "webui" in k or "scratch" in k + } # update only non-existing options! for k, v in webui_options.items(): @@ -187,27 +214,27 @@ def __init__(self, conf=None, name="", **kwargs): # pythonize options = pythonize_options(options) - if options['tuplex.webui.enable']: + if options["tuplex.webui.enable"]: ensure_webui(options) # last arg are the options as json string serialized b.c. of boost python problems # because webui=False/True is convenient, pass it as well to tuplex options - if 'tuplex.webui' in options.keys(): - options['tuplex.webui.enable'] = options['tuplex.webui'] - del options['tuplex.webui'] - if 'webui' in options.keys(): - options['tuplex.webui.enable'] = options['webui'] - del options['webui'] + if "tuplex.webui" in options.keys(): + options["tuplex.webui.enable"] = options["tuplex.webui"] + del options["tuplex.webui"] + if "webui" in options.keys(): + options["tuplex.webui.enable"] = options["webui"] + del options["webui"] # last arg are the options as json string serialized b.c. of boost python problems self._context = _Context(name, runtime_path, json.dumps(options)) python_metrics = self._context.getMetrics() - assert python_metrics, 'internal error: metrics object should be valid' + assert python_metrics, "internal error: metrics object should be valid" self.metrics = Metrics(python_metrics) assert self.metrics def parallelize(self, value_list, columns=None, schema=None, auto_unpack=True): - """ passes data to the Tuplex framework. Must be a list of primitive objects (e.g. of type bool, int, float, str) or + """passes data to the Tuplex framework. Must be a list of primitive objects (e.g. of type bool, int, float, str) or a list of (nested) tuples of these types. Args: @@ -229,20 +256,30 @@ def parallelize(self, value_list, columns=None, schema=None, auto_unpack=True): num_cols = 1 if isinstance(value_list[0], (list, tuple)): num_cols = len(value_list[0]) - cols = ['column{}'.format(i) for i in range(num_cols)] + cols = ["column{}".format(i) for i in range(num_cols)] else: cols = columns for col in cols: - assert isinstance(col, str), 'element {} must be a string'.format(col) - + assert isinstance(col, str), "element {} must be a string".format(col) ds = DataSet() - ds._dataSet = self._context.parallelize(value_list, columns, schema, auto_unpack) + ds._dataSet = self._context.parallelize( + value_list, columns, schema, auto_unpack + ) return ds - def csv(self, pattern, columns=None, header=None, delimiter=None, quotechar='"', null_values=[''], type_hints={}): - """ reads csv (comma separated values) files. This function may either be provided with + def csv( + self, + pattern, + columns=None, + header=None, + delimiter=None, + quotechar='"', + null_values=[""], + type_hints={}, + ): + """reads csv (comma separated values) files. This function may either be provided with parameters that help to determine the delimiter, whether a header present or what kind of quote char is used. Overall, CSV parsing is done according to the RFC-4180 standard (cf. https://tools.ietf.org/html/rfc4180) @@ -274,27 +311,41 @@ def csv(self, pattern, columns=None, header=None, delimiter=None, quotechar='"', if not null_values: null_values = [] - assert isinstance(pattern, str), 'file pattern must be given as str' - assert isinstance(columns, list) or columns is None, 'columns must be a list or None' - assert isinstance(delimiter, str) or delimiter is None, 'delimiter must be given as , or None for auto detection' - assert isinstance(header, bool) or header is None, 'header must be given as bool or None for auto detection' - assert isinstance(quotechar, str), 'quote char must be given as str' - assert isinstance(null_values, list), 'null_values must be a list of strings representing null values' - assert isinstance(type_hints, dict), 'type_hints must be a dictionary mapping index to type hint' # TODO: update with other options + assert isinstance(pattern, str), "file pattern must be given as str" + assert isinstance(columns, list) or columns is None, ( + "columns must be a list or None" + ) + assert isinstance(delimiter, str) or delimiter is None, ( + "delimiter must be given as , or None for auto detection" + ) + assert isinstance(header, bool) or header is None, ( + "header must be given as bool or None for auto detection" + ) + assert isinstance(quotechar, str), "quote char must be given as str" + assert isinstance(null_values, list), ( + "null_values must be a list of strings representing null values" + ) + assert isinstance(type_hints, dict), ( + "type_hints must be a dictionary mapping index to type hint" + ) # TODO: update with other options if delimiter: - assert len(delimiter) == 1, 'delimiter can only exist out of a single character' - assert len(quotechar) == 1, 'quotechar can only be a single character' + assert len(delimiter) == 1, ( + "delimiter can only exist out of a single character" + ) + assert len(quotechar) == 1, "quotechar can only be a single character" ds = DataSet() - ds._dataSet = self._context.csv(pattern, - columns, - header is None, - header if header is not None else False, - '' if delimiter is None else delimiter, - quotechar, - null_values, - type_hints) + ds._dataSet = self._context.csv( + pattern, + columns, + header is None, + header if header is not None else False, + "" if delimiter is None else delimiter, + quotechar, + null_values, + type_hints, + ) return ds def text(self, pattern, null_values=None): @@ -310,15 +361,17 @@ def text(self, pattern, null_values=None): if not null_values: null_values = [] - assert isinstance(pattern, str), 'file pattern must be given as str' - assert isinstance(null_values, list), 'null_values must be a list of strings representing null values' + assert isinstance(pattern, str), "file pattern must be given as str" + assert isinstance(null_values, list), ( + "null_values must be a list of strings representing null values" + ) ds = DataSet() ds._dataSet = self._context.text(pattern, null_values) return ds def orc(self, pattern, columns=None): - """ reads orc files. + """reads orc files. Args: pattern (str): a file glob pattern, e.g. /data/file.csv or /data/\*.csv or /\*/\*csv columns (list): optional list of columns, will be used as header for the CSV file. @@ -326,15 +379,17 @@ def orc(self, pattern, columns=None): tuplex.dataset.DataSet: A Tuplex Dataset object that allows further ETL operations """ - assert isinstance(pattern, str), 'file pattern must be given as str' - assert isinstance(columns, list) or columns is None, 'columns must be a list or None' + assert isinstance(pattern, str), "file pattern must be given as str" + assert isinstance(columns, list) or columns is None, ( + "columns must be a list or None" + ) ds = DataSet() ds._dataSet = self._context.orc(pattern, columns) return ds def options(self, nested=False): - """ retrieves all framework parameters as dictionary + """retrieves all framework parameters as dictionary Args: nested (bool): When set to true, this will return a nested dictionary. @@ -346,15 +401,15 @@ def options(self, nested=False): opt = self._context.options() # small hack because boost python has problems with nested dicts - opt['tuplex.csv.separators'] = eval(opt['tuplex.csv.separators']) - opt['tuplex.csv.comments'] = eval(opt['tuplex.csv.comments']) + opt["tuplex.csv.separators"] = eval(opt["tuplex.csv.separators"]) + opt["tuplex.csv.comments"] = eval(opt["tuplex.csv.comments"]) if nested: return unflatten_dict(opt) else: return opt - def optionsToYAML(self, file_path='config.yaml'): + def optionsToYAML(self, file_path="config.yaml"): """saves options as yaml file to (local) filepath Args: @@ -413,12 +468,12 @@ def uiWebURL(self): None if webUI was disabled, else URL as string """ options = self.options() - if not options['tuplex.webui.enable']: + if not options["tuplex.webui.enable"]: return None - hostname = options['tuplex.webui.url'] - port = options['tuplex.webui.port'] - url = '{}:{}'.format(hostname, port) - if not url.startswith('http://') or url.startswith('https://'): - url = 'http://' + url + hostname = options["tuplex.webui.url"] + port = options["tuplex.webui.port"] + url = "{}:{}".format(hostname, port) + if not url.startswith("http://") or url.startswith("https://"): + url = "http://" + url return url diff --git a/tuplex/python/tuplex/dataset.py b/tuplex/python/tuplex/dataset.py index 2d7eecc00..7feeeb33f 100644 --- a/tuplex/python/tuplex/dataset.py +++ b/tuplex/python/tuplex/dataset.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# # # # Tuplex: Blazing Fast Python Data Science # # # @@ -7,10 +7,9 @@ # (c) 2017 - 2021, Tuplex team # # Created by Leonhard Spiegelberg first on 1/1/2021 # # License: Apache 2.0 # -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# import cloudpickle -import sys import logging try: @@ -20,23 +19,24 @@ from tuplex.utils.reflection import get_source as get_udf_source from tuplex.utils.reflection import get_globals from tuplex.utils.framework import UDFCodeExtractionError -from tuplex.utils.source_vault import SourceVault from .exceptions import classToExceptionCode # signed 64bit limit max_rows = 9223372036854775807 -class DataSet: +class DataSet: def __init__(self): self._dataSet = None def unique(self): - """ removes duplicates from Dataset (out-of-order). Equivalent to a DISTINCT clause in a SQL-statement. + """removes duplicates from Dataset (out-of-order). Equivalent to a DISTINCT clause in a SQL-statement. Returns: tuplex.dataset.Dataset: A Tuplex Dataset object that allows further ETL operations. """ - assert self._dataSet is not None, 'internal API error, datasets must be created via context object' + assert self._dataSet is not None, ( + "internal API error, datasets must be created via context object" + ) ds = DataSet() ds._dataSet = self._dataSet.unique() @@ -57,16 +57,18 @@ def map(self, ftor): tuplex.dataset.DataSet: A Tuplex Dataset object that allows further ETL operations """ - assert self._dataSet is not None, 'internal API error, datasets must be created via context object' - assert ftor is not None, 'need to provide valid functor' + assert self._dataSet is not None, ( + "internal API error, datasets must be created via context object" + ) + assert ftor is not None, "need to provide valid functor" - code = '' + code = "" # try to get code from vault (only lambdas supported yet!) try: # convert code object to str representation code = get_udf_source(ftor) except UDFCodeExtractionError as e: - logging.warn('Could not extract code for {}. Details:\n{}'.format(ftor, e)) + logging.warn("Could not extract code for {}. Details:\n{}".format(ftor, e)) g = get_globals(ftor) @@ -86,16 +88,18 @@ def filter(self, ftor): tuplex.dataset.DataSet: A Tuplex Dataset object that allows further ETL operations """ - assert self._dataSet is not None, 'internal API error, datasets must be created via context object' - assert ftor is not None, 'need to provide valid functor' + assert self._dataSet is not None, ( + "internal API error, datasets must be created via context object" + ) + assert ftor is not None, "need to provide valid functor" - code = '' + code = "" # try to get code from vault (only lambdas supported yet!) try: # convert code object to str representation code = get_udf_source(ftor) except UDFCodeExtractionError as e: - logging.warn('Could not extract code for {}. Details:\n{}'.format(ftor, e)) + logging.warn("Could not extract code for {}. Details:\n{}".format(ftor, e)) g = get_globals(ftor) ds = DataSet() @@ -103,17 +107,19 @@ def filter(self, ftor): return ds def collect(self): - """ action that generates a physical plan, processes data and collects result then as list of tuples. + """action that generates a physical plan, processes data and collects result then as list of tuples. Returns: (list): A list of tuples, or values if the dataset has only one column. """ - assert self._dataSet is not None, 'internal API error, datasets must be created via context objects' + assert self._dataSet is not None, ( + "internal API error, datasets must be created via context objects" + ) return self._dataSet.collect() def take(self, nrows=5): - """ action that generates a physical plan, processes data and collects the top results then as list of tuples. + """action that generates a physical plan, processes data and collects the top results then as list of tuples. Args: nrows (int): number of rows to collect. Per default ``5``. @@ -122,22 +128,26 @@ def take(self, nrows=5): """ - assert isinstance(nrows, int), 'num rows must be an integer' - assert nrows > 0, 'please specify a number greater than zero' + assert isinstance(nrows, int), "num rows must be an integer" + assert nrows > 0, "please specify a number greater than zero" - assert self._dataSet is not None, 'internal API error, datasets must be created via context objects' + assert self._dataSet is not None, ( + "internal API error, datasets must be created via context objects" + ) return self._dataSet.take(nrows) def show(self, nrows=None): - """ action that generates a physical plan, processes data and prints results as nicely formatted + """action that generates a physical plan, processes data and prints results as nicely formatted ASCII table to stdout. Args: nrows (int): number of rows to collect. If ``None`` all rows will be collected """ - assert self._dataSet is not None, 'internal API error, datasets must be created via context objects' + assert self._dataSet is not None, ( + "internal API error, datasets must be created via context objects" + ) # if optional value is None or below zero, simply return all rows. Else only up to nrows! if nrows is None or nrows < 0: @@ -146,7 +156,7 @@ def show(self, nrows=None): self._dataSet.show(nrows) def resolve(self, eclass, ftor): - """ Adds a resolver operator to the pipeline. The signature of ftor needs to be identical to the one of the preceding operator. + """Adds a resolver operator to the pipeline. The signature of ftor needs to be identical to the one of the preceding operator. Args: eclass: Which exception to apply resolution for, e.g. ZeroDivisionError @@ -158,22 +168,26 @@ def resolve(self, eclass, ftor): """ # check that predicate is a class for an exception class - assert issubclass(eclass, Exception), 'predicate must be a subclass of Exception' + assert issubclass(eclass, Exception), ( + "predicate must be a subclass of Exception" + ) # translate to C++ exception code enum ec = classToExceptionCode(eclass) - assert self._dataSet is not None, 'internal API error, datasets must be created via context objects' + assert self._dataSet is not None, ( + "internal API error, datasets must be created via context objects" + ) - assert ftor is not None, 'need to provide valid functor' + assert ftor is not None, "need to provide valid functor" - code = '' + code = "" # try to get code from vault (only lambdas supported yet!) try: # convert code object to str representation code = get_udf_source(ftor) except UDFCodeExtractionError as e: - logging.warn('Could not extract code for {}. Details:\n{}'.format(ftor, e)) + logging.warn("Could not extract code for {}. Details:\n{}".format(ftor, e)) g = get_globals(ftor) ds = DataSet() @@ -181,7 +195,7 @@ def resolve(self, eclass, ftor): return ds def withColumn(self, column, ftor): - """ appends a new column to the dataset by calling ftor over existing tuples + """appends a new column to the dataset by calling ftor over existing tuples Args: column: name for the new column/variable. If column exists, its values will be replaced @@ -192,24 +206,26 @@ def withColumn(self, column, ftor): """ - assert self._dataSet is not None, 'internal API error, datasets must be created via context object' - assert ftor is not None, 'need to provide valid functor' - assert isinstance(column, str), 'column needs to be a string' + assert self._dataSet is not None, ( + "internal API error, datasets must be created via context object" + ) + assert ftor is not None, "need to provide valid functor" + assert isinstance(column, str), "column needs to be a string" - code = '' + code = "" # try to get code from vault (only lambdas supported yet!) try: # convert code object to str representation code = get_udf_source(ftor) except UDFCodeExtractionError as e: - logging.warn('Could not extract code for {}. Details:\n{}'.format(ftor, e)) + logging.warn("Could not extract code for {}. Details:\n{}".format(ftor, e)) g = get_globals(ftor) ds = DataSet() ds._dataSet = self._dataSet.withColumn(column, code, cloudpickle.dumps(ftor), g) return ds def mapColumn(self, column, ftor): - """ maps directly one column. UDF takes as argument directly the value of the specified column and will overwrite + """maps directly one column. UDF takes as argument directly the value of the specified column and will overwrite that column with the result. If you need access to multiple columns, use withColumn instead. If the column name already exists, it will be overwritten. @@ -221,24 +237,26 @@ def mapColumn(self, column, ftor): tuplex.dataset.DataSet: A Tuplex Dataset object that allows further ETL operations """ - assert self._dataSet is not None, 'internal API error, datasets must be created via context object' - assert ftor is not None, 'need to provide valid functor' - assert isinstance(column, str), 'column needs to be a string' + assert self._dataSet is not None, ( + "internal API error, datasets must be created via context object" + ) + assert ftor is not None, "need to provide valid functor" + assert isinstance(column, str), "column needs to be a string" - code = '' + code = "" # try to get code from vault (only lambdas supported yet!) try: # convert code object to str representation code = get_udf_source(ftor) except UDFCodeExtractionError as e: - logging.warn('Could not extract code for {}. Details:\n{}'.format(ftor, e)) + logging.warn("Could not extract code for {}. Details:\n{}".format(ftor, e)) g = get_globals(ftor) ds = DataSet() ds._dataSet = self._dataSet.mapColumn(column, code, cloudpickle.dumps(ftor), g) return ds def selectColumns(self, columns): - """ selects a subset of columns as defined through columns which is a list or a single column + """selects a subset of columns as defined through columns which is a list or a single column Args: columns: list of strings or integers. A string should reference a column name, whereas as an integer refers to an index. Indices may be negative according to python rules. Order in list determines output order @@ -248,24 +266,28 @@ def selectColumns(self, columns): """ - assert self._dataSet is not None, 'internal API error, datasets must be created via context object' + assert self._dataSet is not None, ( + "internal API error, datasets must be created via context object" + ) # syntatic sugar, allow single column, list, tuple, ... if isinstance(columns, (str, int)): columns = [columns] if isinstance(columns, tuple): columns = list(columns) - assert(isinstance(columns, list)) + assert isinstance(columns, list) for el in columns: - assert isinstance(el, (str, int)), 'element {} must be a string or int'.format(el) + assert isinstance(el, (str, int)), ( + "element {} must be a string or int".format(el) + ) ds = DataSet() ds._dataSet = self._dataSet.selectColumns(columns) return ds def renameColumn(self, key, newColumnName): - """ rename a column in dataset + """rename a column in dataset Args: key: str|int, old column name or (0-indexed) position. newColumnName: str, new column name @@ -274,10 +296,12 @@ def renameColumn(self, key, newColumnName): Dataset """ - assert self._dataSet is not None, 'internal API error, datasets must be created via context object' + assert self._dataSet is not None, ( + "internal API error, datasets must be created via context object" + ) - assert isinstance(key, (str, int)), 'key must be a string or integer' - assert isinstance(newColumnName, str), 'newColumnName must be a string' + assert isinstance(key, (str, int)), "key must be a string or integer" + assert isinstance(newColumnName, str), "newColumnName must be a string" ds = DataSet() if isinstance(key, str): @@ -285,11 +309,11 @@ def renameColumn(self, key, newColumnName): elif isinstance(key, int): ds._dataSet = self._dataSet.renameColumnByPosition(key, newColumnName) else: - raise TypeError('key must be int or str') + raise TypeError("key must be int or str") return ds def ignore(self, eclass): - """ ignores exceptions of type eclass caused by previous operator + """ignores exceptions of type eclass caused by previous operator Args: eclass: exception type/class to ignore @@ -300,19 +324,23 @@ def ignore(self, eclass): """ # check that predicate is a class for an exception class - assert issubclass(eclass, Exception), 'predicate must be a subclass of Exception' + assert issubclass(eclass, Exception), ( + "predicate must be a subclass of Exception" + ) # translate to C++ exception code enum ec = classToExceptionCode(eclass) - assert self._dataSet is not None, 'internal API error, datasets must be created via context objects' + assert self._dataSet is not None, ( + "internal API error, datasets must be created via context objects" + ) ds = DataSet() ds._dataSet = self._dataSet.ignore(ec) return ds def cache(self, store_specialized=True): - """ materializes rows in main-memory for reuse with several pipelines. Can be also used to benchmark certain pipeline costs + """materializes rows in main-memory for reuse with several pipelines. Can be also used to benchmark certain pipeline costs Args: store_specialized: bool whether to store normal case and general case separated or merge everything into one normal case. This affects optimizations for operators called on a cached dataset. @@ -321,7 +349,9 @@ def cache(self, store_specialized=True): tuplex.dataset.DataSet: A Tuplex Dataset object that allows further ETL operations """ - assert self._dataSet is not None, 'internal API error, datasets must be created via context object' + assert self._dataSet is not None, ( + "internal API error, datasets must be created via context object" + ) ds = DataSet() ds._dataSet = self._dataSet.cache(store_specialized) @@ -329,7 +359,7 @@ def cache(self, store_specialized=True): @property def columns(self): - """ retrieve names of columns if assigned + """retrieve names of columns if assigned Returns: None or List[str]: Returns None if columns haven't been named yet or a list of strings representing the column names. @@ -339,7 +369,7 @@ def columns(self): @property def types(self): - """ output schema as list of type objects of the dataset. If the dataset has an error, None is returned. + """output schema as list of type objects of the dataset. If the dataset has an error, None is returned. Returns: detected types (general case) of dataset. Typed according to typing module. @@ -347,7 +377,9 @@ def types(self): types = self._dataSet.types() return types - def join(self, dsRight, leftKeyColumn, rightKeyColumn, prefixes=None, suffixes=None): + def join( + self, dsRight, leftKeyColumn, rightKeyColumn, prefixes=None, suffixes=None + ): """ (inner) join with other dataset Args: @@ -361,33 +393,46 @@ def join(self, dsRight, leftKeyColumn, rightKeyColumn, prefixes=None, suffixes=N """ - assert self._dataSet is not None, 'internal API error, datasets must be created via context objects' - assert dsRight._dataSet is not None, 'internal API error, datasets must be created via context objects' + assert self._dataSet is not None, ( + "internal API error, datasets must be created via context objects" + ) + assert dsRight._dataSet is not None, ( + "internal API error, datasets must be created via context objects" + ) # process prefixes/suffixes - leftPrefix = '' - leftSuffix = '' - rightPrefix = '' - rightSuffix = '' + leftPrefix = "" + leftSuffix = "" + rightPrefix = "" + rightSuffix = "" if prefixes: prefixes = tuple(prefixes) - assert len(prefixes) == 2, 'prefixes must be a sequence of 2 elements!' - leftPrefix = prefixes[0] if prefixes[0] else '' - rightPrefix = prefixes[1] if prefixes[1] else '' + assert len(prefixes) == 2, "prefixes must be a sequence of 2 elements!" + leftPrefix = prefixes[0] if prefixes[0] else "" + rightPrefix = prefixes[1] if prefixes[1] else "" if suffixes: suffixes = tuple(suffixes) - assert len(suffixes) == 2, 'prefixes must be a sequence of 2 elements!' - leftSuffix = suffixes[0] if suffixes[0] else '' - rightSuffix = suffixes[1] if suffixes[1] else '' + assert len(suffixes) == 2, "prefixes must be a sequence of 2 elements!" + leftSuffix = suffixes[0] if suffixes[0] else "" + rightSuffix = suffixes[1] if suffixes[1] else "" ds = DataSet() - ds._dataSet = self._dataSet.join(dsRight._dataSet, leftKeyColumn, rightKeyColumn, - leftPrefix, leftSuffix, rightPrefix, rightSuffix) + ds._dataSet = self._dataSet.join( + dsRight._dataSet, + leftKeyColumn, + rightKeyColumn, + leftPrefix, + leftSuffix, + rightPrefix, + rightSuffix, + ) return ds - def leftJoin(self, dsRight, leftKeyColumn, rightKeyColumn, prefixes=None, suffixes=None): + def leftJoin( + self, dsRight, leftKeyColumn, rightKeyColumn, prefixes=None, suffixes=None + ): """ left (outer) join with other dataset Args: @@ -401,34 +446,53 @@ def leftJoin(self, dsRight, leftKeyColumn, rightKeyColumn, prefixes=None, suffix """ - assert self._dataSet is not None, 'internal API error, datasets must be created via context objects' - assert dsRight._dataSet is not None, 'internal API error, datasets must be created via context objects' + assert self._dataSet is not None, ( + "internal API error, datasets must be created via context objects" + ) + assert dsRight._dataSet is not None, ( + "internal API error, datasets must be created via context objects" + ) # process prefixes/suffixes - leftPrefix = '' - leftSuffix = '' - rightPrefix = '' - rightSuffix = '' + leftPrefix = "" + leftSuffix = "" + rightPrefix = "" + rightSuffix = "" if prefixes: prefixes = tuple(prefixes) - assert len(prefixes) == 2, 'prefixes must be a sequence of 2 elements!' - leftPrefix = prefixes[0] if prefixes[0] else '' - rightPrefix = prefixes[1] if prefixes[1] else '' + assert len(prefixes) == 2, "prefixes must be a sequence of 2 elements!" + leftPrefix = prefixes[0] if prefixes[0] else "" + rightPrefix = prefixes[1] if prefixes[1] else "" if suffixes: suffixes = tuple(suffixes) - assert len(suffixes) == 2, 'prefixes must be a sequence of 2 elements!' - leftSuffix = suffixes[0] if suffixes[0] else '' - rightSuffix = suffixes[1] if suffixes[1] else '' + assert len(suffixes) == 2, "prefixes must be a sequence of 2 elements!" + leftSuffix = suffixes[0] if suffixes[0] else "" + rightSuffix = suffixes[1] if suffixes[1] else "" ds = DataSet() - ds._dataSet = self._dataSet.leftJoin(dsRight._dataSet, leftKeyColumn, rightKeyColumn, - leftPrefix, leftSuffix, rightPrefix, rightSuffix) + ds._dataSet = self._dataSet.leftJoin( + dsRight._dataSet, + leftKeyColumn, + rightKeyColumn, + leftPrefix, + leftSuffix, + rightPrefix, + rightSuffix, + ) return ds - - def tocsv(self, path, part_size=0, num_rows=max_rows, num_parts=0, part_name_generator=None, null_value=None, header=True): + def tocsv( + self, + path, + part_size=0, + num_rows=max_rows, + num_parts=0, + part_name_generator=None, + null_value=None, + header=True, + ): """ save dataset to one or more csv files. Triggers execution of pipeline. Args: path: path where to save files to @@ -441,10 +505,14 @@ def tocsv(self, path, part_size=0, num_rows=max_rows, num_parts=0, part_name_gen null_value: string to represent null values. None equals empty string. Must provide explicit quoting for this argument. header: bool to indicate whether to write a header or not or a list of strings to specify explicitly a header to write. number of names provided must match the column count. """ - assert self._dataSet is not None, 'internal API error, datasets must be created via context objects' - assert isinstance(header, list) or isinstance(header, bool), 'header must be a list of strings, or a boolean' - - code, code_pickled = '', '' + assert self._dataSet is not None, ( + "internal API error, datasets must be created via context objects" + ) + assert isinstance(header, list) or isinstance(header, bool), ( + "header must be a list of strings, or a boolean" + ) + + code, code_pickled = "", "" if part_name_generator is not None: code_pickled = cloudpickle.dumps(part_name_generator) # try to get code from vault (only lambdas supported yet!) @@ -452,18 +520,29 @@ def tocsv(self, path, part_size=0, num_rows=max_rows, num_parts=0, part_name_gen # convert code object to str representation code = get_udf_source(part_name_generator) except UDFCodeExtractionError as e: - logging.warn('Could not extract code for {}. Details:\n{}'.format(ftor, e)) + logging.warn( + "Could not extract code for {}. Details:\n{}".format(ftor, e) + ) # clamp max rows if num_rows > max_rows: - raise Exception('Tuplex supports at most {} rows'.format(max_rows)) + raise Exception("Tuplex supports at most {} rows".format(max_rows)) if null_value is None: - null_value = '' - - self._dataSet.tocsv(path, code, code_pickled, num_parts, part_size, num_rows, null_value, header) - - def toorc(self, path, part_size=0, num_rows=max_rows, num_parts=0, part_name_generator=None): + null_value = "" + + self._dataSet.tocsv( + path, code, code_pickled, num_parts, part_size, num_rows, null_value, header + ) + + def toorc( + self, + path, + part_size=0, + num_rows=max_rows, + num_parts=0, + part_name_generator=None, + ): """ save dataset to one or more orc files. Triggers execution of pipeline. Args: path: path where to save files to @@ -476,7 +555,7 @@ def toorc(self, path, part_size=0, num_rows=max_rows, num_parts=0, part_name_gen """ assert self._dataSet is not None - code, code_pickled = '', '' + code, code_pickled = "", "" if part_name_generator is not None: code_pickled = cloudpickle.dumps(part_name_generator) # try to get code from vault (only lambdas supported yet!) @@ -484,10 +563,12 @@ def toorc(self, path, part_size=0, num_rows=max_rows, num_parts=0, part_name_gen # convert code object to str representation code = get_udf_source(part_name_generator) except UDFCodeExtractionError as e: - logging.warn('Could not extract code for {}. Details:\n{}'.format(ftor, e)) + logging.warn( + "Could not extract code for {}. Details:\n{}".format(ftor, e) + ) if num_rows > max_rows: - raise Exception('Tuplex supports at most {} rows'.format(max_rows)) + raise Exception("Tuplex supports at most {} rows".format(max_rows)) self._dataSet.toorc(path, code, code_pickled, num_parts, part_size, num_rows) @@ -502,7 +583,7 @@ def aggregate(self, combine, aggregate, initial_value): Dataset """ - comb_code, agg_code = '', '' + comb_code, agg_code = "", "" comb_code_pickled = cloudpickle.dumps(combine) agg_code_pickled = cloudpickle.dumps(aggregate) @@ -510,20 +591,34 @@ def aggregate(self, combine, aggregate, initial_value): # convert code object to str representation comb_code = get_udf_source(combine) except UDFCodeExtractionError as e: - logging.warn('Could not extract code for combine UDF {}. Details:\n{}'.format(combine, e)) + logging.warn( + "Could not extract code for combine UDF {}. Details:\n{}".format( + combine, e + ) + ) try: # convert code object to str representation agg_code = get_udf_source(aggregate) except UDFCodeExtractionError as e: - logging.warn('Could not extract code for aggregate UDF {}. Details:\n{}'.format(aggregate, e)) + logging.warn( + "Could not extract code for aggregate UDF {}. Details:\n{}".format( + aggregate, e + ) + ) g_comb = get_globals(combine) g_agg = get_globals(aggregate) ds = DataSet() - ds._dataSet = self._dataSet.aggregate(comb_code, comb_code_pickled, - agg_code, agg_code_pickled, - cloudpickle.dumps(initial_value), g_comb, g_agg) + ds._dataSet = self._dataSet.aggregate( + comb_code, + comb_code_pickled, + agg_code, + agg_code_pickled, + cloudpickle.dumps(initial_value), + g_comb, + g_agg, + ) return ds def aggregateByKey(self, combine, aggregate, initial_value, key_columns): @@ -546,30 +641,44 @@ def aggregateByKey(self, combine, aggregate, initial_value, key_columns): if isinstance(key_columns, int): key_columns = [key_columns] - comb_code, comb_code_pickled = '', '' - agg_code, agg_code_pickled = '', '' + comb_code, comb_code_pickled = "", "" + agg_code, agg_code_pickled = "", "" try: # convert code object to str representation comb_code = get_udf_source(combine) comb_code_pickled = cloudpickle.dumps(combine) except UDFCodeExtractionError as e: - logging.warn('Could not extract code for combine UDF {}. Details:\n{}'.format(ftor, e)) + logging.warn( + "Could not extract code for combine UDF {}. Details:\n{}".format( + ftor, e + ) + ) try: # convert code object to str representation agg_code = get_udf_source(aggregate) agg_code_pickled = cloudpickle.dumps(aggregate) except UDFCodeExtractionError as e: - logging.warn('Could not extract code for aggregate UDF {}. Details:\n{}'.format(ftor, e)) + logging.warn( + "Could not extract code for aggregate UDF {}. Details:\n{}".format( + ftor, e + ) + ) g_comb = get_globals(combine) g_agg = get_globals(aggregate) ds = DataSet() - ds._dataSet = self._dataSet.aggregateByKey(comb_code, comb_code_pickled, - agg_code, agg_code_pickled, - cloudpickle.dumps(initial_value), key_columns, - g_comb, g_agg) + ds._dataSet = self._dataSet.aggregateByKey( + comb_code, + comb_code_pickled, + agg_code, + agg_code_pickled, + cloudpickle.dumps(initial_value), + key_columns, + g_comb, + g_agg, + ) return ds @property diff --git a/tuplex/python/tuplex/distributed.py b/tuplex/python/tuplex/distributed.py index 4246ef7a5..43a59ca75 100644 --- a/tuplex/python/tuplex/distributed.py +++ b/tuplex/python/tuplex/distributed.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# # # # Tuplex: Blazing Fast Python Data Science # # # @@ -7,51 +7,49 @@ # (c) 2017 - 2021, Tuplex team # # Created by Leonhard Spiegelberg first on 11/4/2021 # # License: Apache 2.0 # -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# try: import boto3 import botocore.exceptions -except Exception as e: +except Exception: # ignore here, because boto3 is optional pass - #raise Exception('To use distributed version, please install boto3') + # raise Exception('To use distributed version, please install boto3') -import logging -import tempfile import logging import os import base64 import datetime -import socket -import json import sys import threading import time # Tuplex specific imports -from tuplex.utils.common import in_jupyter_notebook, in_google_colab, is_in_interactive_mode, current_user, host_name +from tuplex.utils.common import current_user, host_name def current_iam_user(): - iam = boto3.resource('iam') + iam = boto3.resource("iam") user = iam.CurrentUser() return user.user_name.lower() def default_lambda_name(): - return 'tuplex-lambda-runner' + return "tuplex-lambda-runner" def default_lambda_role(): - return 'tuplex-lambda-role' + return "tuplex-lambda-role" def default_bucket_name(): - return 'tuplex-' + current_iam_user() + return "tuplex-" + current_iam_user() + def default_scratch_dir(): - return default_bucket_name() + '/scratch' + return default_bucket_name() + "/scratch" + def current_region(): session = boto3.session.Session() @@ -59,41 +57,52 @@ def current_region(): if region is None: # could do fancier auto-detect here... - return 'us-east-1' + return "us-east-1" return region + def check_credentials(aws_access_key_id=None, aws_secret_access_key=None): kwargs = {} if isinstance(aws_access_key_id, str): - kwargs['aws_access_key_id'] = aws_access_key_id + kwargs["aws_access_key_id"] = aws_access_key_id if isinstance(aws_secret_access_key, str): - kwargs['aws_secret_access_key'] = aws_secret_access_key - client = boto3.client('s3', **kwargs) + kwargs["aws_secret_access_key"] = aws_secret_access_key + client = boto3.client("s3", **kwargs) try: client.list_buckets() except botocore.exceptions.NoCredentialsError as e: - logging.error('Could not connect to AWS, Details: {}. To configure AWS credentials please confer the guide under https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html#configuring-credentials'.format(e)) + logging.error( + "Could not connect to AWS, Details: {}. To configure AWS credentials please confer the guide under https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html#configuring-credentials".format( + e + ) + ) return False return True + def ensure_s3_bucket(s3_client, bucket_name, region): - bucket_names = list(map(lambda b: b['Name'], s3_client.list_buckets()['Buckets'])) + bucket_names = list(map(lambda b: b["Name"], s3_client.list_buckets()["Buckets"])) if bucket_name not in bucket_names: - logging.info('Bucket {} not found, creating (private bucket) in {} ...'.format(bucket_name, region)) + logging.info( + "Bucket {} not found, creating (private bucket) in {} ...".format( + bucket_name, region + ) + ) # bug in boto3: if region == current_region(): s3_client.create_bucket(Bucket=bucket_name) - logging.info('Bucket {} created in {}'.format(bucket_name, region)) + logging.info("Bucket {} created in {}".format(bucket_name, region)) else: - location = {'LocationConstraint': region.strip()} - s3_client.create_bucket(Bucket=bucket_name, - CreateBucketConfiguration=location) - logging.info('Bucket {} created in {}'.format(bucket_name, region)) + location = {"LocationConstraint": region.strip()} + s3_client.create_bucket( + Bucket=bucket_name, CreateBucketConfiguration=location + ) + logging.info("Bucket {} created in {}".format(bucket_name, region)) else: - logging.info('Found bucket {}'.format(bucket_name)) + logging.info("Found bucket {}".format(bucket_name)) def create_lambda_role(iam_client, lambda_role): @@ -102,39 +111,59 @@ def create_lambda_role(iam_client, lambda_role): lambda_access_to_s3 = '{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":["s3:*MultipartUpload*","s3:Get*","s3:ListBucket","s3:Put*"],"Resource":"*"}]}' lambda_invoke_others = '{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":["lambda:InvokeFunction","lambda:InvokeAsync"],"Resource":"*"}]}' - iam_client.create_role(RoleName=lambda_role, - AssumeRolePolicyDocument=trust_policy, - Description='Auto-created Role for Tuplex AWS Lambda runner') - iam_client.attach_role_policy(RoleName=lambda_role, - PolicyArn='arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole') - iam_client.put_role_policy(RoleName=lambda_role, PolicyName='InvokeOtherlambdas', - PolicyDocument=lambda_invoke_others) - iam_client.put_role_policy(RoleName=lambda_role, PolicyName='LambdaAccessForS3', PolicyDocument=lambda_access_to_s3) - logging.info('Created Tuplex AWS Lambda runner role ({})'.format(lambda_role)) + iam_client.create_role( + RoleName=lambda_role, + AssumeRolePolicyDocument=trust_policy, + Description="Auto-created Role for Tuplex AWS Lambda runner", + ) + iam_client.attach_role_policy( + RoleName=lambda_role, + PolicyArn="arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole", + ) + iam_client.put_role_policy( + RoleName=lambda_role, + PolicyName="InvokeOtherlambdas", + PolicyDocument=lambda_invoke_others, + ) + iam_client.put_role_policy( + RoleName=lambda_role, + PolicyName="LambdaAccessForS3", + PolicyDocument=lambda_access_to_s3, + ) + logging.info("Created Tuplex AWS Lambda runner role ({})".format(lambda_role)) # check it exists try: response = iam_client.get_role(RoleName=lambda_role) except: - raise Exception('Failed to create AWS Lambda Role') + raise Exception("Failed to create AWS Lambda Role") def remove_lambda_role(iam_client, lambda_role): # detach policies... try: - iam_client.detach_role_policy(RoleName=lambda_role, - PolicyArn='arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole') + iam_client.detach_role_policy( + RoleName=lambda_role, + PolicyArn="arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole", + ) except Exception as e: logging.error( - 'Error while detaching policy AWSLambdaBasicExecutionRole, Tuplex setup corrupted? Details: {}'.format(e)) + "Error while detaching policy AWSLambdaBasicExecutionRole, Tuplex setup corrupted? Details: {}".format( + e + ) + ) - policy_names = iam_client.list_role_policies(RoleName=lambda_role)['PolicyNames'] + policy_names = iam_client.list_role_policies(RoleName=lambda_role)["PolicyNames"] for name in policy_names: try: iam_client.delete_role_policy(RoleName=lambda_role, PolicyName=name) except Exception as e: - logging.error('Error while detaching policy {}, Tuplex setup corrupted? Details: {}'.format(name, e)) + logging.error( + "Error while detaching policy {}, Tuplex setup corrupted? Details: {}".format( + name, e + ) + ) # delete role... iam_client.delete_role(RoleName=lambda_role) @@ -143,16 +172,18 @@ def remove_lambda_role(iam_client, lambda_role): def setup_lambda_role(iam_client, lambda_role, region, overwrite): try: response = iam_client.get_role(RoleName=lambda_role) - logging.info('Found Lambda role from {}'.format(response['Role']['CreateDate'])) + logging.info("Found Lambda role from {}".format(response["Role"]["CreateDate"])) # throw dummy exception to force overwrite if overwrite: remove_lambda_role(iam_client, lambda_role) - logging.info('Overwriting existing role {}'.format(lambda_role)) + logging.info("Overwriting existing role {}".format(lambda_role)) create_lambda_role(iam_client, lambda_role) - except iam_client.exceptions.NoSuchEntityException as e: - logging.info('Role {} was not found in {}, creating ...'.format(lambda_role, region)) + except iam_client.exceptions.NoSuchEntityException: + logging.info( + "Role {} was not found in {}, creating ...".format(lambda_role, region) + ) create_lambda_role(iam_client, lambda_role) @@ -166,7 +197,6 @@ def sizeof_fmt(num, suffix="B"): class ProgressPercentage(object): - def __init__(self, filename): self._filename = filename self._size = float(os.path.getsize(filename)) @@ -179,23 +209,37 @@ def __call__(self, bytes_amount): self._seen_so_far += bytes_amount percentage = (self._seen_so_far / self._size) * 100 sys.stdout.write( - "\r%s %s / %s (%.2f%%)" % ( - self._filename, sizeof_fmt(self._seen_so_far), sizeof_fmt(self._size), - percentage)) + "\r%s %s / %s (%.2f%%)" + % ( + self._filename, + sizeof_fmt(self._seen_so_far), + sizeof_fmt(self._size), + percentage, + ) + ) sys.stdout.flush() def s3_split_uri(uri): - assert '/' in uri, 'at least one / is required!' - uri = uri.replace('s3://', '') + assert "/" in uri, "at least one / is required!" + uri = uri.replace("s3://", "") - bucket = uri[:uri.find('/')] - key = uri[uri.find('/') + 1:] + bucket = uri[: uri.find("/")] + key = uri[uri.find("/") + 1 :] return bucket, key -def upload_lambda(iam_client, lambda_client, lambda_function_name, lambda_role, - lambda_zip_file, overwrite=False, s3_client=None, s3_scratch_space=None, quiet=False): +def upload_lambda( + iam_client, + lambda_client, + lambda_function_name, + lambda_role, + lambda_zip_file, + overwrite=False, + s3_client=None, + s3_scratch_space=None, + quiet=False, +): # AWS only allows 50MB to be uploaded directly via request. Else, requires S3 upload. ZIP_UPLOAD_LIMIT_SIZE = 50000000 @@ -204,124 +248,150 @@ def upload_lambda(iam_client, lambda_client, lambda_function_name, lambda_role, # for runtime, choose https://docs.aws.amazon.com/lambda/latest/dg/lambda-runtimes.html RUNTIME = "provided.al2" HANDLER = "tplxlam" # this is how the executable is called... - ARCHITECTURES = ['x86_64'] + ARCHITECTURES = ["x86_64"] DEFAULT_MEMORY_SIZE = 1536 DEFAULT_TIMEOUT = 30 # 30s timeout if not os.path.isfile(lambda_zip_file): - raise Exception('Could not find local lambda zip file {}'.format(lambda_zip_file)) + raise Exception( + "Could not find local lambda zip file {}".format(lambda_zip_file) + ) file_size = os.stat(lambda_zip_file).st_size # if file size is smaller than limit, check how large the base64 encoded version is... CODE = None if file_size < ZIP_UPLOAD_LIMIT_SIZE: - logging.info('Encoding Lambda as base64 ({})'.format(sizeof_fmt(file_size))) - with open(lambda_zip_file, 'rb') as fp: + logging.info("Encoding Lambda as base64 ({})".format(sizeof_fmt(file_size))) + with open(lambda_zip_file, "rb") as fp: CODE = fp.read() CODE = base64.b64encode(CODE) b64_file_size = len(CODE) + 1 - logging.info('File size as base64 is {}'.format(sizeof_fmt(b64_file_size))) + logging.info("File size as base64 is {}".format(sizeof_fmt(b64_file_size))) else: b64_file_size = ZIP_UPLOAD_LIMIT_SIZE + 42 # to not trigger below if # get ARN of lambda role response = iam_client.get_role(RoleName=lambda_role) - lambda_role_arn = response['Role']['Arn'] + lambda_role_arn = response["Role"]["Arn"] # check if Lambda function already exists, if overwrite delete! - l_response = lambda_client.list_functions(FunctionVersion='ALL') - functions = list(filter(lambda f: f['FunctionName'] == lambda_function_name, l_response['Functions'])) + l_response = lambda_client.list_functions(FunctionVersion="ALL") + functions = list( + filter( + lambda f: f["FunctionName"] == lambda_function_name, l_response["Functions"] + ) + ) if len(functions) > 0: if len(functions) != 1: - logging.warning('Found multiple functions with name {}, deleting them all.'.format(lambda_function_name)) + logging.warning( + "Found multiple functions with name {}, deleting them all.".format( + lambda_function_name + ) + ) if not overwrite: raise Exception( - 'Found existing Lambda function {}, specify overwrite=True to replace'.format(lambda_function_name)) + "Found existing Lambda function {}, specify overwrite=True to replace".format( + lambda_function_name + ) + ) for f in functions: - lambda_client.delete_function(FunctionName=f['FunctionName']) - logging.info('Removed existing function {} (Runtime={}, MemorySize={}) from {}'.format(f['FunctionName'], - f['Runtime'], - f['MemorySize'], - f['LastModified'])) + lambda_client.delete_function(FunctionName=f["FunctionName"]) + logging.info( + "Removed existing function {} (Runtime={}, MemorySize={}) from {}".format( + f["FunctionName"], f["Runtime"], f["MemorySize"], f["LastModified"] + ) + ) - logging.info('Assigning role {} to runner'.format(lambda_role_arn)) + logging.info("Assigning role {} to runner".format(lambda_role_arn)) user = current_user() host = host_name() - DEPLOY_MESSAGE = "Auto-deployed Tuplex Lambda Runner function." \ - " Uploaded by {} from {} on {}".format(user, host, datetime.datetime.now()) + DEPLOY_MESSAGE = ( + "Auto-deployed Tuplex Lambda Runner function." + " Uploaded by {} from {} on {}".format(user, host, datetime.datetime.now()) + ) if b64_file_size < ZIP_UPLOAD_LIMIT_SIZE: - logging.info('Found packaged lambda ({})'.format(sizeof_fmt(file_size))) + logging.info("Found packaged lambda ({})".format(sizeof_fmt(file_size))) - logging.info('Loading local zipped lambda...') + logging.info("Loading local zipped lambda...") - logging.info('Uploading Lambda to AWS ({})'.format(sizeof_fmt(file_size))) + logging.info("Uploading Lambda to AWS ({})".format(sizeof_fmt(file_size))) try: # upload directly, we use Custom - response = lambda_client.create_function(FunctionName=lambda_function_name, - Runtime=RUNTIME, - Handler=HANDLER, - Role=lambda_role_arn, - Code={'ZipFile': CODE}, - Description=DEPLOY_MESSAGE, - PackageType='Zip', - MemorySize=DEFAULT_MEMORY_SIZE, - Timeout=DEFAULT_TIMEOUT) + response = lambda_client.create_function( + FunctionName=lambda_function_name, + Runtime=RUNTIME, + Handler=HANDLER, + Role=lambda_role_arn, + Code={"ZipFile": CODE}, + Description=DEPLOY_MESSAGE, + PackageType="Zip", + MemorySize=DEFAULT_MEMORY_SIZE, + Timeout=DEFAULT_TIMEOUT, + ) except Exception as e: - logging.error('Failed with: {}'.format(type(e))) - logging.error('Details: {}'.format(str(e)[:2048])) + logging.error("Failed with: {}".format(type(e))) + logging.error("Details: {}".format(str(e)[:2048])) raise e else: if s3_client is None or s3_scratch_space is None: - raise Exception("Local packaged lambda to large to upload directly, " \ - "need S3. Please specify S3 client + scratch space") - logging.info("Lambda function is larger than current limit ({}) AWS allows, " \ - " deploying via S3...".format(sizeof_fmt(ZIP_UPLOAD_LIMIT_SIZE))) + raise Exception( + "Local packaged lambda to large to upload directly, " + "need S3. Please specify S3 client + scratch space" + ) + logging.info( + "Lambda function is larger than current limit ({}) AWS allows, " + " deploying via S3...".format(sizeof_fmt(ZIP_UPLOAD_LIMIT_SIZE)) + ) # upload to s3 temporarily s3_bucket, s3_key = s3_split_uri(s3_scratch_space) # scratch space, so naming doesn't matter - TEMP_NAME = 'lambda-deploy.zip' - s3_key_obj = s3_key + '/' + TEMP_NAME - s3_target_uri = 's3://' + s3_bucket + '/' + s3_key + '/' + TEMP_NAME + TEMP_NAME = "lambda-deploy.zip" + s3_key_obj = s3_key + "/" + TEMP_NAME + s3_target_uri = "s3://" + s3_bucket + "/" + s3_key + "/" + TEMP_NAME callback = ProgressPercentage(lambda_zip_file) if not quiet else None s3_client.upload_file(lambda_zip_file, s3_bucket, s3_key_obj, Callback=callback) - logging.info('Deploying Lambda from S3 ({})'.format(s3_target_uri)) + logging.info("Deploying Lambda from S3 ({})".format(s3_target_uri)) try: # upload directly, we use Custom - response = lambda_client.create_function(FunctionName=lambda_function_name, - Runtime=RUNTIME, - Handler=HANDLER, - Role=lambda_role_arn, - Code={'S3Bucket': s3_bucket, 'S3Key': s3_key_obj}, - Description=DEPLOY_MESSAGE, - PackageType='Zip', - MemorySize=DEFAULT_MEMORY_SIZE, - Timeout=DEFAULT_TIMEOUT) + response = lambda_client.create_function( + FunctionName=lambda_function_name, + Runtime=RUNTIME, + Handler=HANDLER, + Role=lambda_role_arn, + Code={"S3Bucket": s3_bucket, "S3Key": s3_key_obj}, + Description=DEPLOY_MESSAGE, + PackageType="Zip", + MemorySize=DEFAULT_MEMORY_SIZE, + Timeout=DEFAULT_TIMEOUT, + ) except Exception as e: - logging.error('Failed with: {}'.format(type(e))) - logging.error('Details: {}'.format(str(e)[:2048])) + logging.error("Failed with: {}".format(type(e))) + logging.error("Details: {}".format(str(e)[:2048])) # delete S3 file from scratch s3_client.delete_object(Bucket=s3_bucket, Key=s3_key_obj) - logging.info('Removed {} from S3'.format(s3_target_uri)) + logging.info("Removed {} from S3".format(s3_target_uri)) raise e # delete S3 file from scratch s3_client.delete_object(Bucket=s3_bucket, Key=s3_key_obj) - logging.info('Removed {} from S3'.format(s3_target_uri)) + logging.info("Removed {} from S3".format(s3_target_uri)) # print out deployment details - logging.info('Lambda function {} deployed (MemorySize={}MB, Timeout={}).'.format(response['FunctionName'], - response['MemorySize'], - response['Timeout'])) + logging.info( + "Lambda function {} deployed (MemorySize={}MB, Timeout={}).".format( + response["FunctionName"], response["MemorySize"], response["Timeout"] + ) + ) # return lambda response return response @@ -337,24 +407,26 @@ def find_lambda_package(): this_directory = os.path.abspath(os.path.dirname(__file__)) # check if folder other exists & file tplxlam.zip in it! - candidate_path = os.path.join(this_directory, 'other', 'tplxlam.zip') + candidate_path = os.path.join(this_directory, "other", "tplxlam.zip") if os.path.isfile(candidate_path): - logging.info('Found Lambda runner package in {}'.format(candidate_path)) + logging.info("Found Lambda runner package in {}".format(candidate_path)) return candidate_path return None -def setup_aws(aws_access_key=None, aws_secret_key= None, - overwrite=True, - iam_user=None, - lambda_name=None, - lambda_role=None, - lambda_file=None, - region=None, - s3_scratch_uri=None, - quiet=False - ): +def setup_aws( + aws_access_key=None, + aws_secret_key=None, + overwrite=True, + iam_user=None, + lambda_name=None, + lambda_role=None, + lambda_file=None, + region=None, + s3_scratch_uri=None, + quiet=False, +): start_time = time.time() # detect defaults. Important to do this here, because don't want to always invoke boto3/botocore @@ -372,19 +444,21 @@ def setup_aws(aws_access_key=None, aws_secret_key= None, s3_scratch_uri = default_scratch_dir() if lambda_file is None: - raise Exception('Must specify a lambda runner to upload, i.e. set ' \ - 'parameter lambda_file=. Please check the REAMDE.md to ' \ - ' read about instructions on how to build the lambda runner or visit ' \ - 'the project website to download prebuilt runners.') + raise Exception( + "Must specify a lambda runner to upload, i.e. set " + "parameter lambda_file=. Please check the REAMDE.md to " + " read about instructions on how to build the lambda runner or visit " + "the project website to download prebuilt runners." + ) - assert lambda_file is not None, 'must specify file to upload' + assert lambda_file is not None, "must specify file to upload" # check credentials are existing on machine --> raises exception in case - logging.info('Validating AWS credentials') + logging.info("Validating AWS credentials") check_credentials(aws_access_key, aws_access_key) - logging.info('Setting up AWS Lambda backend for IAM user {}'.format(iam_user)) - logging.info('Configuring backend in zone: {}'.format(region)) + logging.info("Setting up AWS Lambda backend for IAM user {}".format(iam_user)) + logging.info("Configuring backend in zone: {}".format(region)) # check if iam user is found? # --> skip for now, later properly authenticate using assume_role as described in @@ -392,13 +466,15 @@ def setup_aws(aws_access_key=None, aws_secret_key= None, # create all required client objects for setup # key credentials for clients - client_kwargs = {'aws_access_key_id': aws_access_key, - 'aws_secret_access_key': aws_secret_key, - 'region_name': region} + client_kwargs = { + "aws_access_key_id": aws_access_key, + "aws_secret_access_key": aws_secret_key, + "region_name": region, + } - iam_client = boto3.client('iam', **client_kwargs) - s3_client = boto3.client('s3', **client_kwargs) - lambda_client = boto3.client('lambda', **client_kwargs) + iam_client = boto3.client("iam", **client_kwargs) + s3_client = boto3.client("s3", **client_kwargs) + lambda_client = boto3.client("lambda", **client_kwargs) # Step 1: ensure S3 scratch space exists s3_bucket, s3_key = s3_split_uri(s3_scratch_uri) @@ -408,8 +484,18 @@ def setup_aws(aws_access_key=None, aws_secret_key= None, setup_lambda_role(iam_client, lambda_role, region, overwrite) # Step 3: upload/create Lambda - upload_lambda(iam_client, lambda_client, lambda_name, lambda_role, lambda_file, overwrite, s3_client, s3_scratch_uri, quiet) + upload_lambda( + iam_client, + lambda_client, + lambda_name, + lambda_role, + lambda_file, + overwrite, + s3_client, + s3_scratch_uri, + quiet, + ) # done, print if quiet was not set to False if not quiet: - print('\nCompleted lambda setup in {:.2f}s'.format(time.time() - start_time)) + print("\nCompleted lambda setup in {:.2f}s".format(time.time() - start_time)) diff --git a/tuplex/python/tuplex/exceptions.py b/tuplex/python/tuplex/exceptions.py index ae6abd992..0c50fa997 100644 --- a/tuplex/python/tuplex/exceptions.py +++ b/tuplex/python/tuplex/exceptions.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# # # # Tuplex: Blazing Fast Python Data Science # # # @@ -7,7 +7,8 @@ # (c) 2017 - 2021, Tuplex team # # Created by Leonhard Spiegelberg first on 1/1/2021 # # License: Apache 2.0 # -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# + def classToExceptionCode(cls): """ @@ -19,62 +20,63 @@ def classToExceptionCode(cls): """ - lookup = {BaseException : 100, - Exception : 101, - ArithmeticError : 102, - BufferError : 103, - LookupError : 104, - AssertionError : 105, - AttributeError : 106, - EOFError : 107, - GeneratorExit : 108, - ImportError : 109, - ModuleNotFoundError : 110, - IndexError : 111, - KeyError : 112, - KeyboardInterrupt : 113, - MemoryError : 114, - NameError : 115, - NotImplementedError : 116, - OSError : 117, - OverflowError : 118, - RecursionError : 119, - ReferenceError : 120, - RuntimeError : 121, - StopIteration : 122, - StopAsyncIteration : 123, - SyntaxError : 124, - IndentationError : 125, - TabError : 126, - SystemError : 127, - SystemExit : 128, - TypeError : 129, - UnboundLocalError : 130, - UnicodeError : 131, - UnicodeEncodeError : 132, - UnicodeDecodeError : 133, - UnicodeTranslateError : 134, - ValueError : 135, - ZeroDivisionError : 136, - EnvironmentError : 137, - IOError : 138, - BlockingIOError : 139, - ChildProcessError : 140, - ConnectionError : 141, - BrokenPipeError : 142, - ConnectionAbortedError : 143, - ConnectionRefusedError : 144, - FileExistsError : 145, - FileNotFoundError : 146, - InterruptedError : 147, - IsADirectoryError : 148, - NotADirectoryError : 149, - PermissionError : 150, - ProcessLookupError : 151, - TimeoutError : 152 - } + lookup = { + BaseException: 100, + Exception: 101, + ArithmeticError: 102, + BufferError: 103, + LookupError: 104, + AssertionError: 105, + AttributeError: 106, + EOFError: 107, + GeneratorExit: 108, + ImportError: 109, + ModuleNotFoundError: 110, + IndexError: 111, + KeyError: 112, + KeyboardInterrupt: 113, + MemoryError: 114, + NameError: 115, + NotImplementedError: 116, + OSError: 117, + OverflowError: 118, + RecursionError: 119, + ReferenceError: 120, + RuntimeError: 121, + StopIteration: 122, + StopAsyncIteration: 123, + SyntaxError: 124, + IndentationError: 125, + TabError: 126, + SystemError: 127, + SystemExit: 128, + TypeError: 129, + UnboundLocalError: 130, + UnicodeError: 131, + UnicodeEncodeError: 132, + UnicodeDecodeError: 133, + UnicodeTranslateError: 134, + ValueError: 135, + ZeroDivisionError: 136, + EnvironmentError: 137, + IOError: 138, + BlockingIOError: 139, + ChildProcessError: 140, + ConnectionError: 141, + BrokenPipeError: 142, + ConnectionAbortedError: 143, + ConnectionRefusedError: 144, + FileExistsError: 145, + FileNotFoundError: 146, + InterruptedError: 147, + IsADirectoryError: 148, + NotADirectoryError: 149, + PermissionError: 150, + ProcessLookupError: 151, + TimeoutError: 152, + } try: return lookup[cls] - except: - return \ No newline at end of file + except KeyError: + return None diff --git a/tuplex/python/tuplex/libexec/__init__.py b/tuplex/python/tuplex/libexec/__init__.py index f768b97bc..3ff8a069c 100644 --- a/tuplex/python/tuplex/libexec/__init__.py +++ b/tuplex/python/tuplex/libexec/__init__.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# # # # Tuplex: Blazing Fast Python Data Science # # # @@ -7,4 +7,4 @@ # (c) 2017 - 2021, Tuplex team # # Created by Leonhard Spiegelberg first on 1/1/2021 # # License: Apache 2.0 # -#----------------------------------------------------------------------------------------------------------------------# \ No newline at end of file +# ----------------------------------------------------------------------------------------------------------------------# diff --git a/tuplex/python/tuplex/libexec/_tuplex.py b/tuplex/python/tuplex/libexec/_tuplex.py index f768b97bc..3ff8a069c 100644 --- a/tuplex/python/tuplex/libexec/_tuplex.py +++ b/tuplex/python/tuplex/libexec/_tuplex.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# # # # Tuplex: Blazing Fast Python Data Science # # # @@ -7,4 +7,4 @@ # (c) 2017 - 2021, Tuplex team # # Created by Leonhard Spiegelberg first on 1/1/2021 # # License: Apache 2.0 # -#----------------------------------------------------------------------------------------------------------------------# \ No newline at end of file +# ----------------------------------------------------------------------------------------------------------------------# diff --git a/tuplex/python/tuplex/metrics.py b/tuplex/python/tuplex/metrics.py index 19903032f..481bc832f 100644 --- a/tuplex/python/tuplex/metrics.py +++ b/tuplex/python/tuplex/metrics.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# # # # Tuplex: Blazing Fast Python Data Science # # # @@ -7,9 +7,10 @@ # (c) 2017 - 2021, Tuplex team # # Created by Leonhard Spiegelberg first on 1/1/2021 # # License: Apache 2.0 # -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# import logging import typing + try: from .libexec.tuplex import _Context from .libexec.tuplex import _Metrics @@ -19,6 +20,7 @@ import json + class Metrics: """ Stores a reference to the metrics associated with a diff --git a/tuplex/python/tuplex/repl/__init__.py b/tuplex/python/tuplex/repl/__init__.py index 058b111ca..e355fb323 100644 --- a/tuplex/python/tuplex/repl/__init__.py +++ b/tuplex/python/tuplex/repl/__init__.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# # # # Tuplex: Blazing Fast Python Data Science # # # @@ -7,27 +7,33 @@ # (c) 2017 - 2021, Tuplex team # # Created by Leonhard Spiegelberg first on 1/1/2021 # # License: Apache 2.0 # -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# import os import sys -from tuplex.utils.common import is_in_interactive_mode, in_jupyter_notebook, in_google_colab +from tuplex.utils.common import ( + is_in_interactive_mode, + in_jupyter_notebook, + in_google_colab, +) + try: from tuplex.utils.version import __version__ except: - __version__ = 'dev' + __version__ = "dev" + def TuplexBanner(): - banner = '''Welcome to\n + banner = """Welcome to\n _____ _ |_ _| _ _ __ | | _____ __ | || | | | '_ \| |/ _ \ \/ / | || |_| | |_) | | __/> < |_| \__,_| .__/|_|\___/_/\_\\ {} |_| - '''.format(__version__) - banner += '\nusing Python {} on {}'.format(sys.version, sys.platform) + """.format(__version__) + banner += "\nusing Python {} on {}".format(sys.version, sys.platform) return banner @@ -36,14 +42,16 @@ def TuplexBanner(): if is_in_interactive_mode() and not in_jupyter_notebook() and not in_google_colab(): from tuplex.utils.interactive_shell import TuplexShell - os.system('clear') + + os.system("clear") from tuplex.context import Context + _locals = locals() - _locals = {key: _locals[key] for key in _locals if key in ['Context']} + _locals = {key: _locals[key] for key in _locals if key in ["Context"]} shell = TuplexShell() shell.init(locals=_locals) - shell.interact(banner=TuplexBanner() + '\n Interactive Shell mode') + shell.interact(banner=TuplexBanner() + "\n Interactive Shell mode") else: - print(TuplexBanner()) \ No newline at end of file + print(TuplexBanner()) diff --git a/tuplex/python/tuplex/utils/__init__.py b/tuplex/python/tuplex/utils/__init__.py index f768b97bc..3ff8a069c 100644 --- a/tuplex/python/tuplex/utils/__init__.py +++ b/tuplex/python/tuplex/utils/__init__.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# # # # Tuplex: Blazing Fast Python Data Science # # # @@ -7,4 +7,4 @@ # (c) 2017 - 2021, Tuplex team # # Created by Leonhard Spiegelberg first on 1/1/2021 # # License: Apache 2.0 # -#----------------------------------------------------------------------------------------------------------------------# \ No newline at end of file +# ----------------------------------------------------------------------------------------------------------------------# diff --git a/tuplex/python/tuplex/utils/common.py b/tuplex/python/tuplex/utils/common.py index ea9dcf51e..39b7916ac 100644 --- a/tuplex/python/tuplex/utils/common.py +++ b/tuplex/python/tuplex/utils/common.py @@ -17,14 +17,11 @@ import signal import yaml -import sys from datetime import datetime import json import urllib.request import os -import signal -import atexit import socket import shutil import psutil @@ -34,8 +31,6 @@ import re import tempfile import time -import shlex -import pathlib try: import pwd @@ -46,8 +41,8 @@ try: from tuplex.utils.version import __version__ -except: - __version__ = 'dev' +except ImportError: + __version__ = "dev" def cmd_exists(cmd): @@ -71,11 +66,14 @@ def is_shared_lib(path): """ # use file command - assert cmd_exists('file') + assert cmd_exists("file") - res = subprocess.check_output(['file', '--mime-type', path]) + res = subprocess.check_output(["file", "--mime-type", path]) mime_type = res.split()[-1].decode() - return mime_type == 'application/x-sharedlib' or mime_type == 'application/x-application' + return ( + mime_type == "application/x-sharedlib" + or mime_type == "application/x-application" + ) def current_timestamp(): @@ -105,7 +103,7 @@ def host_name(): Returns: some hostname as string """ - if socket.gethostname().find('.') >= 0: + if socket.gethostname().find(".") >= 0: return socket.gethostname() else: return socket.gethostbyaddr(socket.gethostname())[0] @@ -122,9 +120,10 @@ def post_json(url, data): """ - params = json.dumps(data).encode('utf8') - req = urllib.request.Request(url, data=params, - headers={'content-type': 'application/json'}) + params = json.dumps(data).encode("utf8") + req = urllib.request.Request( + url, data=params, headers={"content-type": "application/json"} + ) response = urllib.request.urlopen(req) return json.loads(response.read()) @@ -139,7 +138,7 @@ def get_json(url, timeout=10): python dictionary of decoded json """ - req = urllib.request.Request(url, headers={'content-type': 'application/json'}) + req = urllib.request.Request(url, headers={"content-type": "application/json"}) response = urllib.request.urlopen(req, timeout=timeout) return json.loads(response.read()) @@ -156,9 +155,9 @@ def in_jupyter_notebook(): try: # get_ipython won't be defined in standard python interpreter shell = get_ipython().__class__.__name__ - if shell == 'ZMQInteractiveShell': + if shell == "ZMQInteractiveShell": return True # Jupyter notebook or qtconsole - elif shell == 'TerminalInteractiveShell': + elif shell == "TerminalInteractiveShell": return False # Terminal running IPython else: return False # Other type (?) @@ -172,23 +171,14 @@ def in_google_colab(): Returns: True if Tuplex is running in Google Colab """ - found_colab_package = False - try: - import google.colab - found_colab_package = True - except: - pass shell_name_matching = False try: - shell_name_matching = 'google.colab' in str(get_ipython()) - except: + shell_name_matching = "google.colab" in str(get_ipython()) + except NameError: pass - if found_colab_package or shell_name_matching: - return True - else: - return False + return shell_name_matching def is_in_interactive_mode(): @@ -198,10 +188,10 @@ def is_in_interactive_mode(): """ - return bool(getattr(sys, 'ps1', sys.flags.interactive)) + return bool(getattr(sys, "ps1", sys.flags.interactive)) -def flatten_dict(d, sep='.', parent_key=''): +def flatten_dict(d, sep=".", parent_key=""): """ flattens a nested dictionary into a flat dictionary by concatenating keys with the separator. Args: d (dict): The dictionary to flatten @@ -222,7 +212,7 @@ def flatten_dict(d, sep='.', parent_key=''): return dict(items) -def unflatten_dict(dictionary, sep='.'): +def unflatten_dict(dictionary, sep="."): """ unflattens a dictionary into a nested dictionary according to sep Args: @@ -265,15 +255,15 @@ def beautify_nesting(d): else: return d - assert isinstance(file_path, str), 'file_path must be instance of str' + assert isinstance(file_path, str), "file_path must be instance of str" - with open(file_path, 'w') as f: - f.write('# Tuplex configuration file\n') - f.write('# created {} UTC\n'.format(datetime.utcnow())) + with open(file_path, "w") as f: + f.write("# Tuplex configuration file\n") + f.write("# created {} UTC\n".format(datetime.utcnow())) out = yaml.dump(beautify_nesting(unflatten_dict(conf))) # pyyaml prints { } around single item dicts. Remove by hand - out = out.replace('{', '').replace('}', '') + out = out.replace("{", "").replace("}", "") f.write(out) @@ -308,17 +298,17 @@ def parse_string(item): return item # do not use bool(...) to convert! - if item.lower() == 'true': + if item.lower() == "true": return True - if item.lower() == 'false': + if item.lower() == "false": return False try: return int(item) - except: + except ValueError: pass try: return float(item) - except: + except ValueError: pass return item @@ -352,9 +342,9 @@ def to_nested_dict(obj): resultDict[key] = val return resultDict - assert isinstance(file_path, str), 'file_path must be instance of str' + assert isinstance(file_path, str), "file_path must be instance of str" d = dict() - with open(file_path, 'r') as f: + with open(file_path, "r") as f: confs = list(yaml.safe_load_all(f)) for conf in confs: d.update(to_nested_dict(conf)) @@ -369,7 +359,7 @@ def stringify_dict(d): Returns: dictionary with keys and vals as strs """ - assert isinstance(d, dict), 'd must be a dictionary' + assert isinstance(d, dict), "d must be a dictionary" return {str(key): str(val) for key, val in d.items()} @@ -428,11 +418,13 @@ def logging_callback(level, time_info, logger_name, msg): # fix pathname/lineno if pathname is None: - pathname = '' + pathname = "" if lineno is None: lineno = 0 - log_record = logging.LogRecord(logger_name, level, pathname, lineno, msg, None, None) + log_record = logging.LogRecord( + logger_name, level, pathname, lineno, msg, None, None + ) log_record.created = ct log_record.msecs = (ct - int(ct)) * 1000 log_record.relativeCreated = log_record.created - logging._startTime @@ -462,13 +454,13 @@ def auto_shutdown_all(): for entry in __exit_handlers__: try: name, func, args, msg = entry - logging.debug('Attempting to shutdown {}...'.format(name)) + logging.debug("Attempting to shutdown {}...".format(name)) if msg: logging.info(msg) func(args) - logging.info('Shutdown {} successfully'.format(name)) - except Exception as e: - logging.error('Failed to shutdown {}'.format(name)) + logging.info("Shutdown {} successfully".format(name)) + except Exception: + logging.error("Failed to shutdown {}".format(name)) __exit_handlers__ = [] @@ -500,7 +492,7 @@ def is_process_running(name): return False -def mongodb_uri(mongodb_url, mongodb_port, db_name='tuplex-history'): +def mongodb_uri(mongodb_url, mongodb_port, db_name="tuplex-history"): """ constructs a fully qualified MongoDB URI Args: @@ -511,10 +503,12 @@ def mongodb_uri(mongodb_url, mongodb_port, db_name='tuplex-history'): Returns: string representing MongoDB URI """ - return 'mongodb://{}:{}/{}'.format(mongodb_url, mongodb_port, db_name) + return "mongodb://{}:{}/{}".format(mongodb_url, mongodb_port, db_name) -def check_mongodb_connection(mongodb_url, mongodb_port, db_name='tuplex-history', timeout=10.0): +def check_mongodb_connection( + mongodb_url, mongodb_port, db_name="tuplex-history", timeout=10.0 +): """ connects to a MongoDB database instance, raises exception if connection fails Args: @@ -530,36 +524,46 @@ def check_mongodb_connection(mongodb_url, mongodb_port, db_name='tuplex-history' # check whether one can connect to MongoDB from pymongo import MongoClient - from pymongo.errors import ServerSelectionTimeoutError start_time = time.time() connect_successful = False - logging.debug('Attempting to contact MongoDB under {}'.format(uri)) + logging.debug("Attempting to contact MongoDB under {}".format(uri)) connect_try = 1 while abs(time.time() - start_time) < timeout: - logging.debug('MongoDB connection try {}...'.format(connect_try)) + logging.debug("MongoDB connection try {}...".format(connect_try)) try: # set client connection to super low timeouts so the wait is not too long. - client = MongoClient(uri, serverSelectionTimeoutMS=100, connectTimeoutMS=1000) - info = client.server_info() # force a call to mongodb, alternative is client.admin.command('ismaster') + client = MongoClient( + uri, serverSelectionTimeoutMS=100, connectTimeoutMS=1000 + ) + client.server_info() # force a call to mongodb, alternative is client.admin.command('ismaster') connect_successful = True except Exception as e: - logging.debug('Connection try {} produced {} exception {}'.format(connect_try, type(e), str(e))) + logging.debug( + "Connection try {} produced {} exception {}".format( + connect_try, type(e), str(e) + ) + ) if connect_successful: timeout = 0 break time.sleep(0.05) # sleep for 50ms - logging.debug('Contacting MongoDB under {}... -- {:.2f}s of poll time left'.format(uri, timeout - ( - time.time() - start_time))) + logging.debug( + "Contacting MongoDB under {}... -- {:.2f}s of poll time left".format( + uri, timeout - (time.time() - start_time) + ) + ) connect_try += 1 if connect_successful is False: - raise Exception('Could not connect to MongoDB, check network connection. (ping must be < 100ms)') + raise Exception( + "Could not connect to MongoDB, check network connection. (ping must be < 100ms)" + ) - logging.debug('Connection test to MongoDB succeeded') + logging.debug("Connection test to MongoDB succeeded") def shutdown_process_via_kill(pid): @@ -571,11 +575,17 @@ def shutdown_process_via_kill(pid): Returns: None """ - logging.debug('Shutting down process PID={}'.format(pid)) + logging.debug("Shutting down process PID={}".format(pid)) os.kill(pid, signal.SIGKILL) -def find_or_start_mongodb(mongodb_url, mongodb_port, mongodb_datapath, mongodb_logpath, db_name='tuplex-history'): +def find_or_start_mongodb( + mongodb_url, + mongodb_port, + mongodb_datapath, + mongodb_logpath, + db_name="tuplex-history", +): """ attempts to connect to a MongoDB database. If no running local MongoDB is found, will auto-start a mongodb database. R aises exception when fails. @@ -591,17 +601,19 @@ def find_or_start_mongodb(mongodb_url, mongodb_port, mongodb_datapath, mongodb_l """ # is it localhost? - if 'localhost' in mongodb_url: - logging.debug('Using local MongoDB instance') + if "localhost" in mongodb_url: + logging.debug("Using local MongoDB instance") # first check whether mongod is on path - if not cmd_exists('mongod'): - raise Exception('MongoDB (mongod) not found on PATH. In order to use Tuplex\'s WebUI, you need MongoDB' - ' installed or point the framework to a running MongoDB instance') + if not cmd_exists("mongod"): + raise Exception( + "MongoDB (mongod) not found on PATH. In order to use Tuplex's WebUI, you need MongoDB" + " installed or point the framework to a running MongoDB instance" + ) # is mongod running on local machine? - if is_process_running('mongod'): - logging.debug('Found locally running MongoDB daemon process') + if is_process_running("mongod"): + logging.debug("Found locally running MongoDB daemon process") # process is running, try to connect check_mongodb_connection(mongodb_url, mongodb_port, db_name) @@ -614,11 +626,23 @@ def find_or_start_mongodb(mongodb_url, mongodb_port, mongodb_datapath, mongodb_l # startup via mongod --fork --logpath /var/log/mongodb/mongod.log --port 1234 --dbpath try: - cmd = ['mongod', '--fork', '--logpath', str(mongodb_logpath), '--port', str(mongodb_port), '--dbpath', - str(mongodb_datapath)] - - logging.debug('starting MongoDB daemon process via {}'.format(' '.join(cmd))) - process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + cmd = [ + "mongod", + "--fork", + "--logpath", + str(mongodb_logpath), + "--port", + str(mongodb_port), + "--dbpath", + str(mongodb_datapath), + ] + + logging.debug( + "starting MongoDB daemon process via {}".format(" ".join(cmd)) + ) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) short_timeout = 2.5 max_mongodb_timeout = 10 # maximum timeout is 10s @@ -628,50 +652,70 @@ def find_or_start_mongodb(mongodb_url, mongodb_port, mongodb_datapath, mongodb_l except subprocess.TimeoutExpired: # try now with more time (up to max) logging.info( - "Could not start MongoDB daemon process in {}s, trying with timeout={}s".format(short_timeout, - max_mongodb_timeout)) - p_stdout, p_stderr = process.communicate(timeout=max_mongodb_timeout) + "Could not start MongoDB daemon process in {}s, trying with timeout={}s".format( + short_timeout, max_mongodb_timeout + ) + ) + p_stdout, p_stderr = process.communicate( + timeout=max_mongodb_timeout + ) # decode p_stdout = p_stdout.decode() p_stderr = p_stderr.decode() if len(p_stderr.strip()) > 0: - raise Exception('mongod produced following errors: {}'.format(p_stderr)) + raise Exception( + "mongod produced following errors: {}".format(p_stderr) + ) # find mongod pid - m = re.search(r'forked process: (\d+)', p_stdout) - assert m is not None, 'Could not find Child process ID when starting MongoDB' + m = re.search(r"forked process: (\d+)", p_stdout) + assert m is not None, ( + "Could not find Child process ID when starting MongoDB" + ) mongo_pid = int(m[1]) - logging.debug('MongoDB Daemon PID={}'.format(mongo_pid)) + logging.debug("MongoDB Daemon PID={}".format(mongo_pid)) # add a new shutdown func for mongod - register_auto_shutdown('mongod', shutdown_process_via_kill, mongo_pid) + register_auto_shutdown("mongod", shutdown_process_via_kill, mongo_pid) except Exception as e: - logging.error('Failed to start MongoDB daemon. Details: {}'.format(str(e))) + logging.error( + "Failed to start MongoDB daemon. Details: {}".format(str(e)) + ) # print out first 10 and last 10 lines of mongodb log if exists n_to_print = 15 mongodb_logpath = str(mongodb_logpath) if os.path.isfile(mongodb_logpath): - with open(mongodb_logpath, 'r') as fp_mongo: - lines = list(map(lambda line: line.strip(), fp_mongo.readlines())) - shortened_log = '' + with open(mongodb_logpath, "r") as fp_mongo: + lines = list( + map(lambda line: line.strip(), fp_mongo.readlines()) + ) + shortened_log = "" if len(lines) > 2 * n_to_print: - shortened_log = '\n'.join(lines[:n_to_print]) + '...\n' + '\n'.join(lines[-n_to_print:]) + shortened_log = ( + "\n".join(lines[:n_to_print]) + + "...\n" + + "\n".join(lines[-n_to_print:]) + ) else: - shortened_log = '\n'.join(lines) - logging.error('MongoDB daemon log:\n{}'.format(shortened_log)) + shortened_log = "\n".join(lines) + logging.error("MongoDB daemon log:\n{}".format(shortened_log)) else: - logging.error('Could not find MongoDB log under {}. Permission error?'.format(mongodb_logpath)) + logging.error( + "Could not find MongoDB log under {}. Permission error?".format( + mongodb_logpath + ) + ) raise e logging.debug("Attempting to connect to freshly started MongoDB daemon...") check_mongodb_connection(mongodb_url, mongodb_port, db_name) else: # remote MongoDB - logging.debug('Connecting to remote MongoDB instance') + logging.debug("Connecting to remote MongoDB instance") check_mongodb_connection(mongodb_url, mongodb_port, db_name) @@ -686,12 +730,15 @@ def log_gunicorn_errors(logpath): """ # parse log, check whether there's any line where [ERROR] is contined - with open(logpath, 'r') as fp: + with open(logpath, "r") as fp: lines = fp.readlines() - indices = map(lambda t: t[1], filter(lambda t: '[ERROR]' in t[0], zip(lines, range(len(lines))))) + indices = map( + lambda t: t[1], + filter(lambda t: "[ERROR]" in t[0], zip(lines, range(len(lines)))), + ) if indices: first_idx = min(indices) - logging.error('Gunicorn error log:\n {}'.format(''.join(lines[first_idx:]))) + logging.error("Gunicorn error log:\n {}".format("".join(lines[first_idx:]))) def find_or_start_webui(mongo_uri, hostname, port, web_logfile): @@ -706,27 +753,34 @@ def find_or_start_webui(mongo_uri, hostname, port, web_logfile): Returns: None, raises exceptions on failure """ - version_endpoint = '/api/version' # use this to connect and trigger WebUI connection + version_endpoint = ( + "/api/version" # use this to connect and trigger WebUI connection + ) - if not hostname.startswith('http://') and not hostname.startswith('https://'): - hostname = 'http://' + str(hostname) + if not hostname.startswith("http://") and not hostname.startswith("https://"): + hostname = "http://" + str(hostname) - base_uri = '{}:{}'.format(hostname, port) + base_uri = "{}:{}".format(hostname, port) version_info = None try: version_info = get_json(base_uri + version_endpoint) - except Exception as err: - logging.debug("Couldn't connect to {}, starting WebUI...".format(base_uri + version_endpoint)) + except Exception: + logging.debug( + "Couldn't connect to {}, starting WebUI...".format( + base_uri + version_endpoint + ) + ) if version_info is not None: # check version compatibility return version_info else: # start WebUI up! - if not cmd_exists('gunicorn'): + if not cmd_exists("gunicorn"): raise Exception( - 'Tuplex uses per default gunicorn with eventlet to run the WebUI. Please install via `pip3 install "gunicorn[eventlet]"` or add to PATH') + 'Tuplex uses per default gunicorn with eventlet to run the WebUI. Please install via `pip3 install "gunicorn[eventlet]"` or add to PATH' + ) # command for this is: # env MONGO_URI=$MONGO_URI gunicorn --daemon --worker-class eventlet --log-file $GUNICORN_LOGFILE -b $HOST:$PORT thserver:app @@ -734,41 +788,52 @@ def find_or_start_webui(mongo_uri, hostname, port, web_logfile): # directory needs to be the one where the history server is located in! # ==> from structure of file we can infer that dir_path = os.path.dirname(os.path.realpath(__file__)) - assert dir_path.endswith(os.path.join('tuplex', 'utils')), 'folder structure changed. Need to fix.' + assert dir_path.endswith(os.path.join("tuplex", "utils")), ( + "folder structure changed. Need to fix." + ) # get tuplex base dir tuplex_basedir = pathlib.Path(dir_path).parent # two options: Could be dev install or site-packages install, therefore check two folders - if not os.path.isdir(os.path.join(tuplex_basedir, 'historyserver', 'thserver')): + if not os.path.isdir(os.path.join(tuplex_basedir, "historyserver", "thserver")): # dev install or somehow different folder structure? # --> try to find root tuplex folder containing historyserver folder! path = pathlib.Path(tuplex_basedir) while path.parent != path: # check in path - if 'tuplex' in os.listdir(path) and 'historyserver' in os.listdir(os.path.join(path, 'tuplex')): - tuplex_basedir = os.path.join(str(path), 'tuplex') - logging.debug('Detected Tuplex rootfolder (dev) to be {}'.format(tuplex_basedir)) + if "tuplex" in os.listdir(path) and "historyserver" in os.listdir( + os.path.join(path, "tuplex") + ): + tuplex_basedir = os.path.join(str(path), "tuplex") + logging.debug( + "Detected Tuplex rootfolder (dev) to be {}".format( + tuplex_basedir + ) + ) break path = path.parent # check dir historyserver/thserver exists! - assert os.path.isdir(os.path.join(tuplex_basedir, 'historyserver', - 'thserver')), 'could not find Tuplex WebUI WebApp in {}'.format( - tuplex_basedir) - assert os.path.isfile(os.path.join(tuplex_basedir, 'historyserver', 'thserver', - '__init__.py')), 'could not find Tuplex WebUI __init__.py file in thserver folder' + assert os.path.isdir( + os.path.join(tuplex_basedir, "historyserver", "thserver") + ), "could not find Tuplex WebUI WebApp in {}".format(tuplex_basedir) + assert os.path.isfile( + os.path.join(tuplex_basedir, "historyserver", "thserver", "__init__.py") + ), "could not find Tuplex WebUI __init__.py file in thserver folder" # history server dir to use to start gunicorn - ui_basedir = os.path.join(tuplex_basedir, 'historyserver') - logging.debug('Launching gunicorn from {}'.format(ui_basedir)) + ui_basedir = os.path.join(tuplex_basedir, "historyserver") + logging.debug("Launching gunicorn from {}".format(ui_basedir)) # create temp PID file to get process ID to shutdown auto-started WebUI PID_FILE = tempfile.NamedTemporaryFile(delete=False).name ui_env = os.environ - ui_env['MONGO_URI'] = mongo_uri - gunicorn_host = '{}:{}'.format(hostname.replace('http://', '').replace('https://', ''), port) + ui_env["MONGO_URI"] = mongo_uri + gunicorn_host = "{}:{}".format( + hostname.replace("http://", "").replace("https://", ""), port + ) # need to convert everything to absolute paths (b.c. gunicorn fails else) web_logfile = os.path.abspath(web_logfile) @@ -778,14 +843,33 @@ def find_or_start_webui(mongo_uri, hostname, port, web_logfile): wl_path = pathlib.Path(web_logfile).parent os.makedirs(str(wl_path), exist_ok=True) except Exception as e: - logging.error("ensuring parent dir of {} exists, failed with {}".format(web_logfile, e)) - - cmd = ['gunicorn', '--daemon', '--worker-class', 'eventlet', '--chdir', ui_basedir, '--pid', PID_FILE, - '--log-file', web_logfile, '-b', gunicorn_host, 'thserver:app'] - - logging.debug('Starting gunicorn with command: {}'.format(' '.join(cmd))) - - process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=ui_env) + logging.error( + "ensuring parent dir of {} exists, failed with {}".format( + web_logfile, e + ) + ) + + cmd = [ + "gunicorn", + "--daemon", + "--worker-class", + "eventlet", + "--chdir", + ui_basedir, + "--pid", + PID_FILE, + "--log-file", + web_logfile, + "-b", + gunicorn_host, + "thserver:app", + ] + + logging.debug("Starting gunicorn with command: {}".format(" ".join(cmd))) + + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=ui_env + ) # set a timeout of 2 seconds to keep everything interactive p_stdout, p_stderr = process.communicate(timeout=2) @@ -794,9 +878,9 @@ def find_or_start_webui(mongo_uri, hostname, port, web_logfile): p_stderr = p_stderr.decode() if len(p_stderr.strip()) > 0: - raise Exception('gunicorn produced following errors: {}'.format(p_stderr)) + raise Exception("gunicorn produced following errors: {}".format(p_stderr)) - logging.info('Gunicorn locally started...') + logging.info("Gunicorn locally started...") # find out process id of gunicorn ui_pid = None @@ -809,23 +893,40 @@ def find_or_start_webui(mongo_uri, hostname, port, web_logfile): time.sleep(0.05) # sleep for 50ms else: break - logging.debug('Polling for Gunicorn PID... -- {:.2f}s of poll time left'.format( - TIME_LIMIT - (time.time() - start_time))) + logging.debug( + "Polling for Gunicorn PID... -- {:.2f}s of poll time left".format( + TIME_LIMIT - (time.time() - start_time) + ) + ) ui_pid = None try: # Read PID file - with open(PID_FILE, 'r') as fp: + with open(PID_FILE, "r") as fp: ui_pid = int(fp.read()) except Exception as e: logging.debug("failed to retrieve PID for WebUI, details: {}".format(e)) - non_daemon_log = 'timeout - no log' + non_daemon_log = "timeout - no log" # something went wrong with starting gunicorn. Try to capture some meaningful output and abort try: - cmd = ['gunicorn', '--worker-class', 'eventlet', '--chdir', ui_basedir, '--pid', PID_FILE, - '--log-file', '-', '-b', gunicorn_host, 'thserver:app'] - process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=ui_env) + cmd = [ + "gunicorn", + "--worker-class", + "eventlet", + "--chdir", + ui_basedir, + "--pid", + PID_FILE, + "--log-file", + "-", + "-b", + gunicorn_host, + "thserver:app", + ] + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=ui_env + ) # set a timeout of 5 seconds to keep everything interactive p_stdout, p_stderr = process.communicate(timeout=5) @@ -833,36 +934,45 @@ def find_or_start_webui(mongo_uri, hostname, port, web_logfile): p_stdout = p_stdout.decode() p_stderr = p_stderr.decode() - non_daemon_log = p_stdout + '\n' + p_stderr + non_daemon_log = p_stdout + "\n" + p_stderr except subprocess.TimeoutExpired: pass - logging.error('Gunicorn process log:\n' + non_daemon_log) - raise Exception("Failed to start gunicorn daemon, non-daemon run yielded:\n{}".format(non_daemon_log)) + logging.error("Gunicorn process log:\n" + non_daemon_log) + raise Exception( + "Failed to start gunicorn daemon, non-daemon run yielded:\n{}".format( + non_daemon_log + ) + ) - assert ui_pid is not None, 'Invalid PID for WebUI' - logging.info('Gunicorn PID={}'.format(ui_pid)) + assert ui_pid is not None, "Invalid PID for WebUI" + logging.info("Gunicorn PID={}".format(ui_pid)) # register daemon shutdown - logging.debug('Adding auto-shutdown of process with PID={} (WebUI)'.format(ui_pid)) + logging.debug( + "Adding auto-shutdown of process with PID={} (WebUI)".format(ui_pid) + ) def shutdown_gunicorn(pid): - pids_to_kill = [] # iterate over all gunicorn processes and kill them all for proc in psutil.process_iter(): try: # Get process name & pid from process object. - process_name = proc.name() - process_id = proc.pid - - sep_line = '|'.join(proc.cmdline()).lower() - if 'gunicorn' in sep_line: - + sep_line = "|".join(proc.cmdline()).lower() + if "gunicorn" in sep_line: # check whether that gunicorn instance matches what has been started - if 'thserver:app' in proc.cmdline() and gunicorn_host in proc.cmdline() and PID_FILE in proc.cmdline(): + if ( + "thserver:app" in proc.cmdline() + and gunicorn_host in proc.cmdline() + and PID_FILE in proc.cmdline() + ): pids_to_kill.append(proc.pid) - except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): + except ( + psutil.NoSuchProcess, + psutil.AccessDenied, + psutil.ZombieProcess, + ): pass # kill all gunicorn processes @@ -870,14 +980,14 @@ def shutdown_gunicorn(pid): os.kill(pid, signal.SIGQUIT) os.kill(pid, signal.SIGKILL) os.kill(pid, signal.SIGTERM) - logging.debug('Shutdown gunicorn worker with PID={}'.format(pid)) - logging.debug('Shutdown gunicorn with PID={}'.format(pid)) + logging.debug("Shutdown gunicorn worker with PID={}".format(pid)) + logging.debug("Shutdown gunicorn with PID={}".format(pid)) - register_auto_shutdown('gunicorn', shutdown_gunicorn, ui_pid) + register_auto_shutdown("gunicorn", shutdown_gunicorn, ui_pid) version_info = get_json(base_uri + version_endpoint) if version_info is None: - raise Exception('Could not retrieve version info from WebUI') + raise Exception("Could not retrieve version info from WebUI") # perform checks (same MongoDB URI? Same Version?) return version_info @@ -901,42 +1011,57 @@ def ensure_webui(options): # {"tuplex.webui.mongodb.port", "27017"}, # {"tuplex.webui.mongodb.path", temp_mongodb_path} - assert options['tuplex.webui.enable'] is True, 'only call ensure webui when webui option is true' - - mongodb_url = options['tuplex.webui.mongodb.url'] - mongodb_port = options['tuplex.webui.mongodb.port'] - mongodb_datapath = os.path.join(options['tuplex.scratchDir'], 'webui', 'data') - mongodb_logpath = os.path.join(options['tuplex.scratchDir'], 'webui', 'logs', 'mongod.log') - gunicorn_logpath = os.path.join(options['tuplex.scratchDir'], 'webui', 'logs', 'gunicorn.log') - webui_url = options['tuplex.webui.url'] - webui_port = options['tuplex.webui.port'] + assert options["tuplex.webui.enable"] is True, ( + "only call ensure webui when webui option is true" + ) + + mongodb_url = options["tuplex.webui.mongodb.url"] + mongodb_port = options["tuplex.webui.mongodb.port"] + mongodb_datapath = os.path.join(options["tuplex.scratchDir"], "webui", "data") + mongodb_logpath = os.path.join( + options["tuplex.scratchDir"], "webui", "logs", "mongod.log" + ) + gunicorn_logpath = os.path.join( + options["tuplex.scratchDir"], "webui", "logs", "gunicorn.log" + ) + webui_url = options["tuplex.webui.url"] + webui_port = options["tuplex.webui.port"] try: - logging.debug('finding MongoDB...') - find_or_start_mongodb(mongodb_url, mongodb_port, mongodb_datapath, mongodb_logpath) + logging.debug("finding MongoDB...") + find_or_start_mongodb( + mongodb_url, mongodb_port, mongodb_datapath, mongodb_logpath + ) mongo_uri = mongodb_uri(mongodb_url, mongodb_port) - logging.debug('finding WebUI..') + logging.debug("finding WebUI..") # now it's time to do the same thing for the WebUI (and also check it's version v.s. the current one!) - version_info = find_or_start_webui(mongo_uri, webui_url, webui_port, gunicorn_logpath) + version_info = find_or_start_webui( + mongo_uri, webui_url, webui_port, gunicorn_logpath + ) - logging.debug('WebUI services found or started!') + logging.debug("WebUI services found or started!") # check that version of WebUI and Tuplex version match # exclude dev versions, i.e. silence warning there. - if 'dev' not in __version__ and version_info['version'] != __version__: - logging.warning('Version of Tuplex WebUI ({}) and Tuplex ({}) do not match.'.format(version_info['version'], - __version__)) + if "dev" not in __version__ and version_info["version"] != __version__: + logging.warning( + "Version of Tuplex WebUI ({}) and Tuplex ({}) do not match.".format( + version_info["version"], __version__ + ) + ) # all good, print out link so user can access WebUI easily - webui_uri = webui_url + ':' + str(webui_port) - if not webui_uri.startswith('http'): - webui_uri = 'http://' + webui_uri - print('Tuplex WebUI can be accessed under {}'.format(webui_uri)) + webui_uri = webui_url + ":" + str(webui_port) + if not webui_uri.startswith("http"): + webui_uri = "http://" + webui_uri + print("Tuplex WebUI can be accessed under {}".format(webui_uri)) except Exception as e: - logging.error('Failed to start or connect to Tuplex WebUI. Details: {}'.format(e)) + logging.error( + "Failed to start or connect to Tuplex WebUI. Details: {}".format(e) + ) # log gunicorn errors for local startup - if os.path.isfile(gunicorn_logpath) and 'localhost' == webui_url: + if os.path.isfile(gunicorn_logpath) and "localhost" == webui_url: log_gunicorn_errors(gunicorn_logpath) diff --git a/tuplex/python/tuplex/utils/errors.py b/tuplex/python/tuplex/utils/errors.py index a05d2f2c6..315ca7317 100644 --- a/tuplex/python/tuplex/utils/errors.py +++ b/tuplex/python/tuplex/utils/errors.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# # # # Tuplex: Blazing Fast Python Data Science # # # @@ -7,10 +7,12 @@ # (c) 2017 - 2021, Tuplex team # # Created by Leonhard Spiegelberg first on 1/1/2021 # # License: Apache 2.0 # -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# + class TuplexException(Exception): """ Base class for exceptions across the fraemwork """ - pass \ No newline at end of file + + pass diff --git a/tuplex/python/tuplex/utils/framework.py b/tuplex/python/tuplex/utils/framework.py index d5d36d225..34a3ff649 100644 --- a/tuplex/python/tuplex/utils/framework.py +++ b/tuplex/python/tuplex/utils/framework.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# # # # Tuplex: Blazing Fast Python Data Science # # # @@ -7,13 +7,16 @@ # (c) 2017 - 2021, Tuplex team # # Created by Leonhard Spiegelberg first on 8/3/2021 # # License: Apache 2.0 # -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# # this file contains Framework specific exceptions class TuplexException(Exception): """Base Exception class on which all Tuplex Framework specific exceptions are based""" + pass + class UDFCodeExtractionError(TuplexException): """thrown when UDF code extraction/reflection failed""" - pass \ No newline at end of file + + pass diff --git a/tuplex/python/tuplex/utils/globs.py b/tuplex/python/tuplex/utils/globs.py index 9fba0e9ed..116442e9f 100644 --- a/tuplex/python/tuplex/utils/globs.py +++ b/tuplex/python/tuplex/utils/globs.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# # # # Tuplex: Blazing Fast Python Data Science # # # @@ -7,40 +7,36 @@ # (c) 2017 - 2021, Tuplex team # # Created by Leonhard Spiegelberg first on 1/1/2021 # # License: Apache 2.0 # -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# import types -import inspect -import re -import ast import weakref import dis import opcode -import types import itertools import sys + # ALWAYS import cloudpickle before dill, b.c. of https://github.com/uqfoundation/dill/issues/383 from cloudpickle.cloudpickle import _get_cell_contents -import dill # from cloudpickle # ---------------- _extract_code_globals_cache = weakref.WeakKeyDictionary() # relevant opcodes -STORE_GLOBAL = opcode.opmap['STORE_GLOBAL'] -DELETE_GLOBAL = opcode.opmap['DELETE_GLOBAL'] -LOAD_GLOBAL = opcode.opmap['LOAD_GLOBAL'] +STORE_GLOBAL = opcode.opmap["STORE_GLOBAL"] +DELETE_GLOBAL = opcode.opmap["DELETE_GLOBAL"] +LOAD_GLOBAL = opcode.opmap["LOAD_GLOBAL"] GLOBAL_OPS = (STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL) HAVE_ARGUMENT = dis.HAVE_ARGUMENT EXTENDED_ARG = dis.EXTENDED_ARG + def _extract_code_globals(co): """ Find all globals names read or written to by codeblock co """ out_names = _extract_code_globals_cache.get(co) if out_names is None: - names = co.co_names out_names = {opargval: None for opi, opargval in _walk_global_ops(co)} # Declaring a function inside another one using the "def ..." @@ -57,6 +53,8 @@ def _extract_code_globals(co): _extract_code_globals_cache[co] = out_names return out_names + + def _find_imported_submodules(code, top_level_dependencies): """ Find currently imported submodules used by a function. @@ -85,10 +83,13 @@ def func(): subimports = [] # check if any known dependency is an imported package for x in top_level_dependencies: - if (isinstance(x, types.ModuleType) and - hasattr(x, '__package__') and x.__package__): + if ( + isinstance(x, types.ModuleType) + and hasattr(x, "__package__") + and x.__package__ + ): # check if the package has any currently loaded sub-imports - prefix = x.__name__ + '.' + prefix = x.__name__ + "." # A concurrent thread could mutate sys.modules, # make sure we iterate over a copy to avoid exceptions for name in list(sys.modules): @@ -96,7 +97,7 @@ def func(): # sys.modules. if name is not None and name.startswith(prefix): # check whether the function can address the sub-module - tokens = set(name[len(prefix):].split('.')) + tokens = set(name[len(prefix) :].split(".")) if not tokens - set(code.co_names): subimports.append(sys.modules[name]) return subimports @@ -132,12 +133,12 @@ def _function_getstate(func): } f_globals_ref = _extract_code_globals(func.__code__) - f_globals = {k: func.__globals__[k] for k in f_globals_ref if k in - func.__globals__} + f_globals = {k: func.__globals__[k] for k in f_globals_ref if k in func.__globals__} closure_values = ( list(map(_get_cell_contents, func.__closure__)) - if func.__closure__ is not None else () + if func.__closure__ is not None + else () ) # Extract currently-imported submodules used by func. Storing these modules @@ -145,29 +146,33 @@ def _function_getstate(func): # trigger the side effect of importing these modules at unpickling time # (which is necessary for func to work correctly once depickled) slotstate["_cloudpickle_submodules"] = _find_imported_submodules( - func.__code__, itertools.chain(f_globals.values(), closure_values)) + func.__code__, itertools.chain(f_globals.values(), closure_values) + ) slotstate["__globals__"] = f_globals - - # add free vars to slotstate by decoding closure - try: - slotstate['__freevars__'] = {name: closure_values[i] for i, name in enumerate(func.__code__.co_freevars)} - except: - slotstate['__freevars__'] = {} + # Add free vars to slotstate by decoding closure. + slotstate["__freevars__"] = { + name: closure_values[i] for i, name in enumerate(func.__code__.co_freevars) + } state = func.__dict__ return state, slotstate + + # -------------------- # end from cloudpickle + def get_globals(func): _, d = _function_getstate(func) - func_globals = d['__globals__'] - func_freevars = d['__freevars__'] + func_globals = d["__globals__"] + func_freevars = d["__freevars__"] # unify free vars with globals if len(set(func_globals.keys()).intersection(set(func_freevars.keys()))) != 0: - raise Exception('internal error, overlap between globals and freevars, should not occur.') + raise Exception( + "internal error, overlap between globals and freevars, should not occur." + ) # add free vars to global dict to have everything in one dict. func_globals.update(func_freevars) diff --git a/tuplex/python/tuplex/utils/interactive_shell.py b/tuplex/python/tuplex/utils/interactive_shell.py index 4d432b4c4..91fd74fbd 100644 --- a/tuplex/python/tuplex/utils/interactive_shell.py +++ b/tuplex/python/tuplex/utils/interactive_shell.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# # # # Tuplex: Blazing Fast Python Data Science # # # @@ -7,7 +7,7 @@ # (c) 2017 - 2021, Tuplex team # # Created by Leonhard Spiegelberg first on 1/1/2021 # # License: Apache 2.0 # -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# from __future__ import unicode_literals @@ -17,6 +17,7 @@ import logging from code import InteractiveConsole from prompt_toolkit.history import InMemoryHistory + # old version: 1.0 # from prompt_toolkit.layout.lexers import PygmentsLexer # from prompt_toolkit.styles import style_from_pygments @@ -46,27 +47,31 @@ def __init__(self, context_cls): def Context(self): return self._context_cls + # Interactive shell # check https://github.com/python/cpython/blob/master/Lib/code.py for overwriting this class class TuplexShell(InteractiveConsole): - # use BORG design pattern to make class singleton alike __shared_state = {} def __init__(self): self.__dict__ = self.__shared_state - def init(self, locals=None, filename="", histfile=os.path.expanduser("~/.console_history")): - + def init( + self, + locals=None, + filename="", + histfile=os.path.expanduser("~/.console_history"), + ): # add dummy helper for context - if locals is not None and 'Context' in locals.keys(): - locals['tuplex'] = TuplexModuleHelper(locals['Context']) + if locals is not None and "Context" in locals.keys(): + locals["tuplex"] = TuplexModuleHelper(locals["Context"]) self.initialized = True self.filename = "console-0" self.lineno = 0 InteractiveConsole.__init__(self, locals, self.filename) - self._lastLine = '' + self._lastLine = "" self.historyDict = {} def push(self, line): @@ -81,7 +86,7 @@ def push(self, line): value is 1 if more input is required, 0 if the line was dealt with in some way (this is the same as runsource()). """ - assert self.initialized, 'must call init on TuplexShell object first' + assert self.initialized, "must call init on TuplexShell object first" self.buffer.append(line) source = "\n".join(self.buffer) @@ -99,18 +104,17 @@ def push(self, line): self.historyDict[self.filename] = self.buffer.copy() # new filename - self.filename = 'console-{}'.format(self.lineno) + self.filename = "console-{}".format(self.lineno) self.resetbuffer() return more - def get_lambda_source(self, f): # Won't this work for functions as well? - assert self.initialized, 'must call init on TuplexShell object first' + assert self.initialized, "must call init on TuplexShell object first" - assert isinstance(f, LambdaType), 'object needs to be a lambda object' + assert isinstance(f, LambdaType), "object needs to be a lambda object" vault = SourceVault() @@ -118,40 +122,48 @@ def get_lambda_source(self, f): f_globs = get_globals(f) f_filename = f.__code__.co_filename f_lineno = f.__code__.co_firstlineno - f_colno = f.__code__.co_firstcolno if hasattr(f.__code__, 'co_firstcolno') else None + f_colno = ( + f.__code__.co_firstcolno if hasattr(f.__code__, "co_firstcolno") else None + ) # get source from history # Note: because firstlineno is 1-indexed, add a dummy line so everything works. - src_info = (['dummy'] + self.historyDict[f_filename], 0) + src_info = (["dummy"] + self.historyDict[f_filename], 0) vault.extractAndPutAllLambdas(src_info, f_filename, f_lineno, f_colno, f_globs) return vault.get(f, f_filename, f_lineno, f_colno, f_globs) def get_function_source(self, f): + assert self.initialized, "must call init on TuplexShell object first" - assert self.initialized, 'must call init on TuplexShell object first' - - assert isinstance(f, - FunctionType) and f.__code__.co_name != '', 'object needs to be a function (non-lambda) object' + assert isinstance(f, FunctionType) and f.__code__.co_name != "", ( + "object needs to be a function (non-lambda) object" + ) - # fetch all data - f_globs = get_globals(f) + # Fetch all data: f_filename = f.__code__.co_filename - f_lineno = f.__code__.co_firstlineno - f_colno = f.__code__.co_firstcolno if hasattr(f.__code__, 'co_firstcolno') else None + + # # TODO: Include lineno/colno information in AST. + # f_globs = get_globals(f) + # f_lineno = f.__code__.co_firstlineno + # f_colno = ( + # f.__code__.co_firstcolno if hasattr(f.__code__, "co_firstcolno") else None + # ) # retrieve func source from historyDict lines = self.historyDict[f_filename] # check whether def is found in here - source = '\n'.join(lines).strip() + source = "\n".join(lines).strip() function_name = f.__code__.co_name regex = r"def\s*{}\(.*\)\s*:[\t ]*\n".format(function_name) prog = re.compile(regex) if not prog.search(source): - logging.error('Could not find function "{}" in source'.format(function_name)) + logging.error( + 'Could not find function "{}" in source'.format(function_name) + ) return None return source @@ -178,7 +190,7 @@ def interact(self, banner=None, exitmsg=None): style_trafo = None # check if env TUPLEX_COLORSCHEME is set, then pygments style may be used. - scheme = os.environ.get('TUPLEX_COLORSCHEME', None) + scheme = os.environ.get("TUPLEX_COLORSCHEME", None) if scheme: # define here style for python prompt toolkit @@ -199,9 +211,10 @@ def interact(self, banner=None, exitmsg=None): sys.ps2 = "... " cprt = 'Type "help", "copyright", "credits" or "license" for more information.' if banner is None: - self.write("Python %s on %s\n%s\n(%s)\n" % - (sys.version, sys.platform, cprt, - self.__class__.__name__)) + self.write( + "Python %s on %s\n%s\n(%s)\n" + % (sys.version, sys.platform, cprt, self.__class__.__name__) + ) elif banner: self.write("%s\n" % str(banner)) more = 0 @@ -214,18 +227,22 @@ def interact(self, banner=None, exitmsg=None): try: # use prompt toolkit here for more stylish input & tab completion # raw python prompt - #line = self.raw_input(prompt) - + # line = self.raw_input(prompt) # look here http://python-prompt-toolkit.readthedocs.io/en/stable/pages/asking_for_input.html#hello-world # on how to style the prompt better # use patch_stdout=True to output stuff above prompt - line = ptprompt(prompt, lexer=PygmentsLexer(Python3Lexer), style=style, - style_transformation=style_trafo, history=history, - completer=JediCompleter(lambda: self.locals), - complete_style=CompleteStyle.READLINE_LIKE, - complete_while_typing=False) + line = ptprompt( + prompt, + lexer=PygmentsLexer(Python3Lexer), + style=style, + style_transformation=style_trafo, + history=history, + completer=JediCompleter(lambda: self.locals), + complete_style=CompleteStyle.READLINE_LIKE, + complete_while_typing=False, + ) except EOFError: self.write("\n") @@ -237,6 +254,6 @@ def interact(self, banner=None, exitmsg=None): self.resetbuffer() more = 0 if exitmsg is None: - self.write('now exiting %s...\n' % self.__class__.__name__) - elif exitmsg != '': - self.write('%s\n' % exitmsg) \ No newline at end of file + self.write("now exiting %s...\n" % self.__class__.__name__) + elif exitmsg != "": + self.write("%s\n" % exitmsg) diff --git a/tuplex/python/tuplex/utils/jedi_completer.py b/tuplex/python/tuplex/utils/jedi_completer.py index f5f1bb517..b1260942e 100644 --- a/tuplex/python/tuplex/utils/jedi_completer.py +++ b/tuplex/python/tuplex/utils/jedi_completer.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# # # # Tuplex: Blazing Fast Python Data Science # # # @@ -7,18 +7,17 @@ # (c) 2017 - 2021, Tuplex team # # Created by Leonhard Spiegelberg first on 1/1/2021 # # License: Apache 2.0 # -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# from prompt_toolkit.completion import Completer, Completion -import jedi from jedi import Interpreter from jedi import settings + class JediCompleter(Completer): """REPL Completer using jedi""" def __init__(self, get_locals): - # per default jedi is case insensitive, however we want it to be case sensitive settings.case_insensitive_completion = False @@ -30,20 +29,20 @@ def get_completions(self, document, complete_event): # Jedi API changed, reflect this here completions = [] - if hasattr(interpreter, 'completions'): + if hasattr(interpreter, "completions"): completions = interpreter.completions() - elif hasattr(interpreter, 'complete'): + elif hasattr(interpreter, "complete"): completions = interpreter.complete() else: - raise Exception('Unknown Jedi API, please update or install older version (0.18)') + raise Exception( + "Unknown Jedi API, please update or install older version (0.18)" + ) for completion in completions: - - if completion.name_with_symbols.startswith('_'): + if completion.name_with_symbols.startswith("_"): continue - if len(document.text) > len(completion.name_with_symbols) - len(completion.complete): - last_char = document.text[len(completion.complete) - len(completion.name_with_symbols) - 1] - else: - last_char = None - yield Completion(completion.name_with_symbols, len(completion.complete) - len(completion.name_with_symbols)) \ No newline at end of file + yield Completion( + completion.name_with_symbols, + len(completion.complete) - len(completion.name_with_symbols), + ) diff --git a/tuplex/python/tuplex/utils/jupyter.py b/tuplex/python/tuplex/utils/jupyter.py index f2651ab88..10bf7791a 100644 --- a/tuplex/python/tuplex/utils/jupyter.py +++ b/tuplex/python/tuplex/utils/jupyter.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# # # # Tuplex: Blazing Fast Python Data Science # # # @@ -7,7 +7,7 @@ # (c) 2017 - 2021, Tuplex team # # Created by Leonhard Spiegelberg first on 1/1/2021 # # License: Apache 2.0 # -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# import json import os.path @@ -17,35 +17,50 @@ from urllib.parse import urljoin from notebook.notebookapp import list_running_servers + def get_jupyter_notebook_info(): """ retrieve infos about the currently running jupyter notebook if possible Returns: dict with several info attributes. If info for current notebook could not be retrieved, returns empty dict """ + def get(url): - req = urllib.request.Request(url, headers={'content-type': 'application/json'}) + req = urllib.request.Request(url, headers={"content-type": "application/json"}) response = urllib.request.urlopen(req) return json.loads(response.read()) - kernel_id = re.search('kernel-(.*).json', - ipykernel.connect.get_connection_file()).group(1) + kernel_id = re.search( + "kernel-(.*).json", ipykernel.connect.get_connection_file() + ).group(1) servers = list_running_servers() for ss in servers: # there may be a 403 from jupyter... try: - notebook_infos = get(urljoin(ss['url'], 'api/sessions?token={}'.format(ss.get('token', '')))) + notebook_infos = get( + urljoin(ss["url"], "api/sessions?token={}".format(ss.get("token", ""))) + ) # search for match for ninfo in notebook_infos: - if ninfo['kernel']['id'] == kernel_id: - return {'kernelID' : kernel_id, 'notebookID' : ninfo['id'], - 'kernelName' : ninfo['kernel']['name'], - 'path' : os.path.join(ss['notebook_dir'], ninfo['notebook']['path']), - 'url' : urljoin(ss['url'],'notebooks/{}?token={}'.format(ninfo['path'], ss.get('token', '')))} + if ninfo["kernel"]["id"] == kernel_id: + return { + "kernelID": kernel_id, + "notebookID": ninfo["id"], + "kernelName": ninfo["kernel"]["name"], + "path": os.path.join( + ss["notebook_dir"], ninfo["notebook"]["path"] + ), + "url": urljoin( + ss["url"], + "notebooks/{}?token={}".format( + ninfo["path"], ss.get("token", "") + ), + ), + } except urllib.error.HTTPError as e: # ignore 403s (i.e. no allowed access) - if e.getcode() != 403: + if e.getcode() != 403: raise e - return {} \ No newline at end of file + return {} diff --git a/tuplex/python/tuplex/utils/reflection.py b/tuplex/python/tuplex/utils/reflection.py index fd4e6a295..7066ba74b 100644 --- a/tuplex/python/tuplex/utils/reflection.py +++ b/tuplex/python/tuplex/utils/reflection.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# # # # Tuplex: Blazing Fast Python Data Science # # # @@ -7,36 +7,35 @@ # (c) 2017 - 2021, Tuplex team # # Created by Leonhard Spiegelberg first on 1/1/2021 # # License: Apache 2.0 # -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# import types import inspect import re + # ALWAYS import cloudpickle before dill, b.c. of https://github.com/uqfoundation/dill/issues/383 -import cloudpickle import dill import ast -import weakref -import dis -import opcode -import types -import itertools -import sys from tuplex.utils.errors import TuplexException from tuplex.utils.globs import get_globals from tuplex.utils.source_vault import SourceVault, supports_lambda_closure -from tuplex.utils.common import in_jupyter_notebook, in_google_colab, is_in_interactive_mode +from tuplex.utils.common import ( + in_jupyter_notebook, + in_google_colab, + is_in_interactive_mode, +) # only export get_source function, rest shall be private. -__all__ = ['get_source', 'get_globals', 'supports_lambda_closure'] +__all__ = ["get_source", "get_globals", "supports_lambda_closure"] + def get_jupyter_raw_code(function_name): - # ignore here unresolved reference - history_manager = get_ipython().history_manager + # Ignore here unresolved reference, get_ipython() works in jupyter notebook. + history_manager = get_ipython().history_manager # noqa: F821 hist = history_manager.get_range() regex = r"def\s*{}\(.*\)\s*:[\t ]*\n".format(function_name) - signature = 'hist = history_manager.get_range()' + signature = "hist = history_manager.get_range()" prog = re.compile(regex) matched_cells = [] @@ -47,7 +46,7 @@ def get_jupyter_raw_code(function_name): if signature in inline: continue - if 'get_function_code' in inline: + if "get_function_code" in inline: continue if prog.search(test_str): @@ -55,6 +54,7 @@ def get_jupyter_raw_code(function_name): return matched_cells[-1][2] + def extractFunctionByName(code, func_name, return_linenos=False): class FunctionVisitor(ast.NodeVisitor): def __init__(self): @@ -62,50 +62,51 @@ def __init__(self): self.funcInfo = [] def visit_FunctionDef(self, node): - print(self.lastStmtLineno) self.generic_visit(node) print(self.lastStmtLineno) def visit(self, node): funcStartLineno = -1 - if hasattr(node, 'lineno'): + if hasattr(node, "lineno"): self.lastStmtLineno = node.lineno if isinstance(node, ast.FunctionDef): funcStartLineno = node.lineno self.generic_visit(node) if isinstance(node, ast.FunctionDef): - self.funcInfo.append({'name': node.name, - 'start': funcStartLineno - 1, - 'end': self.lastStmtLineno - 1}) + self.funcInfo.append( + { + "name": node.name, + "start": funcStartLineno - 1, + "end": self.lastStmtLineno - 1, + } + ) root = ast.parse(code) fv = FunctionVisitor() fv.visit(root) # find function with name - candidates = filter(lambda x: x['name'] == func_name, fv.funcInfo) + candidates = filter(lambda x: x["name"] == func_name, fv.funcInfo) def indent(s): - return len(s) - len(s.lstrip(' \t')) + return len(s) - len(s.lstrip(" \t")) - lines = code.split('\n') + lines = code.split("\n") # find out level - candidates = map(lambda x: {**x, 'level': indent(lines[x['start']])}, candidates) + candidates = map(lambda x: {**x, "level": indent(lines[x["start"]])}, candidates) - info = sorted(candidates, key=lambda x: x['level'])[0] + info = sorted(candidates, key=lambda x: x["level"])[0] - func_code = '\n'.join(lines[info['start']:info['end'] + 1]) + func_code = "\n".join(lines[info["start"] : info["end"] + 1]) if return_linenos: - return func_code, info['start'], info['end'] + return func_code, info["start"], info["end"] else: return func_code def extract_function_code(function_name, raw_code): - - # remove greedily up to num_tabs and num_spaces def remove_tabs_and_spaces(line, num_tabs, num_spaces): t = 0 @@ -113,15 +114,15 @@ def remove_tabs_and_spaces(line, num_tabs, num_spaces): pos = 0 while pos < len(line): c = line[pos] - if c == ' ': + if c == " ": s += 1 - elif c == '\t': + elif c == "\t": t += 1 else: break pos += 1 - return ' ' * max(s - num_spaces, 0) + '\t' * max(t - num_tabs, 0) + line[pos:] + return " " * max(s - num_spaces, 0) + "\t" * max(t - num_tabs, 0) + line[pos:] # remove leading spaces / tabs assert len(raw_code) >= 1 @@ -133,19 +134,21 @@ def remove_tabs_and_spaces(line, num_tabs, num_spaces): start_idx = match.start() first_line = raw_code[start_idx:] - first_line_num_tabs = len(first_line) - len(first_line.lstrip('\t')) - first_line_num_spaces = len(first_line) - len(first_line.lstrip(' ')) + first_line_num_tabs = len(first_line) - len(first_line.lstrip("\t")) + first_line_num_spaces = len(first_line) - len(first_line.lstrip(" ")) - - func_lines = [remove_tabs_and_spaces(line, first_line_num_tabs, first_line_num_spaces) \ - for line in raw_code[start_idx:].split('\n')] + func_lines = [ + remove_tabs_and_spaces(line, first_line_num_tabs, first_line_num_spaces) + for line in raw_code[start_idx:].split("\n") + ] # greedily remove for each line tabs/spaces - out = '\n'.join(func_lines) + out = "\n".join(func_lines) return extractFunctionByName(out, function_name) + def get_function_code(f): - """ jupyter notebook, retrieve function history """ + """jupyter notebook, retrieve function history""" assert isinstance(f, types.FunctionType) function_name = f.__code__.co_name assert isinstance(function_name, str) @@ -171,19 +174,22 @@ def get_function_code(f): vault = SourceVault() + def get_source(f): - """ Jupyter notebook code reflection """ + """Jupyter notebook code reflection""" if isinstance(f, types.FunctionType): - # lambda function? # use inspect module # need to clean out lambda... - if f.__name__ == '': + if f.__name__ == "": # interpreter in interactive mode or not? # beware jupyter notebook also returns true for interactive mode! - if is_in_interactive_mode() and not in_jupyter_notebook() and not in_google_colab(): - + if ( + is_in_interactive_mode() + and not in_jupyter_notebook() + and not in_google_colab() + ): # import here, avoids also trouble with jupyter notebooks from tuplex.utils.interactive_shell import TuplexShell @@ -201,26 +207,29 @@ def get_source(f): f_globs = get_globals(f) f_filename = f.__code__.co_filename f_lineno = f.__code__.co_firstlineno - f_colno = f.__code__.co_firstcolno if hasattr(f.__code__, 'co_firstcolno') else None + f_colno = ( + f.__code__.co_firstcolno + if hasattr(f.__code__, "co_firstcolno") + else None + ) # special case: some unknown jupyter magic has been used... - if (in_jupyter_notebook() or in_google_colab()) and (f_filename == '' or f_filename == ''): - raise TuplexException('%%time magic not supported for Tuplex code') + if (in_jupyter_notebook() or in_google_colab()) and ( + f_filename == "" or f_filename == "" + ): + raise TuplexException("%%time magic not supported for Tuplex code") src_info = inspect.getsourcelines(f) - vault.extractAndPutAllLambdas(src_info, - f_filename, - f_lineno, - f_colno, - f_globs) + vault.extractAndPutAllLambdas( + src_info, f_filename, f_lineno, f_colno, f_globs + ) return vault.get(f, f_filename, f_lineno, f_colno, f_globs) else: # works always, because functions can be only defined on a single line! return get_function_code(f) else: - # TODO: for constants, create dummy source code, i.e. lambda x: 20 # when desired to retrieve a constant or so! - return '' \ No newline at end of file + return "" diff --git a/tuplex/python/tuplex/utils/source_vault.py b/tuplex/python/tuplex/utils/source_vault.py index 7a6aabdeb..48cb17493 100644 --- a/tuplex/python/tuplex/utils/source_vault.py +++ b/tuplex/python/tuplex/utils/source_vault.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# # # # Tuplex: Blazing Fast Python Data Science # # # @@ -7,7 +7,7 @@ # (c) 2017 - 2021, Tuplex team # # Created by Leonhard Spiegelberg first on 1/1/2021 # # License: Apache 2.0 # -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# import ast import astor @@ -16,6 +16,7 @@ from types import LambdaType, CodeType import logging + def supports_lambda_closure(): """ source code of lambdas can't be extracted, because there's no column information available @@ -24,8 +25,9 @@ def supports_lambda_closure(): Returns: True if operated with patched interpreter, False otherwise """ - f = lambda x: x * x # dummy function - return hasattr(f.__code__, 'co_firstcolno') + # Check with a dummy function. + f = lambda x: x * x # noqa: E731 + return hasattr(f.__code__, "co_firstcolno") def extract_all_lambdas(tree): @@ -43,7 +45,7 @@ def visit_Lambda(self, node): # extract for lambda incl. default values # annotations are not possible with the current syntax... def args_for_lambda_ast(lam): - return [getattr(n, 'arg') for n in lam.args.args] + return [getattr(n, "arg") for n in lam.args.args] def gen_code_for_lambda(lam): @@ -55,12 +57,12 @@ def gen_code_for_lambda(lam): # astor generates here lambda : # but we want lambda: if 0 == len(lam.args.args): - assert 'lambda :' in s - s = s.replace('lambda :', 'lambda:') + assert "lambda :" in s + s = s.replace("lambda :", "lambda:") return s.strip()[1:-1] except Exception as e: - logging.debug('gen_code_for_lambda via astor failed with {}'.format(e)) + logging.debug("gen_code_for_lambda via astor failed with {}".format(e)) # python3.9+ has ast.unparse if sys.version_info.major >= 3 and sys.version_info.minor >= 9: @@ -70,9 +72,11 @@ def gen_code_for_lambda(lam): s = ast.unparse(lam) return s except Exception as e: - logging.debug('gen_code_for_lambda via ast (python3.9+) failed with {}'.format(e)) + logging.debug( + "gen_code_for_lambda via ast (python3.9+) failed with {}".format(e) + ) - return '' + return "" def hash_code_object(code): @@ -80,16 +84,16 @@ def hash_code_object(code): # need to hash contents # for this use bytecode, varnames & constants # the list comprehension constant shows up as a code object in itself, so we have to recursively hash the constants - ret = code.co_code + bytes(str(code.co_varnames), 'utf8') + b'(' + ret = code.co_code + bytes(str(code.co_varnames), "utf8") + b"(" for c in code.co_consts: if isinstance(c, CodeType): ret += hash_code_object(c) - elif isinstance(c, str) and c.endswith('..'): + elif isinstance(c, str) and c.endswith(".."): continue else: - ret += bytes(str(c), 'utf8') - ret += b',' - return ret + b')' + ret += bytes(str(c), "utf8") + ret += b"," + return ret + b")" # join lines and remove stupid \\n @@ -103,12 +107,12 @@ def remove_line_breaks(source_lines): joined source without \ line breaks """ - source = '' + source = "" last_line_had_break = False for line in source_lines: this_line_had_break = False - if line.endswith('\\\n'): - line = line[:-len('\\\n')] + if line.endswith("\\\n"): + line = line[: -len("\\\n")] this_line_had_break = True # remove leading whitespace if last line had break @@ -141,7 +145,7 @@ def __init__(self): # assert isinstance(obj, LambdaType), 'object needs to be a lambda object' # return self.lambdaDict[hash_code_object(obj.__code__)] def get(self, ftor, filename, lineno, colno, globs): - assert isinstance(ftor, LambdaType), 'object needs to be a lambda object' + assert isinstance(ftor, LambdaType), "object needs to be a lambda object" # perform multiway lookup for code if filename and lineno: @@ -154,26 +158,30 @@ def get(self, ftor, filename, lineno, colno, globs): # if i.e. a call is placed within a loop. if len(entries) == 1: - return entries[0]['code'] + return entries[0]["code"] else: # patched interpreter? - if hasattr(ftor.__code__, 'co_firstcolno'): - raise Exception('patched interpreter not yet implemented') + if hasattr(ftor.__code__, "co_firstcolno"): + raise Exception("patched interpreter not yet implemented") else: # multiple lambda entries. Can only search for lambda IFF no globs if len(globs) != 0: - raise KeyError("Multiple lambdas found in {}:+{}, can't extract source code for " - "lambda expression. Please either patch the interpreter or write at " - "most a single lambda using global variables " - "per line.".format(os.path.basename(filename), lineno)) + raise KeyError( + "Multiple lambdas found in {}:+{}, can't extract source code for " + "lambda expression. Please either patch the interpreter or write at " + "most a single lambda using global variables " + "per line.".format(os.path.basename(filename), lineno) + ) # search for entry with matching hash of codeobject! codeobj_hash = hash_code_object(ftor.__code__) for entry in entries: - if entry['code_hash'] == codeobj_hash: - return entry['code'] - raise KeyError('Multiple lambdas found, but failed to retrieve code for this lambda expression.') + if entry["code_hash"] == codeobj_hash: + return entry["code"] + raise KeyError( + "Multiple lambdas found, but failed to retrieve code for this lambda expression." + ) else: - raise KeyError('could not find lambda function') + raise KeyError("could not find lambda function") def extractAndPutAllLambdas(self, src_info, filename, lineno, colno, globals): """ @@ -184,32 +192,34 @@ def extractAndPutAllLambdas(self, src_info, filename, lineno, colno, globals): lines, start_lineno = src_info - assert lineno >= start_lineno, 'line numbers sound off. please fix!' - f_lines = lines[lineno - start_lineno:] + assert lineno >= start_lineno, "line numbers sound off. please fix!" + f_lines = lines[lineno - start_lineno :] take_only_first_lambda = False # are there two lambda's defined in this line? # if so, in unpatched interpreter raise exception! - lam_count_in_target_line = f_lines[0].count('lambda') + lam_count_in_target_line = f_lines[0].count("lambda") if lam_count_in_target_line != 1: if lam_count_in_target_line == 0: - raise Exception('internal extract error, no lambda in source lines?') + raise Exception("internal extract error, no lambda in source lines?") if len(globals) != 0 and not supports_lambda_closure(): - raise Exception('Found {} lambda expressions in {}:{}. Please patch your interpreter or ' - 'reformat so Tuplex can extract the source code.'.format(lam_count_in_target_line, - os.path.basename(filename), - lineno)) + raise Exception( + "Found {} lambda expressions in {}:{}. Please patch your interpreter or " + "reformat so Tuplex can extract the source code.".format( + lam_count_in_target_line, os.path.basename(filename), lineno + ) + ) else: if supports_lambda_closure(): - assert colno, 'colno has to be valid' + assert colno, "colno has to be valid" # simply cut off based on col no! f_lines[0] = f_lines[0][colno:] take_only_first_lambda = True # if the first line contains only one lambda, simply the first lambda is taken. # else, multiple lambdas per - if f_lines[0].count('lambda') <= 1: + if f_lines[0].count("lambda") <= 1: take_only_first_lambda = True # get the line corresponding to the object @@ -221,22 +231,21 @@ def extractAndPutAllLambdas(self, src_info, filename, lineno, colno, globals): # special case for line breaks (this is a bad HACK! However, don't want to write own AST parser again in python) try: tree = ast.parse(source.lstrip()) - except SyntaxError as se: + except SyntaxError: # we could have a lambda that is broken because of \ at the end of lines # i.e. the source object is something like '\t\t.filter(lambda x: x * x)' # search till first lambda keyword - source = source[source.find('lambda'):] + source = source[source.find("lambda") :] try: # now another exception may be raised, i.e. when parsing fails tree = ast.parse(source.strip()) except SyntaxError as se2: - # try to parse partially till where syntax error occured. - source_lines = source.split('\n') - lines = source_lines[:se2.lineno] - lines[se2.lineno - 1] = lines[se2.lineno - 1][:se2.offset - 1] - source = '\n'.join(lines) + source_lines = source.split("\n") + lines = source_lines[: se2.lineno] + lines[se2.lineno - 1] = lines[se2.lineno - 1][: se2.offset - 1] + source = "\n".join(lines) tree = ast.parse(source.strip()) Lams = extract_all_lambdas(tree) @@ -253,7 +262,7 @@ def extractAndPutAllLambdas(self, src_info, filename, lineno, colno, globals): code = gen_code_for_lambda(lam) if 0 == len(code): - raise Exception('Couldn\'t generate code again for lambda function.') + raise Exception("Couldn't generate code again for lambda function.") # Note: can get colno from ast! colno = lam.col_offset + len(source) - len(source.lstrip()) @@ -261,23 +270,33 @@ def extractAndPutAllLambdas(self, src_info, filename, lineno, colno, globals): # however, to simplify code, use astor. key = (filename, lineno) - codeobj = compile(code, '', 'eval') + codeobj = compile(code, "", "eval") # hash evaluated code object's code codeobj_hash = hash_code_object(eval(codeobj).__code__) - entry = {'code': code, 'code_hash': codeobj_hash, - 'globals': globals, 'colno': colno} + entry = { + "code": code, + "code_hash": codeobj_hash, + "globals": globals, + "colno": colno, + } if key in self.lambdaFileDict.keys(): # when declaration is placed within a loop, and e.g. globals are updated things might change. # in particular, the code + code_hash stay the same, yet the source code changes - existing_entries = self.lambdaFileDict[key] # how many can there be? assume 1 at most! + existing_entries = self.lambdaFileDict[ + key + ] # how many can there be? assume 1 at most! updated_existing = False for i, existing_entry in enumerate(existing_entries): - if existing_entry['code'] == entry['code'] and \ - existing_entry['code_hash'] == entry['code_hash'] and \ - existing_entry['colno'] == entry['colno']: - self.lambdaFileDict[key][i] = entry # update entry in existing file/lineno dict + if ( + existing_entry["code"] == entry["code"] + and existing_entry["code_hash"] == entry["code_hash"] + and existing_entry["colno"] == entry["colno"] + ): + self.lambdaFileDict[key][i] = ( + entry # update entry in existing file/lineno dict + ) updated_existing = True if not updated_existing: # add new entry @@ -287,31 +306,39 @@ def extractAndPutAllLambdas(self, src_info, filename, lineno, colno, globals): else: # check that there are no globals when extracting function! if colno is None and len(globals) != 0: - raise Exception('Found more than one lambda expression on {}:+{}. Either use ' - 'a patched interpreter, which supports __code__.co_firstcolno for lambda ' - 'expressions or make sure to have at most one lambda expression ' - 'on this line'.format(os.path.basename(filename), lineno)) + raise Exception( + "Found more than one lambda expression on {}:+{}. Either use " + "a patched interpreter, which supports __code__.co_firstcolno for lambda " + "expressions or make sure to have at most one lambda expression " + "on this line".format(os.path.basename(filename), lineno) + ) for lam in Lams: code = gen_code_for_lambda(lam) if 0 == len(code): - raise Exception('Couldn\'t generate code again for lambda function.') + raise Exception("Couldn't generate code again for lambda function.") lam_colno = lam.col_offset + len(source) - len(source.lstrip()) # => could also extract code from the string then via col_offsets etc.s # however, to simplify code, use astor. key = (filename, lineno) - codeobj = compile(code, '', 'eval') + codeobj = compile(code, "", "eval") # hash evaluated code object's code codeobj_hash = hash_code_object(eval(codeobj).__code__) if colno is None: # interpreter not patched - assert len(globals) == 0, 'this path should only be taken if there are no globs' + assert len(globals) == 0, ( + "this path should only be taken if there are no globs" + ) # can't associate globals clearly - entry = {'code': code, 'code_hash': codeobj_hash, - 'globals': {}, 'colno': lam_colno} + entry = { + "code": code, + "code_hash": codeobj_hash, + "globals": {}, + "colno": lam_colno, + } if key in self.lambdaFileDict.keys(): self.lambdaFileDict[key].append(entry) @@ -319,8 +346,12 @@ def extractAndPutAllLambdas(self, src_info, filename, lineno, colno, globals): self.lambdaFileDict[key] = [entry] else: # simply add the lambda with colno & co. - entry = {'code': code, 'code_hash': codeobj_hash, - 'globals': globals, 'colno': colno} + entry = { + "code": code, + "code_hash": codeobj_hash, + "globals": globals, + "colno": colno, + } if key in self.lambdaFileDict.keys(): self.lambdaFileDict[key].append(entry) diff --git a/tuplex/python/tuplex/utils/tracebacks.py b/tuplex/python/tuplex/utils/tracebacks.py index 480ca2d4c..18fb138e8 100644 --- a/tuplex/python/tuplex/utils/tracebacks.py +++ b/tuplex/python/tuplex/utils/tracebacks.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# # # # Tuplex: Blazing Fast Python Data Science # # # @@ -7,14 +7,15 @@ # (c) 2017 - 2021, Tuplex team # # Created by Leonhard Spiegelberg first on 1/1/2021 # # License: Apache 2.0 # -#----------------------------------------------------------------------------------------------------------------------# +# ----------------------------------------------------------------------------------------------------------------------# import traceback import linecache import re from .reflection import get_source -__all__ = ['traceback_from_udf'] +__all__ = ["traceback_from_udf"] + def format_traceback(tb, function_name): """ @@ -28,15 +29,13 @@ def format_traceback(tb, function_name): """ fnames = set() - out = '' + out = "" for frame, lineno in traceback.walk_tb(tb): co = frame.f_code filename = co.co_filename - name = co.co_name fnames.add(filename) linecache.lazycache(filename, frame.f_globals) - f_locals = frame.f_locals line = linecache.getline(filename, lineno).strip() # @Todo: maybe this is faster possible when strip is ignored, by counting tabs or so @@ -47,7 +46,10 @@ def format_traceback(tb, function_name): # need here open match for line breaks in function definition. # note the use of ^ to make sure docstrings are not matched wrongly regex = r"^[\t ]*def\s*{}\(.*".format(function_name) - while not re.match(regex, linecache.getline(filename, start_lineno).strip()) and start_lineno > 0: + while ( + not re.match(regex, linecache.getline(filename, start_lineno).strip()) + and start_lineno > 0 + ): start_lineno -= 1 # get line where function def starts via # linecache.getline(filename, start_lineno).strip() @@ -55,13 +57,14 @@ def format_traceback(tb, function_name): # UI is currently formatted with line numbering starting at 1 lineno_correction = -start_lineno + 1 - out += 'line {}, in {}:'.format(lineno + lineno_correction, function_name) - out += '\n\t{}'.format(line) + out += "line {}, in {}:".format(lineno + lineno_correction, function_name) + out += "\n\t{}".format(line) for filename in fnames: linecache.checkcache(filename) return out + # get traceback from sample def traceback_from_udf(udf, x): """ @@ -80,21 +83,25 @@ def traceback_from_udf(udf, x): try: udf(x) except Exception as e: - assert e.__traceback__.tb_next # make sure no exception within this function was raised + assert ( + e.__traceback__.tb_next + ) # make sure no exception within this function was raised etype_name = type(e).__name__ e_msg = e.__str__() - formatted_tb = '' + formatted_tb = "" # case (1): lambda function --> simply use get_source module - if udf.__name__ == '': + if udf.__name__ == "": # Lambda expressions in python consist of one line only. simply iterate code here - formatted_tb = 'line 1, in :\n\t' + get_source(udf) # use reflection module + formatted_tb = "line 1, in :\n\t" + get_source( + udf + ) # use reflection module # case (2) function defined via def else: # print out traceback (with relative line numbers!) formatted_tb = format_traceback(e.__traceback__.tb_next, fname) # return traceback and add exception type + its message - return formatted_tb + '\n\n{}: {}'.format(etype_name, e_msg) - return '' \ No newline at end of file + return formatted_tb + "\n\n{}: {}".format(etype_name, e_msg) + return "" diff --git a/tuplex/python/tuplex/utils/version.py b/tuplex/python/tuplex/utils/version.py index 8a14b5846..40995e4a1 100644 --- a/tuplex/python/tuplex/utils/version.py +++ b/tuplex/python/tuplex/utils/version.py @@ -1,2 +1,2 @@ # (c) L.Spiegelberg 2017 - 2025 -__version__="0.3.7" \ No newline at end of file +__version__ = "0.3.7" From 414504ff7899b5507ba57f64c912ca2042086edf Mon Sep 17 00:00:00 2001 From: Leonhard Spiegelberg Date: Sat, 8 Mar 2025 09:19:23 -0800 Subject: [PATCH 2/5] progress --- tuplex/python/tuplex/dataset.py | 16 ++++++++-------- tuplex/python/tuplex/distributed.py | 9 ++++----- tuplex/python/tuplex/metrics.py | 3 ++- tuplex/python/tuplex/repl/__init__.py | 5 +++-- 4 files changed, 17 insertions(+), 16 deletions(-) diff --git a/tuplex/python/tuplex/dataset.py b/tuplex/python/tuplex/dataset.py index 7feeeb33f..81d2f2e1d 100644 --- a/tuplex/python/tuplex/dataset.py +++ b/tuplex/python/tuplex/dataset.py @@ -521,7 +521,9 @@ def tocsv( code = get_udf_source(part_name_generator) except UDFCodeExtractionError as e: logging.warn( - "Could not extract code for {}. Details:\n{}".format(ftor, e) + "Could not extract code for {}. Details:\n{}".format( + part_name_generator, e + ) ) # clamp max rows @@ -564,7 +566,9 @@ def toorc( code = get_udf_source(part_name_generator) except UDFCodeExtractionError as e: logging.warn( - "Could not extract code for {}. Details:\n{}".format(ftor, e) + "Could not extract code for {}. Details:\n{}".format( + part_name_generator, e + ) ) if num_rows > max_rows: @@ -649,9 +653,7 @@ def aggregateByKey(self, combine, aggregate, initial_value, key_columns): comb_code_pickled = cloudpickle.dumps(combine) except UDFCodeExtractionError as e: logging.warn( - "Could not extract code for combine UDF {}. Details:\n{}".format( - ftor, e - ) + "Could not extract code for combine UDFs. Details:\n{}".format(e) ) try: @@ -660,9 +662,7 @@ def aggregateByKey(self, combine, aggregate, initial_value, key_columns): agg_code_pickled = cloudpickle.dumps(aggregate) except UDFCodeExtractionError as e: logging.warn( - "Could not extract code for aggregate UDF {}. Details:\n{}".format( - ftor, e - ) + "Could not extract code for aggregate UDFs. Details:\n{}".format(e) ) g_comb = get_globals(combine) diff --git a/tuplex/python/tuplex/distributed.py b/tuplex/python/tuplex/distributed.py index 43a59ca75..b0a090f19 100644 --- a/tuplex/python/tuplex/distributed.py +++ b/tuplex/python/tuplex/distributed.py @@ -132,11 +132,11 @@ def create_lambda_role(iam_client, lambda_role): ) logging.info("Created Tuplex AWS Lambda runner role ({})".format(lambda_role)) - # check it exists + # Check that role exists. try: - response = iam_client.get_role(RoleName=lambda_role) - except: - raise Exception("Failed to create AWS Lambda Role") + iam_client.get_role(RoleName=lambda_role) + except botocore.exceptions.ClientError: + raise Exception("Failed to create AWS Lambda Role.") def remove_lambda_role(iam_client, lambda_role): @@ -248,7 +248,6 @@ def upload_lambda( # for runtime, choose https://docs.aws.amazon.com/lambda/latest/dg/lambda-runtimes.html RUNTIME = "provided.al2" HANDLER = "tplxlam" # this is how the executable is called... - ARCHITECTURES = ["x86_64"] DEFAULT_MEMORY_SIZE = 1536 DEFAULT_TIMEOUT = 30 # 30s timeout diff --git a/tuplex/python/tuplex/metrics.py b/tuplex/python/tuplex/metrics.py index 481bc832f..4c155dec2 100644 --- a/tuplex/python/tuplex/metrics.py +++ b/tuplex/python/tuplex/metrics.py @@ -12,7 +12,8 @@ import typing try: - from .libexec.tuplex import _Context + # Module import needed to initialize capture, should revisit. + from .libexec.tuplex import _Context # noqa: F401 from .libexec.tuplex import _Metrics except ModuleNotFoundError as e: logging.error("need to compiled Tuplex first, details: {}".format(e)) diff --git a/tuplex/python/tuplex/repl/__init__.py b/tuplex/python/tuplex/repl/__init__.py index e355fb323..bc0122ecb 100644 --- a/tuplex/python/tuplex/repl/__init__.py +++ b/tuplex/python/tuplex/repl/__init__.py @@ -20,7 +20,7 @@ try: from tuplex.utils.version import __version__ -except: +except (ImportError, NameError): __version__ = "dev" @@ -45,7 +45,8 @@ def TuplexBanner(): os.system("clear") - from tuplex.context import Context + # Module import needed to initialize defaults, should revisit. + from tuplex.context import Context # noqa: F401 _locals = locals() _locals = {key: _locals[key] for key in _locals if key in ["Context"]} From a752af32898b75ac294bb27a9b6453a33f9e5775 Mon Sep 17 00:00:00 2001 From: Leonhard Spiegelberg Date: Sat, 8 Mar 2025 09:23:41 -0800 Subject: [PATCH 3/5] ruff formatting --- tuplex/python/tuplex/__init__.py | 9 ++++++--- tuplex/python/tuplex/context.py | 2 +- tuplex/python/tuplex/dataset.py | 5 +++-- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/tuplex/python/tuplex/__init__.py b/tuplex/python/tuplex/__init__.py index a8f0aa3d1..a6dd057be 100644 --- a/tuplex/python/tuplex/__init__.py +++ b/tuplex/python/tuplex/__init__.py @@ -9,14 +9,17 @@ # License: Apache 2.0 # # ----------------------------------------------------------------------------------------------------------------------# -from tuplex.repl import * +from tuplex.repl import ( + in_jupyter_notebook as in_jupyter_notebook, + in_google_colab as in_google_colab, +) from .context import Context -from .dataset import DataSet +from .dataset import DataSet as DataSet # expose aws setup for better convenience import tuplex.distributed import logging -from tuplex.distributed import setup_aws +from tuplex.distributed import setup_aws as setup_aws from tuplex.utils.version import __version__ as __version__ diff --git a/tuplex/python/tuplex/context.py b/tuplex/python/tuplex/context.py index 0750dc643..2dcf4b90c 100644 --- a/tuplex/python/tuplex/context.py +++ b/tuplex/python/tuplex/context.py @@ -12,7 +12,7 @@ import logging try: - from .libexec.tuplex import _Context, _DataSet, getDefaultOptionsAsJSON + from .libexec.tuplex import _Context, getDefaultOptionsAsJSON except ModuleNotFoundError as e: logging.error("need to compiled Tuplex first, details: {}".format(e)) diff --git a/tuplex/python/tuplex/dataset.py b/tuplex/python/tuplex/dataset.py index 81d2f2e1d..d3d75d97d 100644 --- a/tuplex/python/tuplex/dataset.py +++ b/tuplex/python/tuplex/dataset.py @@ -13,7 +13,8 @@ import logging try: - from .libexec.tuplex import _Context, _DataSet + # Checks that compiled tuplex extension object is present and compatible. + from .libexec.tuplex import _Context, _DataSet # noqa: F401 except ModuleNotFoundError as e: logging.error("need to compiled Tuplex first, details: {}".format(e)) from tuplex.utils.reflection import get_source as get_udf_source @@ -27,7 +28,7 @@ class DataSet: def __init__(self): - self._dataSet = None + self._dataSet: _DataSet = None def unique(self): """removes duplicates from Dataset (out-of-order). Equivalent to a DISTINCT clause in a SQL-statement. From 4808aaa3ed8c3bd1ae90b3d65cb0dcddfeba28d2 Mon Sep 17 00:00:00 2001 From: Leonhard Spiegelberg Date: Sat, 8 Mar 2025 09:33:25 -0800 Subject: [PATCH 4/5] add import sort to ruff --- .pre-commit-config.yaml | 20 ++++++------- tuplex/python/tuplex/__init__.py | 14 ++++------ tuplex/python/tuplex/context.py | 28 ++++++++++--------- tuplex/python/tuplex/dataset.py | 8 ++++-- tuplex/python/tuplex/distributed.py | 4 +-- tuplex/python/tuplex/metrics.py | 6 ++-- tuplex/python/tuplex/repl/__init__.py | 4 +-- tuplex/python/tuplex/utils/common.py | 26 ++++++++--------- tuplex/python/tuplex/utils/globs.py | 6 ++-- .../python/tuplex/utils/interactive_shell.py | 10 ++++--- tuplex/python/tuplex/utils/jedi_completer.py | 3 +- tuplex/python/tuplex/utils/jupyter.py | 3 +- tuplex/python/tuplex/utils/reflection.py | 12 ++++---- tuplex/python/tuplex/utils/source_vault.py | 7 +++-- tuplex/python/tuplex/utils/tracebacks.py | 3 +- 15 files changed, 80 insertions(+), 74 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7231bef27..0d6b6778a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,13 +1,13 @@ repos: -#- repo: https://github.com/pre-commit/pre-commit-hooks -# rev: v2.3.0 -# hooks: -# - id: check-yaml -# exclude: ["tuplex/test/resources"] -# - id: end-of-file-fixer -# exclude: ["tuplex/test/resources"] -# - id: trailing-whitespace -# exclude: ["tuplex/test/resources"] +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.3.0 + hooks: + - id: check-yaml + files: ^tuplex/python/tuplex.*\.py$ + - id: end-of-file-fixer + files: ^tuplex/python/tuplex.*\.py$ + - id: trailing-whitespace + files: ^tuplex/python/tuplex.*\.py$ - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. rev: v0.9.9 @@ -15,7 +15,7 @@ repos: # Run the linter. - id: ruff files: ^tuplex/python/tuplex.*\.py$ - args: [ --fix ] + args: [ "--fix", "--select", "I" ] types_or: [ python, pyi ] # Run the formatter. - id: ruff-format diff --git a/tuplex/python/tuplex/__init__.py b/tuplex/python/tuplex/__init__.py index a6dd057be..ad8d14b5e 100644 --- a/tuplex/python/tuplex/__init__.py +++ b/tuplex/python/tuplex/__init__.py @@ -9,20 +9,18 @@ # License: Apache 2.0 # # ----------------------------------------------------------------------------------------------------------------------# -from tuplex.repl import ( - in_jupyter_notebook as in_jupyter_notebook, - in_google_colab as in_google_colab, -) -from .context import Context -from .dataset import DataSet as DataSet +import logging # expose aws setup for better convenience import tuplex.distributed -import logging from tuplex.distributed import setup_aws as setup_aws - +from tuplex.repl import in_google_colab as in_google_colab +from tuplex.repl import in_jupyter_notebook as in_jupyter_notebook from tuplex.utils.version import __version__ as __version__ +from .context import Context +from .dataset import DataSet as DataSet + # for convenience create a dummy function to return a default-configured Lambda context def LambdaContext(conf=None, name=None, s3_scratch_dir=None, **kwargs): diff --git a/tuplex/python/tuplex/context.py b/tuplex/python/tuplex/context.py index 2dcf4b90c..46763b72d 100644 --- a/tuplex/python/tuplex/context.py +++ b/tuplex/python/tuplex/context.py @@ -16,29 +16,31 @@ except ModuleNotFoundError as e: logging.error("need to compiled Tuplex first, details: {}".format(e)) -from .dataset import DataSet -import os import glob +import json +import os import sys +import uuid + from tuplex.utils.common import ( + current_user, + ensure_webui, flatten_dict, - load_conf_yaml, - stringify_dict, - unflatten_dict, - save_conf_yaml, - in_jupyter_notebook, + host_name, in_google_colab, + in_jupyter_notebook, is_in_interactive_mode, - current_user, is_shared_lib, - host_name, - ensure_webui, - pythonize_options, + load_conf_yaml, logging_callback, + pythonize_options, registerLoggingCallback, + save_conf_yaml, + stringify_dict, + unflatten_dict, ) -import uuid -import json + +from .dataset import DataSet from .metrics import Metrics diff --git a/tuplex/python/tuplex/dataset.py b/tuplex/python/tuplex/dataset.py index d3d75d97d..27e0d37a7 100644 --- a/tuplex/python/tuplex/dataset.py +++ b/tuplex/python/tuplex/dataset.py @@ -9,17 +9,19 @@ # License: Apache 2.0 # # ----------------------------------------------------------------------------------------------------------------------# -import cloudpickle import logging +import cloudpickle + try: # Checks that compiled tuplex extension object is present and compatible. from .libexec.tuplex import _Context, _DataSet # noqa: F401 except ModuleNotFoundError as e: logging.error("need to compiled Tuplex first, details: {}".format(e)) -from tuplex.utils.reflection import get_source as get_udf_source -from tuplex.utils.reflection import get_globals from tuplex.utils.framework import UDFCodeExtractionError +from tuplex.utils.reflection import get_globals +from tuplex.utils.reflection import get_source as get_udf_source + from .exceptions import classToExceptionCode # signed 64bit limit diff --git a/tuplex/python/tuplex/distributed.py b/tuplex/python/tuplex/distributed.py index b0a090f19..096bf56a3 100644 --- a/tuplex/python/tuplex/distributed.py +++ b/tuplex/python/tuplex/distributed.py @@ -17,10 +17,10 @@ pass # raise Exception('To use distributed version, please install boto3') -import logging -import os import base64 import datetime +import logging +import os import sys import threading import time diff --git a/tuplex/python/tuplex/metrics.py b/tuplex/python/tuplex/metrics.py index 4c155dec2..0bcf3fd34 100644 --- a/tuplex/python/tuplex/metrics.py +++ b/tuplex/python/tuplex/metrics.py @@ -13,8 +13,10 @@ try: # Module import needed to initialize capture, should revisit. - from .libexec.tuplex import _Context # noqa: F401 - from .libexec.tuplex import _Metrics + from .libexec.tuplex import ( + _Context, # noqa: F401 + _Metrics, + ) except ModuleNotFoundError as e: logging.error("need to compiled Tuplex first, details: {}".format(e)) _Metrics = typing.Any diff --git a/tuplex/python/tuplex/repl/__init__.py b/tuplex/python/tuplex/repl/__init__.py index bc0122ecb..ef1fb8e83 100644 --- a/tuplex/python/tuplex/repl/__init__.py +++ b/tuplex/python/tuplex/repl/__init__.py @@ -13,9 +13,9 @@ import sys from tuplex.utils.common import ( - is_in_interactive_mode, - in_jupyter_notebook, in_google_colab, + in_jupyter_notebook, + is_in_interactive_mode, ) try: diff --git a/tuplex/python/tuplex/utils/common.py b/tuplex/python/tuplex/utils/common.py index 39b7916ac..e0708fc3d 100644 --- a/tuplex/python/tuplex/utils/common.py +++ b/tuplex/python/tuplex/utils/common.py @@ -9,28 +9,26 @@ # License: Apache 2.0 # # ----------------------------------------------------------------------------------------------------------------------# import atexit -import sys import collections - import collections.abc -import pathlib -import signal - -import yaml -from datetime import datetime - import json -import urllib.request +import logging import os -import socket +import pathlib +import re import shutil -import psutil +import signal +import socket import subprocess -import logging -import iso8601 -import re +import sys import tempfile import time +import urllib.request +from datetime import datetime + +import iso8601 +import psutil +import yaml try: import pwd diff --git a/tuplex/python/tuplex/utils/globs.py b/tuplex/python/tuplex/utils/globs.py index 116442e9f..f938e5035 100644 --- a/tuplex/python/tuplex/utils/globs.py +++ b/tuplex/python/tuplex/utils/globs.py @@ -9,12 +9,12 @@ # License: Apache 2.0 # # ----------------------------------------------------------------------------------------------------------------------# -import types -import weakref import dis -import opcode import itertools +import opcode import sys +import types +import weakref # ALWAYS import cloudpickle before dill, b.c. of https://github.com/uqfoundation/dill/issues/383 from cloudpickle.cloudpickle import _get_cell_contents diff --git a/tuplex/python/tuplex/utils/interactive_shell.py b/tuplex/python/tuplex/utils/interactive_shell.py index 91fd74fbd..56a929b02 100644 --- a/tuplex/python/tuplex/utils/interactive_shell.py +++ b/tuplex/python/tuplex/utils/interactive_shell.py @@ -11,11 +11,13 @@ from __future__ import unicode_literals +import logging import os -import sys import re -import logging +import sys from code import InteractiveConsole +from types import FunctionType, LambdaType + from prompt_toolkit.history import InMemoryHistory # old version: 1.0 @@ -30,10 +32,10 @@ from prompt_toolkit.styles.pygments import style_from_pygments_cls from pygments.lexers import Python3Lexer from pygments.styles import get_style_by_name + +from tuplex.utils.globs import get_globals from tuplex.utils.jedi_completer import JediCompleter from tuplex.utils.source_vault import SourceVault -from types import LambdaType, FunctionType -from tuplex.utils.globs import get_globals # this is a helper to allow for tuplex.Context syntax diff --git a/tuplex/python/tuplex/utils/jedi_completer.py b/tuplex/python/tuplex/utils/jedi_completer.py index b1260942e..deecf8517 100644 --- a/tuplex/python/tuplex/utils/jedi_completer.py +++ b/tuplex/python/tuplex/utils/jedi_completer.py @@ -9,9 +9,8 @@ # License: Apache 2.0 # # ----------------------------------------------------------------------------------------------------------------------# +from jedi import Interpreter, settings from prompt_toolkit.completion import Completer, Completion -from jedi import Interpreter -from jedi import settings class JediCompleter(Completer): diff --git a/tuplex/python/tuplex/utils/jupyter.py b/tuplex/python/tuplex/utils/jupyter.py index 10bf7791a..40fa34f70 100644 --- a/tuplex/python/tuplex/utils/jupyter.py +++ b/tuplex/python/tuplex/utils/jupyter.py @@ -12,9 +12,10 @@ import json import os.path import re -import ipykernel import urllib.request from urllib.parse import urljoin + +import ipykernel from notebook.notebookapp import list_running_servers diff --git a/tuplex/python/tuplex/utils/reflection.py b/tuplex/python/tuplex/utils/reflection.py index 7066ba74b..258da0b27 100644 --- a/tuplex/python/tuplex/utils/reflection.py +++ b/tuplex/python/tuplex/utils/reflection.py @@ -9,22 +9,22 @@ # License: Apache 2.0 # # ----------------------------------------------------------------------------------------------------------------------# -import types +import ast import inspect import re +import types # ALWAYS import cloudpickle before dill, b.c. of https://github.com/uqfoundation/dill/issues/383 import dill -import ast -from tuplex.utils.errors import TuplexException -from tuplex.utils.globs import get_globals -from tuplex.utils.source_vault import SourceVault, supports_lambda_closure from tuplex.utils.common import ( - in_jupyter_notebook, in_google_colab, + in_jupyter_notebook, is_in_interactive_mode, ) +from tuplex.utils.errors import TuplexException +from tuplex.utils.globs import get_globals +from tuplex.utils.source_vault import SourceVault, supports_lambda_closure # only export get_source function, rest shall be private. __all__ = ["get_source", "get_globals", "supports_lambda_closure"] diff --git a/tuplex/python/tuplex/utils/source_vault.py b/tuplex/python/tuplex/utils/source_vault.py index 48cb17493..6e25e8d60 100644 --- a/tuplex/python/tuplex/utils/source_vault.py +++ b/tuplex/python/tuplex/utils/source_vault.py @@ -10,11 +10,12 @@ # ----------------------------------------------------------------------------------------------------------------------# import ast -import astor +import logging import os import sys -from types import LambdaType, CodeType -import logging +from types import CodeType, LambdaType + +import astor def supports_lambda_closure(): diff --git a/tuplex/python/tuplex/utils/tracebacks.py b/tuplex/python/tuplex/utils/tracebacks.py index 18fb138e8..eb5ba3aed 100644 --- a/tuplex/python/tuplex/utils/tracebacks.py +++ b/tuplex/python/tuplex/utils/tracebacks.py @@ -9,9 +9,10 @@ # License: Apache 2.0 # # ----------------------------------------------------------------------------------------------------------------------# -import traceback import linecache import re +import traceback + from .reflection import get_source __all__ = ["traceback_from_udf"] From b14d3eaec0493c8baa38b3f035a86efb295d6ecc Mon Sep 17 00:00:00 2001 From: Leonhard Spiegelberg Date: Sat, 8 Mar 2025 12:41:42 -0800 Subject: [PATCH 5/5] fix order to prevent wait issue --- tuplex/core/src/Executor.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tuplex/core/src/Executor.cc b/tuplex/core/src/Executor.cc index 5078d4bc2..7b74a9937 100644 --- a/tuplex/core/src/Executor.cc +++ b/tuplex/core/src/Executor.cc @@ -104,15 +104,17 @@ namespace tuplex { // save which thread executed this task task->setID(std::this_thread::get_id()); - _numPendingTasks.fetch_add(-1, std::memory_order_release); - - // add task to done list + // Add task to done list, execute before decreasing pending task. TRACE_LOCK("completedTasks"); _completedTasksMutex.lock(); _completedTasks.push_back(std::move(task)); _completedTasksMutex.unlock(); _numCompletedTasks.fetch_add(1, std::memory_order_release); TRACE_UNLOCK("completedTasks"); + + // This needs to come last, because other threads may be waiting on it. + _numPendingTasks.fetch_add(-1, std::memory_order_release); + return true; } } else {