diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 58708f13..65900e99 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -32,12 +32,8 @@ jobs: - uses: actions/checkout@v4 with: fetch-tags: 1 # Essential to later commitizen - fetch-depth: 0 # Reccommended by the action + fetch-depth: 0 # Recommended by the action token: ${{ secrets.PUSH_ACCESS }} - - uses: actions/setup-python@v5 - with: - python-version: "3.10" - cache: pip - run: git tag # Debug statement - name: Create bump and changelog uses: commitizen-tools/commitizen-action@master diff --git a/.gitignore b/.gitignore index 42f279fe..d2f0e29d 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ example-* results typings unknown-trial-bucket +.ipynb_checkpoints diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bf4bfd31..cfa95ba4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,7 +38,7 @@ repos: - "--ignore-missing-imports" - "--show-traceback" - repo: https://github.com/python-jsonschema/check-jsonschema - rev: 0.27.3 + rev: 0.27.4 hooks: - id: check-github-workflows files: '^github/workflows/.*\.ya?ml$' @@ -46,11 +46,11 @@ repos: - id: check-dependabot files: '^\.github/dependabot\.ya?ml$' - repo: https://github.com/commitizen-tools/commitizen - rev: v3.13.0 + rev: v3.14.0 hooks: - id: commitizen - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.1.14 + rev: v0.2.0 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix, --no-cache] diff --git a/CHANGELOG.md b/CHANGELOG.md index eb172db2..a22db719 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,74 @@ +## 1.12.1 (2024-08-13) + +### Fix + +- **pipeline**: `request` now fails without default (#284) + +## 1.12.0 (2024-04-24) + +### Feat + +- **PyTorch**: Add functionality to construct a PyTorch Model from a pipeline (#276) + +### Fix + +- Pass in sampler to `create_study` (#282) +- **Pytorch**: Fix builders.py (#280) +- precommit issues from #276 (#277) + +## 1.11.0 (2024-02-29) + +### Feat + +- **CVEvaluator**: Add feature for post_split and post_processing (#260) +- **sklearn**: `X_test`, `y_test` to CVEvaluator (#258) +- CVEarlyStopping (#254) +- **sklearn**: CVEvaluator allows `configure` and `build` params (#250) +- **sklearn**: Provide a standard CVEvaluator (#244) + +### Fix + +- **trial**: Don't record metric values for deserialized NaN's or None (#263) +- **pipeline**: Ensure optimizer is updated with report (#261) +- **scheduling**: Safe termination of processes, avoiding lifetime race condition (#256) +- **metalearning**: Portfolio Check for Dataframe as Input (#253) +- **CVEvaluator**: `clone` the estimator before use (#249) +- **Node**: Ensure that parent name does not conflict with children (#248) +- **CVEvaluator**: When on_error="raise", inform of which trial failed (#247) +- **Trial**: Give trials a created_at stamp (#246) + +### Refactor + +- **pipeline**: `optimize` now requires one of `timeout` or (#252) +- **Metric, Trial**: Cleanup of metrics and `Trial` (#242) + +## 1.10.1 (2024-01-28) + +### Fix + +- **dask-jobqueue**: Make sure to close client + +### Refactor + +- Make things more context manager +- **trial**: Remove `begin()` (#238) + +## 1.10.0 (2024-01-26) + +### Feat + +- **Pipeline**: Optimize pipelines directly with `optimize()` (#230) + +## 1.9.0 (2024-01-26) + +### Feat + +- **Optimizer**: Allow for batch ask requests (#224) + +### Fix + +- **Pynisher**: Ensure system supports limit (#223) + ## 1.8.0 (2024-01-22) ### Feat diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 00000000..2ad3c7b7 --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,90 @@ +cff-version: "1.2.0" +authors: +- family-names: Bergman + given-names: Edward + orcid: "https://orcid.org/0009-0003-4390-7614" +- family-names: Feurer + given-names: Matthias + orcid: "https://orcid.org/0000-0001-9611-8588" +- family-names: Bahram + given-names: Aron + orcid: "https://orcid.org/0009-0002-8896-2863" +- family-names: Balef + given-names: Amir Rezaei + orcid: "https://orcid.org/0000-0002-6882-0051" +- family-names: Purucker + given-names: Lennart + orcid: "https://orcid.org/0009-0001-1181-0549" +- family-names: Segel + given-names: Sarah + orcid: "https://orcid.org/0009-0005-2966-266X" +- family-names: Lindauer + given-names: Marius + orcid: "https://orcid.org/0000-0002-9675-3175" +- family-names: Hutter + given-names: Frank + orcid: "https://orcid.org/0000-0002-2037-3694" +- family-names: Eggensperger + given-names: Katharina + orcid: "https://orcid.org/0000-0002-0309-401X" +contact: +- family-names: Bergman + given-names: Edward + orcid: "https://orcid.org/0009-0003-4390-7614" +- family-names: Feurer + given-names: Matthias + orcid: "https://orcid.org/0000-0001-9611-8588" +- family-names: Lindauer + given-names: Marius + orcid: "https://orcid.org/0000-0002-9675-3175" +- family-names: Hutter + given-names: Frank + orcid: "https://orcid.org/0000-0002-2037-3694" +- family-names: Eggensperger + given-names: Katharina + orcid: "https://orcid.org/0000-0002-0309-401X" +doi: 10.5281/zenodo.13309537 +message: If you use this software, please cite our article in the + Journal of Open Source Software. +preferred-citation: + authors: + - family-names: Bergman + given-names: Edward + orcid: "https://orcid.org/0009-0003-4390-7614" + - family-names: Feurer + given-names: Matthias + orcid: "https://orcid.org/0000-0001-9611-8588" + - family-names: Bahram + given-names: Aron + orcid: "https://orcid.org/0009-0002-8896-2863" + - family-names: Balef + given-names: Amir Rezaei + orcid: "https://orcid.org/0000-0002-6882-0051" + - family-names: Purucker + given-names: Lennart + orcid: "https://orcid.org/0009-0001-1181-0549" + - family-names: Segel + given-names: Sarah + orcid: "https://orcid.org/0009-0005-2966-266X" + - family-names: Lindauer + given-names: Marius + orcid: "https://orcid.org/0000-0002-9675-3175" + - family-names: Hutter + given-names: Frank + orcid: "https://orcid.org/0000-0002-2037-3694" + - family-names: Eggensperger + given-names: Katharina + orcid: "https://orcid.org/0000-0002-0309-401X" + date-published: 2024-08-14 + doi: 10.21105/joss.06367 + issn: 2475-9066 + issue: 100 + journal: Journal of Open Source Software + publisher: + name: Open Journals + start: 6367 + title: "AMLTK: A Modular AutoML Toolkit in Python" + type: article + url: "https://joss.theoj.org/papers/10.21105/joss.06367" + volume: 9 +title: "AMLTK: A Modular AutoML Toolkit in Python" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9b38fdec..9018a063 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -49,6 +49,8 @@ just test # Run the documentation, fix any warnings just docs +# just docs-code # Run code and display output (slower) +# just docs-full # Run examples and code (slowest) # Run pre-commit checks just check @@ -297,6 +299,9 @@ You can find a collection of features for custom documentation [here](https://squidfunk.github.io/mkdocs-material/reference/) as well as code reference documentation [here](https://mkdocstrings.github.io/usage/) +You can find the entry point for the documentation infrastructure of `mkdocs` +in `mkdocs.yml`. + ### Viewing Documentation You can live view documentation changes by running `just docs`, diff --git a/README.md b/README.md index be946425..d771ed8d 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ [![image](https://img.shields.io/pypi/pyversions/amltk.svg)](https://pypi.python.org/pypi/amltk) [![Actions](https://github.com/automl/amltk/actions/workflows/test.yml/badge.svg)](https://github.com/automl/amltk/actions) [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) +[![DOI](https://joss.theoj.org/papers/10.21105/joss.06367/status.svg)](https://doi.org/10.21105/joss.06367) # AutoML Toolkit A framework for building an AutoML System. The toolkit is designed to be modular and extensible, allowing you to @@ -67,33 +68,32 @@ Here's a brief overview of 3 of the core components from the toolkit: ### Pipelines Define **parametrized** machine learning pipelines using a fluid API: ```python -from amltk.pipeline import Component, Choice, Sequential -from sklearn.ensemble import RandomForestClasifier -from sklearn.preprocessing import OneHotEncoder +from sklearn.ensemble import RandomForestClassifier from sklearn.impute import SimpleImputer +from sklearn.preprocessing import OneHotEncoder from sklearn.svm import SVC +from amltk.pipeline import Choice, Component, Sequential + pipeline = ( Sequential(name="my_pipeline") - >> Component(SimpleImputer, space={"strategy": ["mean", "median"]}), # Choose either mean or median + >> Component(SimpleImputer, space={"strategy": ["mean", "median"]}) # Choose either mean or median >> OneHotEncoder(drop="first") # No parametrization, no problem >> Choice( # Our pipeline can choose between two different estimators Component( RandomForestClassifier, - space={ - "n_estimators": (10, 100), - "criterion": ["gini", "log_loss"] - }, - config={"max_depth":3} + space={"n_estimators": (10, 100), "criterion": ["gini", "log_loss"]}, + config={"max_depth": 3}, ), Component(SVC, space={"kernel": ["linear", "rbf", "poly"]}), - name="estimator" + name="estimator", ) ) # Parser the search space with implemented or you custom parser -search_space = pipeline.search_space(parser=...) +search_space = pipeline.search_space(parser="configspace") +config = search_space.sample_configuration() # Configure a pipeline configured_pipeline = pipeline.configure(config) @@ -114,24 +114,23 @@ accuracy = Metric("accuracy", maximize=True, bounds=(0. 1)) inference_time = Metric("inference_time", maximize=False) def evaluate(trial: Trial) -> Trial.Report: + model = pipeline.configure(trial.config).build("sklearn") - # Say when and where you trial begins - with trial.begin(): - model = pipeline.configure(trial.config).build("sklearn") - + try: # Profile the things you'd like with trial.profile("fit"): model.fit(...) - # Record anything else you'd like - trial.summary["model_size"] = ... + except Exception as e: + # Generate reports from exceptions easily + return trial.fail(exception=e) - # Store whatever you'd like - trial.store({"model.pkl": model, "predictions.npy": predictions}), - return trial.success(accuracy=0.8, inference_time=...) + # Record anything else you'd like + trial.summary["model_size"] = ... - if trial.exception: - return trial.fail() + # Store whatever you'd like + trial.store({"model.pkl": model, "predictions.npy": predictions}), + return trial.success(accuracy=0.8, inference_time=...) # Easily swap between optimizers, without needing to change the rest of your code from amltk.optimization.optimizers.smac import SMACOptimizer diff --git a/docs/_templates/python/material/attribute.html b/docs/_templates/python/material/attribute.html deleted file mode 100644 index 4f234307..00000000 --- a/docs/_templates/python/material/attribute.html +++ /dev/null @@ -1,70 +0,0 @@ -{{ log.debug("Rendering " + attribute.path) }} - -
-{% with html_id = attribute.path %} - - {% if root %} - {% set show_full_path = config.show_root_full_path %} - {% set root_members = True %} - {% elif root_members %} - {% set show_full_path = config.show_root_members_full_path or config.show_object_full_path %} - {% set root_members = False %} - {% else %} - {% set show_full_path = config.show_object_full_path %} - {% endif %} - - {% if not root or config.show_root_heading %} - - {% filter heading(heading_level, - role="data" if attribute.parent.kind.value == "module" else "attr", - id=html_id, - class="doc doc-heading", - toc_label=attribute.name, - ) - %} - - {% if config.separate_signature %} - {% if show_full_path %}{{ attribute.path }}{% else %}{{ attribute.name }}{% endif %} - {% else %} - {% filter highlight(language="python", inline=True) %} - {% if show_full_path %}{{ attribute.path }}{% else %}{{ attribute.name }}{% endif %} - {% if attribute.annotation %}: {{ attribute.annotation }}{% endif %} - {% endfilter %} - {% endif %} - - {% with labels = attribute.labels %} - {% include "labels.html" with context %} - {% endwith %} - - {% endfilter %} - - {% if config.separate_signature %} - {% filter highlight(language="python", inline=False) %} - {% filter format_code(config.line_length) %} - {% if show_full_path %}{{ attribute.path }}{% else %}{{ attribute.name }}{% endif %} - {% if attribute.annotation %}: {{ attribute.annotation|safe }}{% endif %} - # {% if attribute.value %} = {{ attribute.value|safe }}{% endif %} - {% endfilter %} - {% endfilter %} - {% endif %} - - {% else %} - {% if config.show_root_toc_entry %} - {% filter heading(heading_level, - role="data" if attribute.parent.kind.value == "module" else "attr", - id=html_id, - toc_label=attribute.path if config.show_root_full_path else attribute.name, - hidden=True) %} - {% endfilter %} - {% endif %} - {% set heading_level = heading_level - 1 %} - {% endif %} - -
- {% with docstring_sections = attribute.docstring.parsed %} - {% include "docstring.html" with context %} - {% endwith %} -
- -{% endwith %} -
diff --git a/docs/_templates/python/material/class.html b/docs/_templates/python/material/class.html deleted file mode 100644 index 0b5c8124..00000000 --- a/docs/_templates/python/material/class.html +++ /dev/null @@ -1,115 +0,0 @@ -{{ log.debug("Rendering " + class.path) }} - -
-{% with html_id = class.path %} - - {% if root %} - {% set show_full_path = config.show_root_full_path %} - {% set root_members = True %} - {% elif root_members %} - {% set show_full_path = config.show_root_members_full_path or config.show_object_full_path %} - {% set root_members = False %} - {% else %} - {% set show_full_path = config.show_object_full_path %} - {% endif %} - - {% if not root or config.show_root_heading %} - - {% filter heading(heading_level, - role="class", - id=html_id, - class="doc doc-heading", - toc_label=class.name) %} - - {% if config.separate_signature %} - {% if show_full_path %}{{ class.path }}{% else %}{{ class.name }}{% endif %} - {% elif config.merge_init_into_class and "__init__" in class.members -%} - {%- with function = class.members["__init__"] -%} - {%- filter highlight(language="python", inline=True) -%} - {% if show_full_path %}{{ class.path }}{% else %}class {{ class.name }}{% endif %} - {%- include "signature.html" with context -%} - {%- endfilter -%} - {%- endwith -%} - {% else %} - {% filter highlight(language="python", inline=True) %} - {% if show_full_path %}{{ class.path }}{% else %}class {{ class.name }}{% endif %} - {% endfilter %} - {% endif %} - - {% with labels = class.labels %} - {% include "labels.html" with context %} - {% endwith %} - - {% endfilter %} - - {% if config.separate_signature and config.merge_init_into_class %} - {% if "__init__" in class.members %} - {% with function = class.members["__init__"] %} - {% filter highlight(language="python", inline=False) %} - {% filter format_signature(config.line_length) %} - {% if show_full_path %}{{ class.path }}{% else %}{{ class.name }}{% endif %} - {% include "signature.html" with context %} - {% endfilter %} - {% endfilter %} - {% endwith %} - {% endif %} - {% endif %} - - {% else %} - {% if config.show_root_toc_entry %} - {% filter heading(heading_level, - role="class", - id=html_id, - toc_label=class.path if config.show_root_full_path else class.name, - hidden=True) %} - {% endfilter %} - {% endif %} - {% set heading_level = heading_level - 1 %} - {% endif %} - -
- {% if config.show_bases and class.bases %} -

- Bases: {% for expression in class.bases -%} - {% include "expression.html" with context %}{% if not loop.last %}, {% endif %} - {% endfor -%} -

- {% endif %} - - {% with docstring_sections = class.docstring.parsed %} - {% include "docstring.html" with context %} - {% endwith %} - - {% if config.merge_init_into_class %} - {% if "__init__" in class.members and class.members["__init__"].has_docstring %} - {% with docstring_sections = class.members["__init__"].docstring.parsed %} - {% include "docstring.html" with context %} - {% endwith %} - {% endif %} - {% endif %} - - {% if config.show_source %} - {% if config.merge_init_into_class %} - {% if "__init__" in class.members and class.members["__init__"].source %} -
- Source code in {{ class.relative_filepath }} - {{ class.members["__init__"].source|highlight(language="python", linestart=class.members["__init__"].lineno, linenums=True) }} -
- {% endif %} - {% elif class.source %} -
- Source code in {{ class.relative_filepath }} - {{ class.source|highlight(language="python", linestart=class.lineno, linenums=True) }} -
- {% endif %} - {% endif %} - - {% with obj = class %} - {% set root = False %} - {% set heading_level = heading_level + 1 %} - {% include "children.html" with context %} - {% endwith %} -
- -{% endwith %} -
diff --git a/docs/_templates/python/material/function.html b/docs/_templates/python/material/function.html deleted file mode 100644 index 171721ee..00000000 --- a/docs/_templates/python/material/function.html +++ /dev/null @@ -1,74 +0,0 @@ -{{ log.debug("Rendering " + function.path) }} - -
-{% with html_id = function.path %} - - {% if root %} - {% set show_full_path = config.show_root_full_path %} - {% set root_members = True %} - {% elif root_members %} - {% set show_full_path = config.show_root_members_full_path or config.show_object_full_path %} - {% set root_members = False %} - {% else %} - {% set show_full_path = config.show_object_full_path %} - {% endif %} - - {% if not root or config.show_root_heading %} - - {% filter heading(heading_level, - role="function", - id=html_id, - class="doc doc-heading", - toc_label=function.name ~ "()") %} - - {% if config.separate_signature %} - {% if show_full_path %}{{ function.path }}{% else %}{{ function.name }}{% endif %} - {% else %} - {% filter highlight(language="python", inline=True) %} - {% if show_full_path %}{{ function.path }}{% else %}def {{ function.name }}{% endif %} - {% include "signature.html" with context %} - {% endfilter %} - {% endif %} - - {% with labels = function.labels %} - {% include "labels.html" with context %} - {% endwith %} - - {% endfilter %} - - {% if config.separate_signature %} - {% filter highlight(language="python", inline=False) %} - {% filter format_signature(config.line_length) %} - {% if show_full_path %}{{ function.path }}{% else %}{{ function.name }}{% endif %} - {% include "signature.html" with context %} - {% endfilter %} - {% endfilter %} - {% endif %} - - {% else %} - {% if config.show_root_toc_entry %} - {% filter heading(heading_level, - role="function", - id=html_id, - toc_label=function.path if config.show_root_full_path else function.name, - hidden=True) %} - {% endfilter %} - {% endif %} - {% set heading_level = heading_level - 1 %} - {% endif %} - -
- {% with docstring_sections = function.docstring.parsed %} - {% include "docstring.html" with context %} - {% endwith %} - - {% if config.show_source and function.source %} -
- Source code in {{ function.relative_filepath }} - {{ function.source|highlight(language="python", linestart=function.lineno, linenums=True) }} -
- {% endif %} -
- -{% endwith %} -
diff --git a/docs/_templates/python/material/labels.html b/docs/_templates/python/material/labels.html deleted file mode 100644 index 27fd24c9..00000000 --- a/docs/_templates/python/material/labels.html +++ /dev/null @@ -1,9 +0,0 @@ -{% if labels %} - {{ log.debug("Rendering labels") }} - -
- {% for label in labels|sort %} - {{ label|replace("property", "prop")|replace("instance-attribute", "attr")|replace("dataclass", "dataclass")|replace("class-attribute", "classvar") }} - {% endfor %} -
-{% endif %} diff --git a/docs/_templates/python/material/signature.html b/docs/_templates/python/material/signature.html deleted file mode 100644 index ae267f5f..00000000 --- a/docs/_templates/python/material/signature.html +++ /dev/null @@ -1,48 +0,0 @@ -{%- if config.show_signature -%} - {{ log.debug("Rendering signature") }} - {%- with -%} - - {%- set ns = namespace(has_pos_only=False, render_pos_only_separator=True, render_kw_only_separator=True, equal="=") -%} - - {%- if config.show_signature_annotations -%} - {%- set ns.equal = " = " -%} - {%- endif -%} - - ( - {%- for parameter in function.parameters -%} - {%- if parameter.name not in ("self", "cls") or loop.index0 > 0 or not (function.parent and function.parent.is_class) -%} - - {%- if parameter.kind.value == "positional-only" -%} - {%- set ns.has_pos_only = True -%} - {%- else -%} - {%- if ns.has_pos_only and ns.render_pos_only_separator -%} - {%- set ns.render_pos_only_separator = False %}/, {% endif -%} - {%- if parameter.kind.value == "keyword-only" -%} - {%- if ns.render_kw_only_separator -%} - {%- set ns.render_kw_only_separator = False %}*, {% endif -%} - {%- endif -%} - {%- endif -%} - - {%- if config.show_signature_annotations and parameter.annotation is not none -%} - {%- set annotation = ": " + parameter.annotation|safe -%} - {%- endif -%} - - {%- if parameter.default is not none and parameter.kind.value != "variadic positional" and parameter.kind.value != "variadic keyword" -%} - {%- set default = ns.equal + parameter.default|safe -%} - {%- endif -%} - - {%- if parameter.kind.value == "variadic positional" -%} - {%- set ns.render_kw_only_separator = False -%} - {%- endif -%} - - {% if parameter.kind.value == "variadic positional" %}*{% elif parameter.kind.value == "variadic keyword" %}**{% endif -%} - {{ parameter.name }}{{ annotation }}{{ default }} - {%- if not loop.last %}, {% endif -%} - - {%- endif -%} - {%- endfor -%} - ) - {%- if config.show_signature_annotations and function.annotation %} -> {{ function.annotation|safe }}{%- endif -%} - - {%- endwith -%} -{%- endif -%} diff --git a/docs/api_generator.py b/docs/api_generator.py index b8279ce0..038ed9dd 100644 --- a/docs/api_generator.py +++ b/docs/api_generator.py @@ -11,6 +11,13 @@ logger = logging.getLogger(__name__) +# Modules whose members should not include inherited attributes or methods +# NOTE: Given the current setup, we can only operate at a module level. +# Ideally we specify options (at least at a module level) and we render +# them into strings using a yaml parser. For now this is fine though +NO_INHERITS = ("sklearn.evaluation",) +TAB = " " + for path in sorted(Path("src").rglob("*.py")): module_path = path.relative_to("src").with_suffix("") doc_path = path.relative_to("src").with_suffix(".md") @@ -28,4 +35,8 @@ ident = ".".join(parts) fd.write(f"::: {ident}") + if ident.endswith(NO_INHERITS): + fd.write(f"\n{TAB}options:") + fd.write(f"\n{TAB}{TAB}inherited_members: false") + mkdocs_gen_files.set_edit_path(full_doc_path, path) diff --git a/docs/changelog.md b/docs/changelog.md new file mode 100644 index 00000000..786b75d5 --- /dev/null +++ b/docs/changelog.md @@ -0,0 +1 @@ +--8<-- "CHANGELOG.md" diff --git a/docs/example_runner.py b/docs/example_runner.py index 03c24c86..ef0c84f0 100644 --- a/docs/example_runner.py +++ b/docs/example_runner.py @@ -13,10 +13,9 @@ import mkdocs_gen_files from more_itertools import first_true, peekable -logger = logging.getLogger(__name__) -logging.basicConfig(level=logging.WARNING) +logger = logging.getLogger("mkdocs") -ENV_VAR = "AMLTK_DOC_RENDER_EXAMPLES" +RUN_EXAMPLES_ENV_VAR = "AMLTK_DOC_RENDER_EXAMPLES" @dataclass @@ -144,9 +143,10 @@ def should_execute(cls, *, name: str, runnable: bool) -> bool: if not runnable: return False - env_var = os.environ.get(ENV_VAR, None) - if env_var is None: + env_var = os.environ.get(RUN_EXAMPLES_ENV_VAR, "all") + if env_var in ("false", "", "0", "no", "off"): return False + if env_var == "all": return True @@ -270,6 +270,12 @@ def copy_section(self) -> str: ) +if os.environ.get(RUN_EXAMPLES_ENV_VAR, "all") in ("false", "", "0", "no", "off"): + logger.warning( + f"Env variable {RUN_EXAMPLES_ENV_VAR} not set - not running examples." + " Use `just docs-full` to run and render examples.", + ) + for path in sorted(Path("examples").rglob("*.py")): module_path = path.relative_to("examples").with_suffix("") doc_path = path.relative_to("examples").with_suffix(".md") diff --git a/docs/guides/optimization.md b/docs/guides/optimization.md index 31fb7ca3..fad24c77 100644 --- a/docs/guides/optimization.md +++ b/docs/guides/optimization.md @@ -52,7 +52,7 @@ s = Searchable( {"x": (-10.0, 10.0)}, name="my-searchable" ) -from amltk._doc import doc_print; doc_print(print, s) +from amltk._doc import doc_print; doc_print(print, s) # markdown-exec: hide ``` @@ -128,16 +128,14 @@ for _ in range(10): # Access the the trial's config x = trial.config["my-searchable:x"] - # Begin the trial - with trial.begin(): + try: score = poly(x) - - if trial.exception is None: - # Generate a success report - report: Trial.Report = trial.success(score=score) - else: + except ZeroDivisionError as e: # Generate a failed report (i.e. poly(x) raised divide by zero exception with x=0) - report: Trial.Report = trial.fail() + report = trial.fail(e) + else: + # Generate a success report + report = trial.success(score=score) # Store artifacts with the trial, using file extensions to infer how to store it trial.store({ "config.json": trial.config, "array.npy": [1, 2, 3] }) @@ -168,12 +166,6 @@ functionality to help you during optimization. The `.config` will contain name spaced parameters, in this case, `my-searchable:x`, based on the pipeline/search space you specified. -We also wrap the actual evaluation of the function in a [`with trial.begin():`][amltk.optimization.Trial.begin] which -will time and profile the evaluation of the function and handle any exceptions that occur within the block. - -If an exception occured in the `#!python with trial.begin():` block, any exception/traceback that occured will be -attached to [`.exception`][amltk.optimization.Trial.exception] and [`.traceback`][amltk.optimization.Trial.traceback]. - It's also quite typical to store artifacts with the trial, a common feature of things like TensorBoard, MLFlow, etc. We provide a primitive way to store artifacts with the trial using [`.store()`][amltk.optimization.Trial.store] which takes a dictionary of file names to file contents. The file extension is used to infer how to store the file, for example, @@ -198,23 +190,21 @@ at the end. from amltk.optimization import History, Trial, Metric from amltk.store import PathBucket -bucket = PathBucket("my-bucket") metric = Metric("score", minimize=False, bounds=(0, 5)) history = History() trials = [ - Trial(name="trial-1", config={"x": 1.0}, bucket=bucket, metrics=[metric]), - Trial(name="trial-2", config={"x": 2.0}, bucket=bucket, metrics=[metric]), - Trial(name="trial-3", config={"x": 3.0}, bucket=bucket, metrics=[metric]), + Trial.create(name="trial-1", config={"x": 1.0}, metrics=[metric]), + Trial.create(name="trial-2", config={"x": 2.0}, metrics=[metric]), + Trial.create(name="trial-3", config={"x": 3.0}, metrics=[metric]), ] for trial in trials: - with trial.begin(): - x = trial.config["x"] - if x >= 2: - report = trial.fail() - else: - report = trial.success(score=x) + x = trial.config["x"] + if x >= 2: + report = trial.fail() + else: + report = trial.success(score=x) history.add(report) @@ -223,6 +213,7 @@ print(df) best = history.best() print(best) +for t in trials: t.bucket.rmdir() # markdown-exec: hide ``` You can use the [`History.df()`][amltk.optimization.History.df] method to get a dataframe of the history and @@ -269,31 +260,19 @@ my_pipeline = ( from amltk._doc import doc_print; doc_print(print, my_pipeline) # markdown-exec: hide ``` -Next up, we need to define a simple target function we want to evaluate on. With that, we'll also -store our data, so that on each evaluate call, we load it in. This doesn't make much sense for a single in-process -call but when scaling up to using multiple processes or remote compute, this is a good practice to follow. For this -we use a [`PathBucket`][amltk.store.PathBucket] and get a [`StoredValue`][amltk.store.StoredValue] from it, basically -a reference to some object we can load back in later. +Next up, we need to define a simple target function we want to evaluate on. ```python exec="true" result="python" source="material-block" session="optimizing-an-sklearn-pipeline" -from sklearn.datasets import load_iris from sklearn.model_selection import cross_validate from amltk.optimization import Trial -from amltk.store import PathBucket, StoredValue +from amltk.store import Stored import numpy as np -# Load in our data -_X, _y = load_iris(return_X_y=True) - -# Store our data in a bucket -bucket = PathBucket("my-bucket") -bucket.update({"X.npy": _X, "y.npy": _y}) - def evaluate( trial: Trial, pipeline: Sequential, - X: StoredValue[str, np.ndarray], - y: StoredValue[str, np.ndarray], + X: Stored[np.ndarray], + y: Stored[np.ndarray], ) -> Trial.Report: # Configure our pipeline and build it sklearn_pipeline = ( @@ -303,11 +282,11 @@ def evaluate( ) # Load in our data - X = X.value() - y = y.value() + X = X.load() + y = y.load() # Use sklearns.cross_validate as our evaluator - with trial.begin(): + with trial.profile("cross-validate"): results = cross_validate(sklearn_pipeline, X, y, scoring="accuracy", cv=3, return_estimator=True) test_scores = results["test_score"] @@ -318,6 +297,26 @@ def evaluate( return trial.success(acc=mean_test_score) ``` +With that, we'll also store our data, so that on each evaluate call, we load it in. +This doesn't make much sense for a single in-process call but when scaling up to using +multiple processes or remote compute, this is a good practice to follow. + +For this we use a [`PathBucket`][amltk.store.PathBucket] and get +a [`Stored`][amltk.store.Stored] from it, a reference to some object we can `load()` back in later. + +```python exec="true" result="python" source="material-block" session="optimizing-an-sklearn-pipeline" +from sklearn.datasets import load_iris +from amltk.store import PathBucket + +# Load in our data +_X, _y = load_iris(return_X_y=True) + +# Store our data in a bucket +bucket = PathBucket("my-bucket") +stored_X = bucket["X.npy"].put(_X) +stored_y = bucket["y.npy"].put(_y) +``` + Lastly, we'll create our optimizer and run it. In this example, we'll use the [`SMACOptimizer`][amltk.optimization.optimizers.smac.SMACOptimizer] but you can refer to the [optimizer reference](../reference/optimization/optimizers.md) for other optimizers. For basic @@ -328,7 +327,6 @@ from amltk.optimization.optimizers.smac import SMACOptimizer from amltk.optimization import Metric, History metric = Metric("acc", minimize=False, bounds=(0, 1)) -bucket = PathBucket("my-bucket") optimizer = SMACOptimizer.create( space=my_pipeline, # Let it know what to optimize metrics=metric, # And let it know what to expect @@ -336,8 +334,6 @@ optimizer = SMACOptimizer.create( ) history = History() -stored_X = bucket["X.npy"].as_stored_value() -stored_y = bucket["y.npy"].as_stored_value() for _ in range(10): # Get a trial from the optimizer diff --git a/docs/guides/scheduling.md b/docs/guides/scheduling.md index 94e7b1e3..52890940 100644 --- a/docs/guides/scheduling.md +++ b/docs/guides/scheduling.md @@ -107,8 +107,8 @@ from amltk._doc import doc_print; doc_print(print, scheduler) # markdown-exec: ### Running the Scheduler You may have noticed from the above example that there are many events the scheduler will emit, -such as `@on_start` or `@on_future_done`. One particularly important one is -[`@on_start`][amltk.scheduling.Scheduler.on_start], an event to signal +such as `@start` or `@future-done`. One particularly important one is +[`@start`][amltk.scheduling.Scheduler.on_start], an event to signal the scheduler has started and is ready to accept tasks. ```python exec="true" source="material-block" html="True" @@ -123,7 +123,7 @@ from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fon ``` From the output, we can see that the `print_hello()` function was registered -to the event `@on_start`, but it was never called and no `#!python "hello"` was printed. +to the event `@start`, but it was never called and no `#!python "hello"` was printed. For this to happen, we actually have to [`run()`][amltk.scheduling.Scheduler.run] the scheduler. @@ -140,7 +140,7 @@ scheduler.run() from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fontsize="small") # markdown-exec: hide ``` -Now the output will show a little yellow number next to the `@on_start` +Now the output will show a little yellow number next to the `@start` and the `print_hello()`, indicating that event was triggered and the callback was called. @@ -181,7 +181,7 @@ defining these units of compute, it is beneficial to see how the `Scheduler` operates directly with `submit()`, without abstractions. In the below example, we will use the -[`@on_future_result`][amltk.scheduling.Scheduler.on_future_result] +[`@future-result`][amltk.scheduling.Scheduler.on_future_result] event to submit more compute once the previous computation has returned a result. ```python exec="true" source="material-block" html="True" hl_lines="10 13 17" @@ -221,47 +221,47 @@ for a complete list. !!! example "`@events`" - === "`@on_start`" + === "`@start`" ::: amltk.scheduling.Scheduler.on_start - === "`@on_future_result`" + === "`@future-result`" ::: amltk.scheduling.Scheduler.on_future_result - === "`@on_future_exception`" + === "`@future-exception`" ::: amltk.scheduling.Scheduler.on_future_exception - === "`@on_future_submitted`" + === "`@future-submitted`" ::: amltk.scheduling.Scheduler.on_future_submitted - === "`@on_future_done`" + === "`@future-done`" ::: amltk.scheduling.Scheduler.on_future_done - === "`@on_future_cancelled`" + === "`@future-cancelled`" ::: amltk.scheduling.Scheduler.on_future_cancelled - === "`@on_timeout`" + === "`@timeout`" ::: amltk.scheduling.Scheduler.on_timeout - === "`@on_stop`" + === "`@stop`" ::: amltk.scheduling.Scheduler.on_stop - === "`@on_finishing`" + === "`@finishing`" ::: amltk.scheduling.Scheduler.on_finishing - === "`@on_finished`" + === "`@finished`" ::: amltk.scheduling.Scheduler.on_finished - === "`@on_empty`" + === "`@empty`" ::: amltk.scheduling.Scheduler.on_empty @@ -272,7 +272,7 @@ it was emitted as the values. ### Controlling Callbacks There's a few parameters you can pass to any event subscriber -such as `@on_start` or `@on_future_result`. +such as `@start` or `@future-result`. These control the behavior of what happens when its event is fired and can be used to control the flow of your system. @@ -391,7 +391,7 @@ However, there are more explicit methods. scheduler to stop immediately with [`run(wait=False)`][amltk.scheduling.Scheduler.run]. You'll notice this in the event count of the Scheduler where the event - `@on_future_cancelled` was fired. + `@future-cancelled` was fired. ```python exec="true" source="material-block" html="True" hl_lines="13-15" import time @@ -420,10 +420,12 @@ However, there are more explicit methods. You can also tell the `Scheduler` to stop after a certain amount of time with the `timeout=` argument to [`run()`][amltk.scheduling.Scheduler.run]. - This will also trigger the `@on_timeout` event as seen in the `Scheduler` output. + This will also trigger the `@timeout` event as seen in the `Scheduler` output. ```python exec="true" source="material-block" html="True" hl_lines="20" import time + from asyncio import Future + from amltk.scheduling import Scheduler scheduler = Scheduler.with_processes(1) @@ -437,7 +439,7 @@ However, there are more explicit methods. def submit_calculations() -> None: scheduler.submit(expensive_function) - # The will endlessly loop the scheduler + # This will endlessly loop the scheduler @scheduler.on_future_done def submit_again(future: Future) -> None: if scheduler.running(): @@ -453,7 +455,7 @@ to clarify that there are two kinds of exceptions that can occur within the Sche The 1st kind that can happen is within some function submitted with [`submit()`][amltk.scheduling.Scheduler.submit]. When this happens, -the `@on_future_exception` will be emitted, passing the exception to the callback. +the `@future-exception` will be emitted, passing the exception to the callback. By default, the `Scheduler` will then raise the exception that occurred up to your program and end its computations. This is done by setting @@ -468,6 +470,7 @@ the default, but it also takes three other possibilities: One example is to just `stop()` the scheduler when some exception occurs. ```python exec="true" source="material-block" html="True" hl_lines="12-15" +from asyncio import Future from amltk.scheduling import Scheduler scheduler = Scheduler.with_processes(1) diff --git a/docs/hooks/cleanup_log_output.py b/docs/hooks/cleanup_log_output.py new file mode 100644 index 00000000..98c65421 --- /dev/null +++ b/docs/hooks/cleanup_log_output.py @@ -0,0 +1,51 @@ +"""The module is a hook which disables warnings and log messages which pollute the +doc build output. + +One possible downside is if one of these modules ends up giving an actual +error, such as OpenML failing to retrieve a dataset. I tried to make sure ERROR +log message are still allowed through. +""" +import logging +import warnings +from typing import Any + +import mkdocs +import mkdocs.plugins +import mkdocs.structure.pages + +from amltk.exceptions import AutomaticParameterWarning + +log = logging.getLogger("mkdocs") + + +@mkdocs.plugins.event_priority(-50) +def on_startup(**kwargs: Any): + # We get a load of deprecation warnings from SMAC + warnings.filterwarnings("ignore", category=DeprecationWarning) + + # We ignore AutoWarnings as our example tend to rely on + # a lot of the `"auto"` parameters + warnings.filterwarnings("ignore", category=AutomaticParameterWarning) + + # ConvergenceWarning from sklearn + warnings.filterwarnings("ignore", module="sklearn") + + +def on_pre_page( + page: mkdocs.structure.pages.Page, + config: Any, + files: Any, +) -> mkdocs.structure.pages.Page | None: + # NOTE: mkdocs says they're always normalized to be '/' seperated + # which means this should work on windows as well. + + # This error is actually demonstrated to the user which causes amltk + # to log the error. I don't know how to disable it for that one code cell + # put I can at least limit it to the file in which it's in. + if page.file.src_uri == "guides/scheduling.md": + scheduling_logger = logging.getLogger("amltk.scheduling.task") + scheduling_logger.setLevel(logging.CRITICAL) + + logging.getLogger("smac").setLevel(logging.ERROR) + logging.getLogger("openml").setLevel(logging.ERROR) + return page diff --git a/docs/hooks/debug_which_page_is_being_rendered.py b/docs/hooks/debug_which_page_is_being_rendered.py new file mode 100644 index 00000000..6ad00827 --- /dev/null +++ b/docs/hooks/debug_which_page_is_being_rendered.py @@ -0,0 +1,19 @@ +"""This module is a hook that when any code is being rendered, it will +print the path to the file being rendered. + +This makes it easier to identify which file is being rendered when an error happens.""" +import logging +from typing import Any + +import mkdocs +import mkdocs.plugins +import mkdocs.structure.pages + +log = logging.getLogger("mkdocs") + +def on_pre_page( + page: mkdocs.structure.pages.Page, + config: Any, + files: Any, +) -> mkdocs.structure.pages.Page | None: + log.info(f"{page.file.src_path}") diff --git a/docs/hooks/disable_markdown_exec.py b/docs/hooks/disable_markdown_exec.py new file mode 100644 index 00000000..8df35d76 --- /dev/null +++ b/docs/hooks/disable_markdown_exec.py @@ -0,0 +1,46 @@ +"""This disable markdown_exec based on an environment variable. +This speeds up the build of the docs for faster iteration. + +This is done by overwriting the module responsible for compiling and executing the code +by overriding the `exec(...)` global variable that is used to run the code. +We hijack it and print a helpful message about how to run the code cell instead. + +https://github.com/pawamoy/markdown-exec/blob/adff40b2928dbb2d22f27684e085f02d39a07291/src/markdown_exec/formatters/python.py#L42-L70 +""" +from __future__ import annotations + +import logging +import os +from typing import Any + +import mkdocs +import mkdocs.plugins +import mkdocs.structure.pages + +RUN_CODE_BLOCKS_ENV_VAR = "AMLTK_EXEC_DOCS" + +logger = logging.getLogger("mkdocs") + + +def _print_msg(compiled_code: Any, code_block_id: int, exec_globals: dict) -> None: + _print = exec_globals["print"] + _print( + f"Env variable {RUN_CODE_BLOCKS_ENV_VAR}=0 - No code to display." + "\nUse `just docs-code` (or `just docs-full` for examples) to run" + " the code block and display output." + ) + +truthy_values = {"yes", "on", "true", "1"} + +@mkdocs.plugins.event_priority(100) +def on_startup(**kwargs: Any): + run_code_blocks = os.environ.get(RUN_CODE_BLOCKS_ENV_VAR, "true") + if run_code_blocks.lower() not in truthy_values: + logger.warning( + f"Disabling markdown-exec due to {RUN_CODE_BLOCKS_ENV_VAR}={run_code_blocks}" + "\n.Use `just docs-full` to run and render examples.", + ) + from markdown_exec.formatters import python + + setattr(python, "exec_python", _print_msg) + diff --git a/docs/reference/metalearning/index.md b/docs/reference/metalearning/index.md index 5592cb96..066918f6 100644 --- a/docs/reference/metalearning/index.md +++ b/docs/reference/metalearning/index.md @@ -11,18 +11,344 @@ to help implement these methods. ## MetaFeatures -::: amltk.metalearning.metafeatures - options: - members: false +A [`MetaFeature`][amltk.metalearning.MetaFeature] is some +statistic about a dataset/task, that can be used to make datasets or +tasks more comparable, thus enabling meta-learning methods. + +Calculating meta-features of a dataset is quite straight foward. + +```python exec="true" source="material-block" result="python" title="Metafeatures" hl_lines="10" +import openml +from amltk.metalearning import compute_metafeatures + +dataset = openml.datasets.get_dataset( + 31, # credit-g + download_data=True, + download_features_meta_data=False, + download_qualities=False, +) +X, y, _, _ = dataset.get_data( + dataset_format="dataframe", + target=dataset.default_target_attribute, +) + +mfs = compute_metafeatures(X, y) + +print(mfs) +``` + +By default [`compute_metafeatures()`][amltk.metalearning.compute_metafeatures] will +calculate all the [`MetaFeature`][amltk.metalearning.MetaFeature] implemented, +iterating through their subclasses to do so. You can pass an explicit list +as well to `compute_metafeatures(X, y, features=[...])`. + +To implement your own is also quite straight forward: + +```python exec="true" source="material-block" result="python" title="Create Metafeature" hl_lines="10 11 12 13 14 15 16 17 18 19" +from amltk.metalearning import MetaFeature, compute_metafeatures +import openml +import pandas as pd + +dataset = openml.datasets.get_dataset( + 31, # credit-g + download_data=True, + download_features_meta_data=False, + download_qualities=False, +) +X, y, _, _ = dataset.get_data( + dataset_format="dataframe", + target=dataset.default_target_attribute, +) + +class TotalValues(MetaFeature): + + @classmethod + def compute( + cls, + x: pd.DataFrame, + y: pd.Series | pd.DataFrame, + dependancy_values: dict, + ) -> int: + return int(x.shape[0] * x.shape[1]) + +mfs = compute_metafeatures(X, y, features=[TotalValues]) +print(mfs) +``` + +As many metafeatures rely on pre-computed dataset statistics, and they do not +need to be calculated more than once, you can specify the dependancies of +a meta feature. When a metafeature would return something other than a single +value, i.e. a `dict` or a `pd.DataFrame`, we instead call those a +[`DatasetStatistic`][amltk.metalearning.DatasetStatistic]. These will +**not** be included in the result of [`compute_metafeatures()`][amltk.metalearning.compute_metafeatures]. +These `DatasetStatistic`s will only be calculated once on a call to `compute_metafeatures()` so +they can be re-used across all `MetaFeature`s that require that dependancy. + +```python exec="true" source="material-block" result="python" title="Metafeature Dependancy" hl_lines="10 11 12 13 14 15 16 17 18 19 20 23 26 35" +from amltk.metalearning import MetaFeature, DatasetStatistic, compute_metafeatures +import openml +import pandas as pd + +dataset = openml.datasets.get_dataset( + 31, # credit-g + download_data=True, + download_features_meta_data=False, + download_qualities=False, +) +X, y, _, _ = dataset.get_data( + dataset_format="dataframe", + target=dataset.default_target_attribute, +) + +class NAValues(DatasetStatistic): + """A mask of all NA values in a dataset""" + + @classmethod + def compute( + cls, + x: pd.DataFrame, + y: pd.Series | pd.DataFrame, + dependancy_values: dict, + ) -> pd.DataFrame: + return x.isna() + + +class PercentageNA(MetaFeature): + """The percentage of values missing""" + + dependencies = (NAValues,) + + @classmethod + def compute( + cls, + x: pd.DataFrame, + y: pd.Series | pd.DataFrame, + dependancy_values: dict, + ) -> int: + na_values = dependancy_values[NAValues] + n_na = na_values.sum().sum() + n_values = int(x.shape[0] * x.shape[1]) + return float(n_na / n_values) + +mfs = compute_metafeatures(X, y, features=[PercentageNA]) +print(mfs) +``` + +To view the description of a particular `MetaFeature`, you can call +[`.description()`][amltk.metalearning.DatasetStatistic.description] +on it. Otherwise you can access all of them in the following way: + +```python exec="true" source="tabbed-left" result="python" title="Metafeature Descriptions" hl_lines="4" +from pprint import pprint +from amltk.metalearning import metafeature_descriptions + +descriptions = metafeature_descriptions() +for name, description in descriptions.items(): + print("---") + print(name) + print("---") + print(" * " + description) +``` ## Dataset Distances +One common way to define how similar two datasets are is to compute some "similarity" +between them. This notion of "similarity" requires computing some features of a dataset +(**metafeatures**) first, such that we can numerically compute some distance function. + +Let's see how we can quickly compute the distance between some datasets with +[`dataset_distance()`][amltk.metalearning.dataset_distance]! + +```python exec="true" source="material-block" result="python" title="Dataset Distances P.1" session='dd' +import pandas as pd +import openml + +from amltk.metalearning import compute_metafeatures + +def get_dataset(dataset_id: int) -> tuple[pd.DataFrame, pd.Series]: + dataset = openml.datasets.get_dataset( + dataset_id, + download_data=True, + download_features_meta_data=False, + download_qualities=False, + ) + X, y, _, _ = dataset.get_data( + dataset_format="dataframe", + target=dataset.default_target_attribute, + ) + return X, y + +d31 = get_dataset(31) +d3 = get_dataset(3) +d4 = get_dataset(4) + +metafeatures_dict = { + "dataset_31": compute_metafeatures(*d31), + "dataset_3": compute_metafeatures(*d3), + "dataset_4": compute_metafeatures(*d4), +} -::: amltk.metalearning.dataset_distances - options: - members: false +metafeatures = pd.DataFrame(metafeatures_dict) +print(metafeatures) +``` + +Now we want to know which one of `#!python "dataset_3"` or `#!python "dataset_4"` is +more _similar_ to `#!python "dataset_31"`. + +```python exec="true" source="material-block" result="python" title="Dataset Distances P.2" session='dd' +from amltk.metalearning import dataset_distance + +target = metafeatures_dict.pop("dataset_31") +others = metafeatures_dict + +distances = dataset_distance(target, others, distance_metric="l2") +print(distances) +``` + +Seems like `#!python "dataset_3"` is some notion of closer to `#!python "dataset_31"` +than `#!python "dataset_4"`. However the scale of the metafeatures are not exactly all close. +For example, many lie between `#!python (0, 1)` but some like `instance_count` can completely +dominate the show. + +Lets repeat the computation but specify that we should apply a `#!python "minmax"` scaling +across the rows. + +```python exec="true" source="material-block" result="python" title="Dataset Distances P.3" session='dd' hl_lines="5" +distances = dataset_distance( + target, + others, + distance_metric="l2", + scaler="minmax" +) +print(distances) +``` + +Now `#!python "dataset_3"` is considered more similar but the difference between the two is a lot less +dramatic. In general, applying some scaling to values of different scales is required for metalearning. + +You can also use an [sklearn.preprocessing.MinMaxScaler][] or anything other scaler from scikit-learn +for that matter. + +```python exec="true" source="material-block" result="python" title="Dataset Distances P.3" session='dd' hl_lines="7" +from sklearn.preprocessing import MinMaxScaler + +distances = dataset_distance( + target, + others, + distance_metric="l2", + scaler=MinMaxScaler() +) +print(distances) +``` ## Portfolio Selection +A portfolio in meta-learning is to a set (ordered or not) of configurations +that maximize some notion of coverage across datasets or tasks. +The intuition here is that this also means that any new dataset is also covered! + +Suppose we have the given performances of some configurations across some datasets. +```python exec="true" source="material-block" result="python" title="Initial Portfolio" +import pandas as pd + +performances = { + "c1": [90, 60, 20, 10], + "c2": [20, 10, 90, 20], + "c3": [10, 20, 40, 90], + "c4": [90, 10, 10, 10], +} +portfolio = pd.DataFrame(performances, index=["dataset_1", "dataset_2", "dataset_3", "dataset_4"]) +print(portfolio) +``` + +If we could only choose `#!python k=3` of these configurations on some new given dataset, which ones would +you choose and in what priority? +Here is where we can apply [`portfolio_selection()`][amltk.metalearning.portfolio_selection]! + +The idea is that we pick a subset of these algorithms that maximise some value of utility for +the portfolio. We do this by adding a single configuration from the entire set, 1-by-1 until +we reach `k`, beginning with the empty portfolio. + +Let's see this in action! + +```python exec="true" source="material-block" result="python" title="Portfolio Selection" hl_lines="12 13 14 15 16" +import pandas as pd +from amltk.metalearning import portfolio_selection + +performances = { + "c1": [90, 60, 20, 10], + "c2": [20, 10, 90, 20], + "c3": [10, 20, 40, 90], + "c4": [90, 10, 10, 10], +} +portfolio = pd.DataFrame(performances, index=["dataset_1", "dataset_2", "dataset_3", "dataset_4"]) + +selected_portfolio, trajectory = portfolio_selection( + portfolio, + k=3, + scaler="minmax" +) + +print(selected_portfolio) +print() +print(trajectory) +``` + +The trajectory tells us which configuration was added at each time stamp along with the utility +of the portfolio with that configuration added. However we havn't specified how _exactly_ we defined the +utility of a given portfolio. We could define our own function to do so: + +```python exec="true" source="material-block" result="python" title="Portfolio Selection Custom" hl_lines="12 13 14 20" +import pandas as pd +from amltk.metalearning import portfolio_selection + +performances = { + "c1": [90, 60, 20, 10], + "c2": [20, 10, 90, 20], + "c3": [10, 20, 40, 90], + "c4": [90, 10, 10, 10], +} +portfolio = pd.DataFrame(performances, index=["dataset_1", "dataset_2", "dataset_3", "dataset_4"]) + +def my_function(p: pd.DataFrame) -> float: + # Take the maximum score for each dataset and then take the mean across them. + return p.max(axis=1).mean() + +selected_portfolio, trajectory = portfolio_selection( + portfolio, + k=3, + scaler="minmax", + portfolio_value=my_function, +) + +print(selected_portfolio) +print() +print(trajectory) +``` + +This notion of reducing across all configurations for a dataset and then aggregating these is common +enough that we can also directly just define these operations and we will perform the rest. + +```python exec="true" source="material-block" result="python" title="Portfolio Selection With Reduction" hl_lines="17 18" +import pandas as pd +import numpy as np +from amltk.metalearning import portfolio_selection + +performances = { + "c1": [90, 60, 20, 10], + "c2": [20, 10, 90, 20], + "c3": [10, 20, 40, 90], + "c4": [90, 10, 10, 10], +} +portfolio = pd.DataFrame(performances, index=["dataset_1", "dataset_2", "dataset_3", "dataset_4"]) + +selected_portfolio, trajectory = portfolio_selection( + portfolio, + k=3, + scaler="minmax", + row_reducer=np.max, # This is actually the default + aggregator=np.mean, # This is actually the default +) -::: amltk.metalearning.portfolio - options: - members: false +print(selected_portfolio) +print() +print(trajectory) +``` diff --git a/docs/reference/optimization/history.md b/docs/reference/optimization/history.md index 65c50039..b632aeab 100644 --- a/docs/reference/optimization/history.md +++ b/docs/reference/optimization/history.md @@ -22,18 +22,18 @@ def quadratic(x): history = History() trials = [ - Trial(name=f"trial_{count}", config={"x": i}, metrics=[loss]) + Trial.create(name=f"trial_{count}", config={"x": i}, metrics=[loss]) for count, i in enumerate(range(-5, 5)) ] reports = [] for trial in trials: - with trial.begin(): - x = trial.config["x"] - report = trial.success(loss=quadratic(x)) - history.add(report) + x = trial.config["x"] + report = trial.success(loss=quadratic(x)) + history.add(report) print(history.df()) +for trial in trials: trial.bucket.rmdir() # markdown-exec: hide ``` Typically, to use this inside of an optimization run, you would add the reports inside @@ -55,9 +55,8 @@ see the [optimization guide](../../guides/optimization.md) for more details. def target_function(trial: Trial) -> Trial.Report: x = trial.config["x"] - with trial.begin(): - cost = quadratic(x) - return trial.success(cost=cost) + cost = quadratic(x) + return trial.success(cost=cost) optimizer = SMACOptimizer(space=searchable, metrics=Metric("cost", minimize=True), seed=42) @@ -93,7 +92,7 @@ print(history[last_report.name]) ```python exec="true" source="material-block" result="python" session="ref-history" for report in history: - print(report.name, f"loss = {report.metrics['loss']}") + print(report.name, f"loss = {report.values['loss']}") ``` ```python exec="true" source="material-block" result="python" session="ref-history" diff --git a/docs/reference/optimization/metrics.md b/docs/reference/optimization/metrics.md index b6c9b379..0b5211f3 100644 --- a/docs/reference/optimization/metrics.md +++ b/docs/reference/optimization/metrics.md @@ -1,5 +1,31 @@ ## Metric +A [`Metric`][amltk.optimization.Metric] to let optimizers know how to +handle numeric values properly. -::: amltk.optimization.metric - options: - members: False +A `Metric` is defined by a `.name: str` and whether it is better to `.minimize: bool` +the metric. Further, you can specify `.bounds: tuple[lower, upper]` which can +help optimizers and other code know how to treat metrics. + +To easily convert between `loss` and +`score` of some value you can use the [`loss()`][amltk.optimization.Metric.loss] +and [`score()`][amltk.optimization.Metric.score] methods. + +If the metric is bounded, you can also make use of the +[`distance_to_optimal()`][amltk.optimization.Metric.distance_to_optimal] +function which is the distance to the optimal value. + +In the case of optimization, we provide a +[`normalized_loss()`][amltk.optimization.Metric.normalized_loss] which +normalized the value to be a minimization loss, that is also bounded +if the metric itself is bounded. + +```python exec="true" source="material-block" result="python" +from amltk.optimization import Metric + +acc = Metric("accuracy", minimize=False, bounds=(0, 100)) + +print(f"Distance: {acc.distance_to_optimal(90)}") # Distance to optimal. +print(f"Loss: {acc.loss(90)}") # Something that can be minimized +print(f"Score: {acc.score(90)}") # Something that can be maximized +print(f"Normalized loss: {acc.normalized_loss(90)}") # Normalized loss +``` diff --git a/docs/reference/optimization/optimizers.md b/docs/reference/optimization/optimizers.md index 777b5d38..7573447b 100644 --- a/docs/reference/optimization/optimizers.md +++ b/docs/reference/optimization/optimizers.md @@ -40,29 +40,28 @@ the [`Report`][amltk.optimization.Trial.Report], as this will be different for e to worry that the internal state of the optimizer is updated accordingly to these two _"Ask"_ and _"Tell"_ events and that's it. -For a reference on implementing an optimizer you can refer to any of the following: - - -## SMAC - -::: amltk.optimization.optimizers.smac - options: - members: false - -## NePs - -::: amltk.optimization.optimizers.neps - options: - members: false - -## Optuna - -::: amltk.optimization.optimizers.optuna - options: - members: false +For a reference on implementing an optimizer you can refer to any of the following +API Docs: +* [SMAC][amltk.optimization.optimizers.smac] +* [NePs][amltk.optimization.optimizers.neps] +* [Optuna][amltk.optimization.optimizers.optuna] +* [Random Search][amltk.optimization.optimizers.random_search] ## Integrating your own - -::: amltk.optimization.optimizer - options: - members: false +The base [`Optimizer`][amltk.optimization.optimizer.Optimizer] class, +defines the API we require optimizers to implement. + +* [`ask()`][amltk.optimization.optimizer.Optimizer.ask] - Ask the optimizer for a + new [`Trial`][amltk.optimization.trial.Trial] to evaluate. +* [`tell()`][amltk.optimization.optimizer.Optimizer.tell] - Tell the optimizer + the result of the sampled config. This comes in the form of a + [`Trial.Report`][amltk.optimization.trial.Trial.Report]. + +Additionally, to aid users from switching between optimizers, the +[`preferred_parser()`][amltk.optimization.optimizer.Optimizer.preferred_parser] +method should return either a `parser` function or a string that can be used +with [`node.search_space(parser=..._)`][amltk.pipeline.Node.search_space] to +extract the search space for the optimizer. + +Please refer to the code of [Random Search][amltk.optimization.optimizers.random_search] +on github for an example of how to implement a new optimizer. diff --git a/docs/reference/optimization/profiling.md b/docs/reference/optimization/profiling.md index 7a2b3147..5755a0cf 100644 --- a/docs/reference/optimization/profiling.md +++ b/docs/reference/optimization/profiling.md @@ -1,5 +1,77 @@ ## Profiling +Whether for debugging, building an AutoML system or for optimization +purposes, we provide a powerful [`Profiler`][amltk.profiling.Profiler], +which can generate a [`Profile`][amltk.profiling.Profile] of different sections +of code. This is particularly useful with [`Trial`][amltk.optimization.Trial]s, +so much so that we attach one to every `Trial` made as +[`trial.profiler`][amltk.optimization.Trial.profiler]. -:: amltk.profiling.profiler - options: - members: False +When done profiling, you can export all generated profiles as a dataframe using +[`profiler.df()`][amltk.profiling.Profiler.df]. + +```python exec="true" result="python" source="material-block" +from amltk.profiling import Profiler +import numpy as np + +profiler = Profiler() + +with profiler("loading-data"): + X = np.random.rand(1000, 1000) + +with profiler("training-model"): + model = np.linalg.inv(X) + +with profiler("predicting"): + y = model @ X + +print(profiler.df()) +``` + +You'll find these profiles as keys in the [`Profiler`][amltk.profiling.Profiler], +e.g. `#! python profiler["loading-data"]`. + +This will measure both the time it took within the block but also +the memory consumed before and after the block finishes, allowing +you to get an estimate of the memory consumed. + + +??? tip "Memory, vms vs rms" + + While not entirely accurate, this should be enough for info + for most use cases. + + Given the main process uses 2GB of memory and the process + then spawns a new process in which you are profiling, as you + might do from a [`Task`][amltk.scheduling.Task]. In this new + process you use another 2GB on top of that, then: + + * The virtual memory size (**vms**) will show 4GB as the + new process will share the 2GB with the main process and + have it's own 2GB. + + * The resident set size (**rss**) will show 2GB as the + new process will only have 2GB of it's own memory. + + +If you need to profile some iterator, like a for loop, you can use +[`Profiler.each()`][amltk.profiling.Profiler.each] which will measure +the entire loop but also each individual iteration. This can be useful +for iterating batches of a deep-learning model, splits of a cross-validator +or really any loop with work you want to profile. + +```python exec="true" result="python" source="material-block" +from amltk.profiling import Profiler +import numpy as np + +profiler = Profiler() + +for i in profiler.each(range(3), name="for-loop"): + X = np.random.rand(1000, 1000) + +print(profiler.df()) +``` + +Lastly, to disable profiling without editing much code, +you can always use [`Profiler.disable()`][amltk.profiling.Profiler.disable] +and [`Profiler.enable()`][amltk.profiling.Profiler.enable] to toggle +profiling on and off. diff --git a/docs/reference/optimization/trials.md b/docs/reference/optimization/trials.md index b71a41a0..cccdfbb9 100644 --- a/docs/reference/optimization/trials.md +++ b/docs/reference/optimization/trials.md @@ -1,11 +1,197 @@ -## Trial +## Trial and Report -::: amltk.optimization.trial - options: - members: False +[`Trial`][amltk.optimization.trial.Trial] - typically the output of +[`Optimizer.ask()`][amltk.optimization.Optimizer.ask], indicating +what the optimizer would like to evaluate next. +e provide a host of convenience methods attached to the `Trial` to make it easy to +save results, store artifacts, and more. -### History +[`Trial.Report`][amltk.optimization.trial.Trial.Report] - +the output of a [`trial.success(cost=...)`][amltk.optimization.trial.Trial.success] or +[`trial.fail(cost=...)`][amltk.optimization.trial.Trial.fail] call. +Provides an easy way to report back to the optimizer's +[`tell()`][amltk.optimization.Optimizer.tell]. -::: amltk.optimization.history - options: - members: False + + +### Trial +A [`Trial`][amltk.optimization.Trial] encapsulates some configuration +that needs to be evaluated. Typically, this is what is generated by an +[`Optimizer.ask()`][amltk.optimization.Optimizer.ask] call. + +- [`trial.success()`][amltk.optimization.Trial.success] to generate a +success [`Report`][amltk.optimization.Trial.Report], typically +passing what your chosen optimizer expects, e.g., `"loss"` or `"cost"`. + +- [`trial.fail()`][amltk.optimization.Trial.fail] to generate a +failure [`Report`][amltk.optimization.Trial.Report]. +If an exception is passed to `fail()`, it will be attached to the report along with any traceback it can deduce. +Each [`Optimizer`][amltk.optimization.Optimizer] will take care of what to do from here. + +```python exec="true" source="material-block" result="python" +from amltk.optimization import Trial, Metric +from amltk.store import PathBucket + +cost = Metric("cost", minimize=True) + +def target_function(trial: Trial) -> Trial.Report: + x = trial.config["x"] + y = trial.config["y"] + + with trial.profile("expensive-calculation"): + cost = x**2 - y + + return trial.success(cost=cost) + +# ... usually obtained from an optimizer +trial = Trial.create( + name="some-unique-name", + config={"x": 1, "y": 2}, + metrics=[cost] +) + +report = target_function(trial) +print(report.df()) +trial.bucket.rmdir() # markdown-exec: hide +``` + + +What you can return with [`trial.success()`][amltk.optimization.Trial.success] +or [`trial.fail()`][amltk.optimization.Trial.fail] depends on the +[`metrics`][amltk.optimization.Trial.metrics] of the trial. Typically, +an optimizer will provide the trial with the list of [metrics](../optimization/metrics.md) + +Some important properties are that they have a unique +[`.name`][amltk.optimization.Trial.name] given the optimization run, +a candidate [`.config`][amltk.optimization.Trial.config] to evaluate, +a possible [`.seed`][amltk.optimization.Trial.seed] to use, +and an [`.info`][amltk.optimization.Trial.info] object, which is the optimizer +specific information, if required by you. + +!!! tip "Reporting success (or failure)" + + When using the [`success()`][amltk.optimization.trial.Trial.success] + method, make sure to provide values for all metrics specified in the + [`.metrics`][amltk.optimization.Trial.metrics] attribute. + Usually these are set by the optimizer generating the `Trial`. + + If you instead report using [`fail()`][amltk.optimization.trial.Trial.success], + any metric not specified will be set to the + [`.worst`][amltk.optimization.Metric.worst] value of the metric. + + Each metric has a unique name, and it's crucial to use the correct names when + reporting success, otherwise an error will occur. + + ??? example "Reporting success for metrics" + + For example: + + ```python exec="true" result="python" source="material-block" + from amltk.optimization import Trial, Metric + + # Gotten from some optimizer usually, i.e. via `optimizer.ask()` + trial = Trial.create( + name="example_trial", + config={"param": 42}, + metrics=[Metric(name="accuracy", minimize=False)] + ) + + # Incorrect usage (will raise an error) + try: + report = trial.success(invalid_metric=0.95) + except ValueError as error: + print(error) + + # Correct usage + report = trial.success(accuracy=0.95) + trial.bucket.rmdir() # markdown-exec: hide + ``` + +If using [`Plugins`][amltk.scheduling.plugins.Plugin], they may insert +some extra objects in the [`.extra`][amltk.optimization.Trial.extras] dict. + +To profile your trial, you can wrap the logic you'd like to check with +[`trial.profile()`][amltk.optimization.Trial.profile], which will automatically +profile the block of code for memory before and after as well as time taken. + +If you've [`profile()`][amltk.optimization.Trial.profile]'ed any intervals, +you can access them by name through +[`trial.profiles`][amltk.optimization.Trial.profiles]. +Please see the [`Profiler`][amltk.profiling.profiler.Profiler] +for more. + +??? example "Profiling with a trial." + + ```python exec="true" source="material-block" result="python" title="profile" + from amltk.optimization import Trial + + trial = Trial.create(name="some-unique-name", config={}) + + # ... somewhere where you've begun your trial. + with trial.profile("some_interval"): + for work in range(100): + pass + + print(trial.profiler.df()) + trial.bucket.rmdir() # markdown-exec: hide + ``` + +You can also record anything you'd like into the +[`.summary`][amltk.optimization.Trial.summary], a plain `#!python dict` +or use [`trial.store()`][amltk.optimization.Trial.store] to store artifacts +related to the trial. + +??? tip "What to put in `.summary`?" + + For large items, e.g. predictions or models, these are highly advised to + [`.store()`][amltk.optimization.Trial.store] to disk, especially if using + a `Task` for multiprocessing. + + Further, if serializing the report using the + [`report.df()`][amltk.optimization.Trial.Report.df], + returning a single row, + or a [`History`][amltk.optimization.History] + with [`history.df()`][amltk.optimization.History.df] for a dataframe consisting + of many of the reports, then you'd likely only want to store things + that are scalar and can be serialised to disk by a pandas DataFrame. + + +### Report +The [`Trial.Report`][amltk.optimization.Trial.Report] encapsulates +a [`Trial`][amltk.optimization.Trial], its status and any metrics/exceptions +that may have occured. + +Typically you will not create these yourself, but instead use +[`trial.success()`][amltk.optimization.Trial.success] or +[`trial.fail()`][amltk.optimization.Trial.fail] to generate them. + +```python exec="true" source="material-block" result="python" +from amltk.optimization import Trial, Metric + +loss = Metric("loss", minimize=True) + +trial = Trial.create(name="trial", config={"x": 1}, metrics=[loss]) + +with trial.profile("fitting"): + # Do some work + # ... + report = trial.success(loss=1) + +print(report.df()) +trial.bucket.rmdir() # markdown-exec: hide +``` + +These reports are used to report back metrics to an +[`Optimizer`][amltk.optimization.Optimizer] +with [`Optimizer.tell()`][amltk.optimization.Optimizer.tell] but can also be +stored for your own uses. + +You can access the original trial with the +[`.trial`][amltk.optimization.Trial.Report.trial] attribute, and the +[`Status`][amltk.optimization.Trial.Status] of the trial with the +[`.status`][amltk.optimization.Trial.Report.status] attribute. + +You may also want to check out the [`History`][amltk.optimization.History] class +for storing a collection of `Report`s, allowing for an easier time to convert +them to a dataframe or perform some common Hyperparameter optimization parsing +of metrics. diff --git a/docs/reference/pipelines/pipeline.md b/docs/reference/pipelines/pipeline.md index 533917c4..f19edd19 100644 --- a/docs/reference/pipelines/pipeline.md +++ b/docs/reference/pipelines/pipeline.md @@ -1,15 +1,9 @@ -## Pieces of a Pipeline A pipeline is a collection of [`Node`][amltk.pipeline.node.Node]s that are connected together to form a directed acylic graph, where the nodes follow a parent-child relation ship. The purpose of these is to form some _abstract_ representation of what you want to search over/optimize and then build into a concrete object. -These [`Node`][amltk.pipeline.node.Node]s allow you to specific the function/object that -will be used there, it's search space and any configuration you want to explicitly apply. -There are various components listed below which gives these nodes extract syntatic meaning, -e.g. a [`Choice`](#choice) which represents some choice between it's children while -a [`Sequential`](#sequential) indicates that each child follows one after the other. - +## Key Operations Once a pipeline is created, you can perform 3 very critical operations on it: * [`search_space(parser=...)`][amltk.pipeline.node.Node.search_space] - This will return the @@ -21,14 +15,155 @@ Once a pipeline is created, you can perform 3 very critical operations on it: concrete object from a configured pipeline. You can find the reference to the [available builders here](../pipelines/builders.md). -### Components +## Node +A [`Node`][amltk.pipeline.node.Node] is the basic building block of a pipeline. +It contains various attributes, such as a + +- [`.name`][amltk.pipeline.node.Node.name] - The name of the node, which is used + to identify it in the pipeline. +- [`.item`][amltk.pipeline.node.Node.item] - The concrete object or some function to construct one +- [`.space`][amltk.pipeline.node.Node.space] - A search space to consider for this node +- [`.config`][amltk.pipeline.node.Node.config] - The specific configuration to use for this + node once `build` is called. +- [`.nodes`][amltk.pipeline.node.Node.nodes] - Other nodes that this node links to. + +To give syntactic meaning to these nodes, we have various subclasses. For example, +[`Sequential`][amltk.pipeline.components.Sequential] is a node where the order of the +`nodes` it contains matter, while a [`Component`][amltk.pipeline.components.Component] is a node +that can be used to parametrize and construct a concrete object, but does not lead to anything else. + +Each node type here is either a _leaf_ or a _branch_, where a _branch_ has children, while +while a _leaf_ does not. + +There various components are listed here: + +### [`Component`][amltk.pipeline.Component] - `leaf` +A parametrizable node type with some way to build an object, given a configuration. + +```python exec="true" source="material-block" html="true" +from amltk.pipeline import Component +from dataclasses import dataclass + +@dataclass +class Model: + x: float + +c = Component(Model, space={"x": (0.0, 1.0)}, name="model") +from amltk._doc import doc_print; doc_print(print, c) # markdown-exec: hide +``` + +### [`Searchable`][amltk.pipeline.Searchable] - `leaf` +A parametrizable node type that contains a search space that should be searched over, +but does not provide a concrete object. + +```python exec="true" source="material-block" html="true" +from amltk.pipeline import Searchable + +def run_script(mode, n): + # ... run some actual script + pass + +script_space = Searchable({"mode": ["orange", "blue", "red"], "n": (10, 100)}) +from amltk._doc import doc_print; doc_print(print, script_space) # markdown-exec: hide +``` + +### [`Fixed`][amltk.pipeline.Fixed] - `leaf` +A _non-parametrizable_ node type that contains an object that should be used as is. + +```python exec="true" source="material-block" html="true" +from amltk.pipeline import Component, Fixed, Sequential +from sklearn.ensemble import RandomForestClassifier + +estimator = RandomForestClassifier() +# ... pretend it was fit +fitted_estimator = Fixed(estimator) +from amltk._doc import doc_print; doc_print(print, fitted_estimator) # markdown-exec: hide +``` + +### [`Sequential`][amltk.pipeline.Sequential] - `branch` +A node type which signifies an order between its children, such as a sequential +set of preprocessing and estimator through which the data should flow. + +```python exec="true" source="material-block" html="true" +from amltk.pipeline import Component, Sequential +from sklearn.decomposition import PCA +from sklearn.ensemble import RandomForestClassifier + +pipeline = Sequential( + PCA(n_components=3), + Component(RandomForestClassifier, space={"n_estimators": (10, 100)}), + name="my_pipeline" +) +from amltk._doc import doc_print; doc_print(print, pipeline) # markdown-exec: hide +``` + +### [`Choice`][amltk.pipeline.Choice] - `branch` +A node type that signifies a choice between multiple children, usually chosen during configuration. + +```python exec="true" source="material-block" html="true" +from amltk.pipeline import Choice, Component +from sklearn.ensemble import RandomForestClassifier +from sklearn.neural_network import MLPClassifier + +rf = Component(RandomForestClassifier, space={"n_estimators": (10, 100)}) +mlp = Component(MLPClassifier, space={"activation": ["logistic", "relu", "tanh"]}) + +estimator_choice = Choice(rf, mlp, name="estimator") +from amltk._doc import doc_print; doc_print(print, estimator_choice) # markdown-exec: hide +``` + +### [`Split`][amltk.pipeline.Split] - `branch` +A node where the output of the previous node is split amongst its children, +according to it's configuration. + +```python exec="true" source="material-block" html="true" +from amltk.pipeline import Component, Split +from sklearn.impute import SimpleImputer +from sklearn.preprocessing import OneHotEncoder +from sklearn.compose import make_column_selector + +categorical_pipeline = [ + SimpleImputer(strategy="constant", fill_value="missing"), + OneHotEncoder(drop="first"), +] +numerical_pipeline = Component(SimpleImputer, space={"strategy": ["mean", "median"]}) + +preprocessor = Split( + {"categories": categorical_pipeline, "numerical": numerical_pipeline}, + name="my_split" +) +from amltk._doc import doc_print; doc_print(print, preprocessor) # markdown-exec: hide +``` + +### [`Join`][amltk.pipeline.Join] - `branch` +A node where the output of the previous node is sent all of its children. + +```python exec="true" source="material-block" html="true" +from amltk.pipeline import Join, Component +from sklearn.decomposition import PCA +from sklearn.feature_selection import SelectKBest + +pca = Component(PCA, space={"n_components": (1, 3)}) +kbest = Component(SelectKBest, space={"k": (1, 3)}) + +join = Join(pca, kbest, name="my_feature_union") +from amltk._doc import doc_print; doc_print(print, join) # markdown-exec: hide +``` + +## Syntax Sugar +You can connect these nodes together using either the constructors explicitly, +as shown in the examples. We also provide some index operators: + +* `>>` - Connect nodes together to form a [`Sequential`][amltk.pipeline.components.Sequential] +* `&` - Connect nodes together to form a [`Join`][amltk.pipeline.components.Join] +* `|` - Connect nodes together to form a [`Choice`][amltk.pipeline.components.Choice] -::: amltk.pipeline.components - options: - members: false +There is also another short-hand that you may find useful to know: -### Node +* `{comp1, comp2, comp3}` - This will automatically be converted into a + [`Choice`][amltk.pipeline.Choice] between the given components. +* `(comp1, comp2, comp3)` - This will automatically be converted into a + [`Join`][amltk.pipeline.Join] between the given components. +* `[comp1, comp2, comp3]` - This will automatically be converted into a + [`Sequential`][amltk.pipeline.Sequential] between the given components. -::: amltk.pipeline.node - options: - members: false diff --git a/docs/reference/pipelines/spaces.md b/docs/reference/pipelines/spaces.md index a89a63f5..7a1953c6 100644 --- a/docs/reference/pipelines/spaces.md +++ b/docs/reference/pipelines/spaces.md @@ -32,14 +32,6 @@ from amltk._doc import doc_print; doc_print(print, c) # markdown-exec: hide What follow's below is a list of supported parsers you could pass `parser=` to extract a search space representation. -## ConfigSpace - -::: amltk.pipeline.parsers.configspace - options: - members: false - -## Optuna - -::: amltk.pipeline.parsers.optuna - options: - members: false +* [`ConfigSpace`][amltk.pipeline.parsers.configspace] - A parser for the + [ConfigSpace](https://automl.github.io/ConfigSpace/master/) library. +* [`Optuna`][amltk.pipeline.parsers.optuna] - A parser specifically for optuna. diff --git a/docs/reference/scheduling/events.md b/docs/reference/scheduling/events.md index 53f61ac8..58ff103e 100644 --- a/docs/reference/scheduling/events.md +++ b/docs/reference/scheduling/events.md @@ -1,5 +1,239 @@ ## Events +One of the primary ways to respond to `@events` emitted +with by a [`Task`][amltk.scheduling.Task] +the [`Scheduler`][amltk.scheduling.Scheduler] +is through use of a **callback**. -::: amltk.scheduling.events - options: - members: False +The reason for this is to enable an easier time for API's to utilize +multiprocessing and remote compute from the `Scheduler`, without having +to burden users with knowing the details of how to use multiprocessing. + +A callback subscribes to some event using a decorator but can also be done in +a functional style if preferred. The below example is based on the +event [`@scheduler.on_start`][amltk.scheduling.Scheduler.on_start] but +the same applies to all events. + +=== "Decorators" + + ```python exec="true" source="material-block" html="true" + from amltk.scheduling import Scheduler + + scheduler = Scheduler.with_processes(1) + + @scheduler.on_start + def print_hello() -> None: + print("hello") + + scheduler.run() + from amltk._doc import doc_print; doc_print(print, scheduler, fontsize="small") # markdown-exec: hide + ``` + +=== "Functional" + + ```python exec="true" source="material-block" html="true" + from amltk.scheduling import Scheduler + + scheduler = Scheduler.with_processes(1) + + def print_hello() -> None: + print("hello") + + scheduler.on_start(print_hello) + scheduler.run() + from amltk._doc import doc_print; doc_print(print, scheduler, fontsize="small") # markdown-exec: hide + ``` + +There are a number of ways to customize the behaviour of these callbacks, notably +to control how often they get called and when they get called. + +??? tip "Callback customization" + + + === "`on('event', repeat=...)`" + + This will cause the callback to be called `repeat` times successively. + This is most useful in combination with + [`@scheduler.on_start`][amltk.scheduling.Scheduler.on_start] to launch + a number of tasks at the start of the scheduler. + + ```python exec="true" source="material-block" html="true" hl_lines="11" + from amltk import Scheduler + + N_WORKERS = 2 + + def f(x: int) -> int: + return x * 2 + from amltk._doc import make_picklable; make_picklable(f) # markdown-exec: hide + + scheduler = Scheduler.with_processes(N_WORKERS) + task = scheduler.task(f) + + @scheduler.on_start(repeat=N_WORKERS) + def on_start(): + task.submit(1) + + scheduler.run() + from amltk._doc import doc_print; doc_print(print, scheduler, fontsize="small") # markdown-exec: hide + ``` + + === "`on('event', max_calls=...)`" + + Limit the number of times a callback can be called, after which, the callback + will be ignored. + + ```python exec="true" source="material-block" html="True" hl_lines="13" + from asyncio import Future + from amltk.scheduling import Scheduler + + scheduler = Scheduler.with_processes(2) + + def expensive_function(x: int) -> int: + return x ** 2 + from amltk._doc import make_picklable; make_picklable(expensive_function) # markdown-exec: hide + + @scheduler.on_start + def submit_calculations() -> None: + scheduler.submit(expensive_function, 2) + + @scheduler.on_future_result(max_calls=3) + def print_result(future, result) -> None: + scheduler.submit(expensive_function, 2) + + scheduler.run() + from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fontsize="small") # markdown-exec: hide + ``` + + === "`on('event', when=...)`" + + A callable which takes no arguments and returns a `bool`. The callback + will only be called when the `when` callable returns `True`. + + Below is a rather contrived example, but it shows how we can use the + `when` parameter to control when the callback is called. + + ```python exec="true" source="material-block" html="True" hl_lines="8 12" + import random + from amltk.scheduling import Scheduler + + LOCALE = random.choice(["English", "German"]) + + scheduler = Scheduler.with_processes(1) + + @scheduler.on_start(when=lambda: LOCALE == "English") + def print_hello() -> None: + print("hello") + + @scheduler.on_start(when=lambda: LOCALE == "German") + def print_guten_tag() -> None: + print("guten tag") + + scheduler.run() + from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fontsize="small") # markdown-exec: hide + ``` + + === "`on('event', every=...)`" + + Only call the callback every `every` times the event is emitted. This + includes the first time it's called. + + ```python exec="true" source="material-block" html="True" hl_lines="6" + from amltk.scheduling import Scheduler + + scheduler = Scheduler.with_processes(1) + + # Print "hello" only every 2 times the scheduler starts. + @scheduler.on_start(every=2) + def print_hello() -> None: + print("hello") + + # Run the scheduler 5 times + scheduler.run() + scheduler.run() + scheduler.run() + scheduler.run() + scheduler.run() + from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fontsize="small") # markdown-exec: hide + ``` + +### Emitter, Subscribers and Events +This part of the documentation is not necessary to understand or use for AMLTK. People +wishing to build tools upon AMLTK may still find this a useful component to add to their +arsenal. + +The core of making this functionality work is the [`Emitter`][amltk.scheduling.events.Emitter]. +Its purpose is to have `@events` that can be emitted and subscribed to. Classes like the +[`Scheduler`][amltk.scheduling.Scheduler] and [`Task`][amltk.scheduling.Task] carry +around with them an `Emitter` to enable all of this functionality. + +Creating an `Emitter` is rather straight-forward, but we must also create +[`Events`][amltk.scheduling.events.Event] that people can subscribe to. + +```python +from amltk.scheduling import Emitter, Event +emitter = Emitter("my-emitter") + +event: Event[int] = Event("my-event") # (1)! + +@emitter.on(event) +def my_callback(x: int) -> None: + print(f"Got {x}!") + +emitter.emit(event, 42) # (2)! +``` + +1. The typing `#!python Event[int]` is used to indicate that the event will be emitting + an integer. This is not necessary, but it is useful for type-checking and + documentation. +2. The `#!python emitter.emit(event, 42)` is used to emit the event. This will call + all the callbacks registered for the event, i.e. `#!python my_callback()`. + +!!! warning "Independent Events" + + Given a single `Emitter` and a single instance of an `Event`, there is no way to + have different `@events` for callbacks. There are two options, both used extensively + in AMLTK. + + The first is to have different `Events` quite naturally, i.e. you distinguish + between different things that can happen. However, you often want to have different + objects emit the same `Event` but have different callbacks for each object. + + This makes most sense in the context of a `Task` the `Event` instances are shared as + class variables in the `Task` class, however a user likely want's to subscribe to + the `Event` for a specific instance of the `Task`. + + This is where the second option comes in, in which each object carries around its + own `Emitter` instance. This is how a user can subscribe to the same kind of `Event` + but individually for each `Task`. + + +However, to shield users from this and to create named access points for users to +subscribe to, we can use the [`Subscriber`][amltk.scheduling.events.Subscriber] class, +conveniently created by the [`Emitter.subscriber()`][amltk.scheduling.events.Emitter.subscriber] +method. + +```python +from amltk.scheduling import Emitter, Event +emitter = Emitter("my-emitter") + +class GPT: + + event: Event[str] = Event("my-event") + + def __init__(self) -> None: + self.on_answer: Subscriber[str] = emitter.subscriber(self.event) + + def ask(self, question: str) -> None: + emitter.emit(self.event, "hello world!") + +gpt = GPT() + +@gpt.on_answer +def print_answer(answer: str) -> None: + print(answer) + +gpt.ask("What is the conical way for an AI to greet someone?") +``` + +Typically these event based systems make little sense in a synchronous context, however +with the [`Scheduler`][amltk.scheduling.Scheduler] and [`Task`][amltk.scheduling.Task] +classes, they are used to enable a simple way to use multiprocessing and remote compute. diff --git a/docs/reference/scheduling/executors.md b/docs/reference/scheduling/executors.md index 5044903e..6e6d1e83 100644 --- a/docs/reference/scheduling/executors.md +++ b/docs/reference/scheduling/executors.md @@ -83,6 +83,9 @@ provide a robust and flexible framework for scheduling compute across workers. client = Client(...) executor = client.get_executor() scheduler = Scheduler(executor=executor) + + # Important to do if the program will continue! + client.close() ``` ### :simple-dask: `dask-jobqueue` diff --git a/docs/reference/scheduling/queue_monitor.md b/docs/reference/scheduling/queue_monitor.md index 42b72a37..92be0fb7 100644 --- a/docs/reference/scheduling/queue_monitor.md +++ b/docs/reference/scheduling/queue_monitor.md @@ -1,5 +1,64 @@ ## Queue Monitor +A [`QueueMonitor`][amltk.scheduling.queue_monitor.QueueMonitor] is a +monitor for the scheduler queue. -::: amltk.scheduling.queue_monitor - options: - members: False +This module contains a monitor for the scheduler queue. The monitor tracks the +queue state at every event emitted by the scheduler. The data can be converted +to a pandas DataFrame or plotted as a stacked barchart. + +!!! note "Monitoring Frequency" + + To prevent repeated polling, we sample the scheduler queue at every scheduler event. + This is because the queue is only modified upon one of these events. This means we + don't need to poll the queue at a fixed interval. However, if you need more fine + grained updates, you can add extra events/timings at which the monitor should + [`update()`][amltk.scheduling.queue_monitor.QueueMonitor.update]. + +!!! warning "Performance impact" + + If your tasks and callbacks are very fast (~sub 10ms), then the monitor has a + non-nelgible impact however for most use cases, this should not be a problem. + As anything, you should profile how much work the scheduler can get done, + with and without the monitor, to see if it is a problem for your use case. + +In the below example, we have a very fast running function that runs on repeat, +sometimes too fast for the scheduler to keep up, letting some futures buildup needing +to be processed. + +```python exec="true" source="material-block" result="python" session="queue-monitor" +import time +import matplotlib.pyplot as plt +from amltk.scheduling import Scheduler +from amltk.scheduling.queue_monitor import QueueMonitor + +def fast_function(x: int) -> int: + return x + 1 +from amltk._doc import make_picklable; make_picklable(fast_function) # markdown-exec: hide + +N_WORKERS = 2 +scheduler = Scheduler.with_processes(N_WORKERS) +monitor = QueueMonitor(scheduler) +task = scheduler.task(fast_function) + +@scheduler.on_start(repeat=N_WORKERS) +def start(): + task.submit(1) + +@task.on_result +def result(_, x: int): + if scheduler.running(): + task.submit(x) + +scheduler.run(timeout=1) +df = monitor.df() +print(df) +``` + +We can also [`plot()`][amltk.scheduling.queue_monitor.QueueMonitor.plot] the data as a +stacked barchart with a set interval. + +```python exec="true" source="material-block" html="true" session="queue-monitor" +fig, ax = plt.subplots() +monitor.plot(interval=(50, "ms")) +from io import StringIO; fig.tight_layout(); buffer = StringIO(); plt.savefig(buffer, format="svg"); print(buffer.getvalue()) # markdown-exec: hide +``` diff --git a/docs/reference/scheduling/scheduler.md b/docs/reference/scheduling/scheduler.md index d5a6a2f9..498c11aa 100644 --- a/docs/reference/scheduling/scheduler.md +++ b/docs/reference/scheduling/scheduler.md @@ -1,5 +1,284 @@ ## Scheduler +The [`Scheduler`][amltk.scheduling.Scheduler] uses +an [`Executor`][concurrent.futures.Executor], a builtin python native with +a `#!python submit(f, *args, **kwargs)` function to submit compute to +be compute else where, whether it be locally or remotely. -::: amltk.scheduling.scheduler - options: - members: False +The `Scheduler` is primarily used to dispatch compute to an `Executor` and +emit `@events`, which can trigger user callbacks. + +Typically you should not use the `Scheduler` directly for dispatching and +responding to computed functions, but rather use a [`Task`][amltk.scheduling.Task] + +??? note "Running in a Jupyter Notebook/Colab" + + If you are using a Jupyter Notebook, you likley need to use the following + at the top of your notebook: + + ```python + import nest_asyncio # Only necessary in Notebooks + nest_asyncio.apply() + + scheduler.run(...) + ``` + + This is due to the fact a notebook runs in an async context. If you do not + wish to use the above snippet, you can instead use: + + ```python + await scheduler.async_run(...) + ``` + +??? tip "Basic Usage" + + In this example, we create a scheduler that uses local processes as + workers. We then create a task that will run a function `fn` and submit it + to the scheduler. Lastly, a callback is registered to `@future-result` to print the + result when the compute is done. + + ```python exec="true" source="material-block" html="true" + from amltk.scheduling import Scheduler + + def fn(x: int) -> int: + return x + 1 + from amltk._doc import make_picklable; make_picklable(fn) # markdown-exec: hide + + scheduler = Scheduler.with_processes(1) + + @scheduler.on_start + def launch_the_compute(): + scheduler.submit(fn, 1) + + @scheduler.on_future_result + def callback(future, result): + print(f"Result: {result}") + + scheduler.run() + from amltk._doc import doc_print; doc_print(print, scheduler) # markdown-exec: hide + ``` + + The last line in the previous example called + [`scheduler.run()`][amltk.scheduling.Scheduler.run] is what starts the scheduler + running, in which it will first emit the `@start` event. This triggered the + callback `launch_the_compute()` which submitted the function `fn` with the + arguments `#!python 1`. + + The scheduler then ran the compute and waited for it to complete, emitting the + `@future-result` event when it was done successfully. This triggered the callback + `callback()` which printed the result. + + At this point, there is no more compute happening and no more events to respond to + so the scheduler will halt. + +??? example "`@events`" + + === "Scheduler Status Events" + + When the scheduler enters some important state, it will emit an event + to let you know. + + === "`@start`" + + ::: amltk.scheduling.Scheduler.on_start + options: + show_root_heading: False + show_root_toc_entry: False + + === "`@finishing`" + + ::: amltk.scheduling.Scheduler.on_finishing + options: + show_root_heading: False + show_root_toc_entry: False + + === "`@finished`" + + ::: amltk.scheduling.Scheduler.on_finished + options: + show_root_heading: False + show_root_toc_entry: False + + === "`@stop`" + + ::: amltk.scheduling.Scheduler.on_stop + options: + show_root_heading: False + show_root_toc_entry: False + + === "`@timeout`" + + ::: amltk.scheduling.Scheduler.on_timeout + options: + show_root_heading: False + show_root_toc_entry: False + + === "`@empty`" + + ::: amltk.scheduling.Scheduler.on_empty + options: + show_root_heading: False + show_root_toc_entry: False + + === "Submitted Compute Events" + + When any compute goes through the `Scheduler`, it will emit an event + to let you know. You should however prefer to use a + [`Task`][amltk.scheduling.Task] as it will emit specific events + for the task at hand, and not all compute. + + === "`@future-submitted`" + + ::: amltk.scheduling.Scheduler.on_future_submitted + options: + show_root_heading: False + show_root_toc_entry: False + + === "`@future-result`" + + ::: amltk.scheduling.Scheduler.on_future_result + options: + show_root_heading: False + show_root_toc_entry: False + + === "`@future-exception`" + + ::: amltk.scheduling.Scheduler.on_future_exception + options: + show_root_heading: False + show_root_toc_entry: False + + === "`@future-done`" + + ::: amltk.scheduling.Scheduler.on_future_done + options: + show_root_heading: False + show_root_toc_entry: False + + === "`@future-cancelled`" + + ::: amltk.scheduling.Scheduler.on_future_cancelled + options: + show_root_heading: False + show_root_toc_entry: False + + +??? tip "Common usages of `run()`" + + There are various ways to [`run()`][amltk.scheduling.Scheduler.run] the + scheduler, notably how long it should run with `timeout=` and also how + it should react to any exception that may have occurred within the `Scheduler` + itself or your callbacks. + + Please see the [`run()`][amltk.scheduling.Scheduler.run] API doc for more + details and features, however we show two common use cases of using the `timeout=` + parameter. + + You can render a live display using [`run(display=...)`][amltk.scheduling.Scheduler.run]. + This require [`rich`](https://github.com/Textualize/rich) to be installed. You + can install this with `#!bash pip install rich` or `#!bash pip install amltk[rich]`. + + + === "`run(timeout=...)`" + + You can tell the `Scheduler` to stop after a certain amount of time + with the `timeout=` argument to [`run()`][amltk.scheduling.Scheduler.run]. + + This will also trigger the `@timeout` event as seen in the `Scheduler` output. + + ```python exec="true" source="material-block" html="True" hl_lines="19" + import time + from asyncio import Future + + from amltk.scheduling import Scheduler + + scheduler = Scheduler.with_processes(1) + + def expensive_function() -> int: + time.sleep(0.1) + return 42 + from amltk._doc import make_picklable; make_picklable(expensive_function) # markdown-exec: hide + + @scheduler.on_start + def submit_calculations() -> None: + scheduler.submit(expensive_function) + + # This will endlessly loop the scheduler + @scheduler.on_future_done + def submit_again(future: Future) -> None: + if scheduler.running(): + scheduler.submit(expensive_function) + + scheduler.run(timeout=1) # End after 1 second + from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fontsize="small") # markdown-exec: hide + ``` + + === "`run(timeout=..., wait=False)`" + + By specifying that the `Scheduler` should not wait for ongoing tasks + to finish, the `Scheduler` will attempt to cancel and possibly terminate + any running tasks. + + ```python exec="true" source="material-block" html="True" + import time + from amltk.scheduling import Scheduler + + scheduler = Scheduler.with_processes(1) + + def expensive_function() -> None: + time.sleep(10) + + from amltk._doc import make_picklable; make_picklable(expensive_function) # markdown-exec: hide + + @scheduler.on_start + def submit_calculations() -> None: + scheduler.submit(expensive_function) + + scheduler.run(timeout=1, wait=False) # End after 1 second + from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fontsize="small") # markdown-exec: hide + ``` + + ??? info "Forcibly Terminating Workers" + + As an `Executor` does not provide an interface to forcibly + terminate workers, we provide `Scheduler(terminate=...)` as a custom + strategy for cleaning up a provided executor. It is not possible + to terminate running thread based workers, for example using + `ThreadPoolExecutor` and any Executor using threads to spawn + tasks will have to wait until all running tasks are finish + before python can close. + + It's likely `terminate` will trigger the `EXCEPTION` event for + any tasks that are running during the shutdown, **not*** + a cancelled event. This is because we use a + [`Future`][concurrent.futures.Future] + under the hood and these can not be cancelled once running. + However there is no guarantee of this and is up to how the + `Executor` handles this. + +??? example "Scheduling something to be run later" + + You can schedule some function to be run later using the + [`#!python scheduler.call_later()`][amltk.scheduling.Scheduler.call_later] method. + + !!! note + + This does not run the function in the background, it just schedules some + function to be called later, where you could perhaps then use submit to + scheduler a [`Task`][amltk.scheduling.Task] to run the function in the + background. + + ```python exec="true" source="material-block" result="python" + from amltk.scheduling import Scheduler + + scheduler = Scheduler.with_processes(1) + + def fn() -> int: + print("Ending now!") + scheduler.stop() + + @scheduler.on_start + def schedule_fn() -> None: + scheduler.call_later(1, fn) + + scheduler.run(end_on_empty=False) + ``` diff --git a/docs/reference/scheduling/task.md b/docs/reference/scheduling/task.md index f7fdd111..51b0ccbe 100644 --- a/docs/reference/scheduling/task.md +++ b/docs/reference/scheduling/task.md @@ -1,5 +1,88 @@ ## Tasks +A [`Task`][amltk.scheduling.task.Task] is a unit of work that can be scheduled by the +[`Scheduler`][amltk.scheduling.Scheduler]. -::: amltk.scheduling.task - options: - members: False +It is defined by its `function=` to call. Whenever a `Task` +has its [`submit()`][amltk.scheduling.task.Task.submit] method called, +the function will be dispatched to run by a `Scheduler`. + +When a task has returned, either successfully, or with an exception, +it will emit `@events` to indicate so. You can subscribe to these events +with callbacks and act accordingly. + + +??? example "`@events`" + + Check out the `@events` reference + for more on how to customize these callbacks. You can also take a look + at the API of [`on()`][amltk.scheduling.task.Task.on] for more information. + + === "`@on-result`" + + ::: amltk.scheduling.task.Task.on_result + options: + show_root_heading: False + show_root_toc_entry: False + + === "`@on-exception`" + + ::: amltk.scheduling.task.Task.on_exception + options: + show_root_heading: False + show_root_toc_entry: False + + === "`@on-done`" + + ::: amltk.scheduling.task.Task.on_done + options: + show_root_heading: False + show_root_toc_entry: False + + === "`@on-submitted`" + + ::: amltk.scheduling.task.Task.on_submitted + options: + show_root_heading: False + show_root_toc_entry: False + + === "`@on-cancelled`" + + ::: amltk.scheduling.task.Task.on_cancelled + options: + show_root_heading: False + show_root_toc_entry: False + +??? tip "Usage" + + The usual way to create a task is with + [`Scheduler.task()`][amltk.scheduling.scheduler.Scheduler.task], + where you provide the `function=` to call. + + ```python exec="true" source="material-block" html="true" + from amltk import Scheduler + from asyncio import Future + + def f(x: int) -> int: + return x * 2 + from amltk._doc import make_picklable; make_picklable(f) # markdown-exec: hide + + scheduler = Scheduler.with_processes(2) + task = scheduler.task(f) + + @scheduler.on_start + def on_start(): + task.submit(1) + + @task.on_result + def on_result(future: Future[int], result: int): + print(f"Task {future} returned {result}") + + scheduler.run() + from amltk._doc import doc_print; doc_print(print, scheduler) # markdown-exec: hide + ``` + + If you'd like to simply just call the original function, without submitting it to + the scheduler, you can always just call the task directly, i.e. `#!python task(1)`. + +You can also provide [`Plugins`][amltk.scheduling.plugins.Plugin] to the task, +to modify tasks, add functionality and add new events. diff --git a/docs/stylesheets/custom.css b/docs/stylesheets/custom.css index 7da3de53..a7782f17 100644 --- a/docs/stylesheets/custom.css +++ b/docs/stylesheets/custom.css @@ -1,110 +1,25 @@ -[data-md-color-scheme="default"] { - --doc-label-instance-attribute-fg-color: #0079ff; - --doc-label-property-fg-color: #00dfa2; - --doc-label-class-attribute-fg-color: #d1b619; - --doc-label-dataclass-fg-color: #ff0060; - - --doc-label-instance-attribute-bg-color: #0079ff1a; - --doc-label-property-bg-color: #00dfa21a; - --doc-label-class-attribute-bg-color: #d1b6191a; - --doc-label-dataclass-bg-color: #ff00601a; -} - -[data-md-color-scheme="slate"] { - --doc-label-instance-attribute-fg-color: #963fb8; - --doc-label-property-fg-color: #6d67e4; - --doc-label-class-attribute-fg-color: #46c2cb; - --doc-label-dataclass-fg-color: #f2f7a1; - - --doc-label-instance-attribute-bg-color: #963fb81a; - --doc-label-property-bg-color: #6d67e41a; - --doc-label-class-attribute-bg-color: #46c2cb1a; - --doc-label-dataclass-bg-color: #f2f7a11a; -} -:root { - --md-tooltip-width: 500px; -} - -.doc.doc-label.doc-label-instance-attribute code { - background-color: var(--doc-label-instance-attribute-bg-color); - color: var(--doc-label-instance-attribute-fg-color); -} -.doc.doc-label.doc-label-class-attribute code { - background-color: var(--doc-label-class-attribute-bg-color); - color: var(--doc-label-class-attribute-fg-color); -} -.doc.doc-label.doc-label-classmethod code { - background-color: var(--doc-label-class-attribute-bg-color); - color: var(--doc-label-classattribute-fg-color); -} -.doc.doc-label.doc-label-property code { - background-color: var(--doc-label-property-bg-color); - color: var(--doc-label-property-fg-color); -} -.doc.doc-label.doc-label-dataclass code { - background-color: var(--doc-label-dataclass-bg-color); - color: var(--doc-label-dataclass-fg-color); -} -.doc.doc-label code { - font-weight: bold; -} -.doc.doc-labels { - margin-left: 10px; - border-radius: 10px; - padding-top: 1px; - padding-bottom: 1px; - padding-right: 10px; - padding-left: 10px; -} - - -body[data-md-color-primary="black"] .excalidraw svg { - filter: invert(100%) hue-rotate(180deg); -} - -body[data-md-color-primary="black"] .excalidraw svg rect { - fill: transparent; -} - -h2.doc.doc-heading { - border-top: 2px solid rgba(0, 0, 0); - padding-top: 48px; -} - +/* If anything is breaking from this css, please feel free to +* remove it. +*/ -/* Highlight None inside code blocks */ +/* Highlight None with color inside code blocks */ code.highlight.language-python span.kc { color: var(--md-code-hl-keyword-color); } - - -h3.doc.doc-heading { - border-top: 1px solid rgba(0, 0, 0, 0.2); - padding-top: 48px; -} - -a.md-nav__link--passed { - padding-left: 10px; -} - -a.md-nav__link--active { - padding-left: 10px; - font-weight: bold; - border-left: .05rem solid var(--md-accent-fg-color); -} - -h3 .doc-heading code { - font-size: 16px; -} - -/* -.doc.doc-object.doc-class h2.doc.doc-heading { - text-align: center; +/* Make tool tip annotations wider */ +:root { + --md-tooltip-width: 500px; } -*/ - -.doc-heading code { - font-weight: normal; - font-family: "Roboto Mono", "SFMono-Regular", Consolas, "Courier New", Courier, - monospace; +/* api doc attribute cards */ +div.doc-class > div.doc-contents > div.doc-children > div.doc-object { + padding-right: 20px; + padding-left: 20px; + border-radius: 15px; + margin-top: 20px; + margin-bottom: 20px; + box-shadow: 2px 2px 2px rgba(0, 0, 0, 0.2); + margin-right: 0px; + border-color: rgba(0, 0, 0, 0.2); + border-width: 1px; + border-style: solid; } diff --git a/examples/dask-jobqueue.py b/examples/dask-jobqueue.py index 84bd0dc1..5953c959 100644 --- a/examples/dask-jobqueue.py +++ b/examples/dask-jobqueue.py @@ -17,6 +17,7 @@ efficient for this workload with ~32 cores. """ +import traceback from typing import Any import openml @@ -101,12 +102,13 @@ def target_function(trial: Trial, _pipeline: Node) -> Trial.Report: sklearn_pipeline = _pipeline.configure(trial.config).build("sklearn") - with trial.begin(): - sklearn_pipeline.fit(X_train, y_train) - - if trial.exception: - trial.store({"exception.txt": f"{trial.exception}\n {trial.traceback}"}) - return trial.fail() + try: + with trial.profile("fit"): + sklearn_pipeline.fit(X_train, y_train) + except Exception as e: + tb = traceback.format_exc() + trial.store({"exception.txt": f"{e}\n {tb}"}) + return trial.fail(e, tb) with trial.profile("predictions"): train_predictions = sklearn_pipeline.predict(X_train) diff --git a/examples/hpo.py b/examples/hpo.py index aca758ee..c22788c7 100644 --- a/examples/hpo.py +++ b/examples/hpo.py @@ -72,8 +72,14 @@ def get_dataset( Sequential(name="Pipeline") >> Split( { - "categorical": [SimpleImputer(strategy="constant", fill_value="missing"), OneHotEncoder(drop="first")], - "numerical": Component(SimpleImputer, space={"strategy": ["mean", "median"]}), + "categorical": [ + SimpleImputer(strategy="constant", fill_value="missing"), + OneHotEncoder(drop="first"), + ], + "numerical": Component( + SimpleImputer, + space={"strategy": ["mean", "median"]}, + ), }, name="feature_preprocessing", ) @@ -102,34 +108,35 @@ def get_dataset( from sklearn.metrics import accuracy_score from amltk.optimization import Trial +from amltk.store import PathBucket -def target_function(trial: Trial, _pipeline: Node) -> Trial.Report: +def target_function( + trial: Trial, + _pipeline: Node, + data_bucket: PathBucket, +) -> Trial.Report: trial.store({"config.json": trial.config}) # Load in data with trial.profile("data-loading"): X_train, X_val, X_test, y_train, y_val, y_test = ( - trial.bucket["X_train.csv"].load(), - trial.bucket["X_val.csv"].load(), - trial.bucket["X_test.csv"].load(), - trial.bucket["y_train.npy"].load(), - trial.bucket["y_val.npy"].load(), - trial.bucket["y_test.npy"].load(), + data_bucket["X_train.csv"].load(), + data_bucket["X_val.csv"].load(), + data_bucket["X_test.csv"].load(), + data_bucket["y_train.npy"].load(), + data_bucket["y_val.npy"].load(), + data_bucket["y_test.npy"].load(), ) # Configure the pipeline with the trial config before building it. sklearn_pipeline = _pipeline.configure(trial.config).build("sklearn") - # Fit the pipeline, indicating when you want to start the trial timing and error - # catchnig. - with trial.begin(): - sklearn_pipeline.fit(X_train, y_train) - - # If an exception happened, we use `trial.fail` to indicate that the - # trial failed - if trial.exception: - trial.store({"exception.txt": f"{trial.exception}\n {trial.traceback}"}) - return trial.fail() + # Fit the pipeline, indicating when you want to start the trial timing + try: + with trial.profile("fit"): + sklearn_pipeline.fit(X_train, y_train) + except Exception as e: + return trial.fail(e) # Make our predictions with the model with trial.profile("predictions"): @@ -187,7 +194,8 @@ def target_function(trial: Trial, _pipeline: Node) -> Trial.Report: X_test, y_test = data["test"] bucket = PathBucket("example-hpo", clean=True, create=True) -bucket.store( +data_bucket = bucket / "data" +data_bucket.store( { "X_train.csv": X_train, "X_val.csv": X_val, @@ -200,6 +208,7 @@ def target_function(trial: Trial, _pipeline: Node) -> Trial.Report: print(bucket) print(dict(bucket)) +print(dict(data_bucket)) """ ### Setting up the Scheduler, Task and Optimizer We use the [`Scheduler.with_processes`][amltk.scheduling.Scheduler.with_processes] @@ -255,14 +264,14 @@ def target_function(trial: Trial, _pipeline: Node) -> Trial.Report: def launch_initial_tasks() -> None: """When we start, launch `n_workers` tasks.""" trial = optimizer.ask() - task.submit(trial, _pipeline=pipeline) + task.submit(trial, _pipeline=pipeline, data_bucket=data_bucket) """ When a [`Task`][amltk.Trial] returns and we get a report, i.e. with [`task.success()`][amltk.optimization.Trial.success] or [`task.fail()`][amltk.optimization.Trial.fail], the `task` will fire off the -callbacks registered with [`@on_result`][amltk.Task.on_result]. +callbacks registered with [`@result`][amltk.Task.on_result]. We can use these to add callbacks that get called when these events happen. Here we use it to update the optimizer with the report we got. @@ -300,12 +309,13 @@ def add_to_history(_, report: Trial.Report) -> None: a report. This will launch a new task as soon as one finishes. """ + @task.on_result def launch_another_task(*_: Any) -> None: """When we get a report, evaluate another trial.""" if scheduler.running(): trial = optimizer.ask() - task.submit(trial, _pipeline=pipeline) + task.submit(trial, _pipeline=pipeline, data_bucket=data_bucket) """ diff --git a/examples/hpo_with_ensembling.py b/examples/hpo_with_ensembling.py index c1f74518..aac8674d 100644 --- a/examples/hpo_with_ensembling.py +++ b/examples/hpo_with_ensembling.py @@ -37,6 +37,7 @@ from __future__ import annotations import shutil +import traceback from asyncio import Future from collections.abc import Callable from dataclasses import dataclass @@ -214,18 +215,18 @@ def target_function( pipeline = pipeline.configure(trial.config) # (2)! sklearn_pipeline = pipeline.build("sklearn") # - with trial.begin(): # (3)! - sklearn_pipeline.fit(X_train, y_train) - - if trial.exception: + try: + with trial.profile("fit"): # (3)! + sklearn_pipeline.fit(X_train, y_train) + except Exception as e: + tb = traceback.format_exc() trial.store( { - "exception.txt": str(trial.exception), + "exception.txt": str(e), "config.json": dict(trial.config), - "traceback.txt": str(trial.traceback), + "traceback.txt": str(tb), }, ) - return trial.fail() # (4)! # Make our predictions with the model @@ -301,8 +302,7 @@ def create_ensemble( return Ensemble({}, [], {}) validation_predictions = { - report.name: report.retrieve("val_probabilities.npy", where=bucket) - for report in history + report.name: report.retrieve("val_probabilities.npy") for report in history } targets = bucket["y_val.npy"].load() @@ -321,10 +321,7 @@ def _score(_targets: np.ndarray, ensembled_probabilities: np.ndarray) -> float: seed=seed, # ) # - configs = { - name: history.find(name).retrieve("config.json", where=bucket) - for name in weights - } + configs = {name: history[name].retrieve("config.json") for name in weights} return Ensemble(weights=weights, trajectory=trajectory, configs=configs) @@ -360,14 +357,14 @@ def _score(_targets: np.ndarray, ensembled_probabilities: np.ndarray) -> float: seed=seed, ) # (4)! -task = scheduler.task(target_function) # (6)! -ensemble_task = scheduler.task(create_ensemble) # (7)! +task = scheduler.task(target_function) # (5)! +ensemble_task = scheduler.task(create_ensemble) # (6)! trial_history = History() ensembles: list[Ensemble] = [] -@scheduler.on_start # (8)! +@scheduler.on_start # (7)! def launch_initial_tasks() -> None: """When we start, launch `n_workers` tasks.""" trial = optimizer.ask() @@ -427,7 +424,7 @@ def run_last_ensemble_task() -> None: if __name__ == "__main__": - scheduler.run(timeout=5, wait=True) # (9)! + scheduler.run(timeout=5, wait=True) # (8)! print("Trial history:") history_df = trial_history.df() @@ -440,22 +437,21 @@ def run_last_ensemble_task() -> None: # 1. We use `#!python get_dataset()` defined earlier to load the # dataset. # 2. We use [`store()`][amltk.store.Bucket.store] to store the data in the bucket, with -# each key being the name of the file and the value being the data. +# each key being the name of the file and the value being the data. # 3. We use [`Scheduler.with_processes()`][amltk.scheduling.Scheduler.with_processes] # create a [`Scheduler`][amltk.scheduling.Scheduler] that runs everything # in a different process. You can of course use a different backend if you want. # 4. We use [`SMACOptimizer.create()`][amltk.optimization.optimizers.smac.SMACOptimizer.create] to create a # [`SMACOptimizer`][amltk.optimization.optimizers.smac.SMACOptimizer] given the space from the pipeline # to optimize over. -# 6. We create a [`Task`][amltk.scheduling.Task] that will run our objective, passing -# in the function to run and the scheduler for where to run it -# 7. We use [`task()`][amltk.scheduling.Task] to create a -# [`Task`][amltk.scheduling.Task] -# for the `create_ensemble` method above. This will also run in parallel with the hpo -# trials if using a non-sequential scheduling mode. -# 8. We use `@scheduler.on_start()` hook to register a +# 5. We create a [`Task`][amltk.scheduling.Task] that will run our objective, passing +# in the function to run and the scheduler for where to run it +# 6. We use [`task()`][amltk.scheduling.Task] to create a +# [`Task`][amltk.scheduling.Task] for the `create_ensemble` method above. +# This will also run in parallel with the hpo trials if using a non-sequential scheduling mode. +# 7. We use `@scheduler.on_start()` hook to register a # callback that will be called when the scheduler starts. We can use the # `repeat` argument to make sure it's called many times if we want. -# 9. We use [`Scheduler.run()`][amltk.scheduling.Scheduler.run] to run the scheduler. +# 8. We use [`Scheduler.run()`][amltk.scheduling.Scheduler.run] to run the scheduler. # Here we set it to run briefly for 5 seconds and wait for remaining tasks to finish # before continuing. diff --git a/examples/pytorch-example.py b/examples/pytorch-example.py new file mode 100644 index 00000000..abf535e6 --- /dev/null +++ b/examples/pytorch-example.py @@ -0,0 +1,263 @@ +"""Script for building and evaluating an example of PyTorch MLP model on MNIST dataset. +The script defines functions for constructing a neural network model from a pipeline, +training the model, and evaluating its performance. + +References: +- PyTorch MNIST example: https://github.com/pytorch/examples/blob/main/mnist/main.py +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +import torch.nn.functional as f +from torch import nn, optim +from torch.optim.lr_scheduler import StepLR +from torchvision import datasets, transforms + +from amltk import Choice, Component, Metric, Sequential +from amltk.optimization.optimizers.smac import SMACOptimizer +from amltk.pytorch import ( + MatchChosenDimensions, + MatchDimensions, + build_model_from_pipeline, +) + +if TYPE_CHECKING: + from amltk import Node, Trial + +from rich import print + + +def test( + model: nn.Module, + device: torch.device, + test_loader: torch.utils.data.DataLoader, +) -> tuple[float, float]: + """Evaluate the performance of the model on the test dataset. + + Args: + model (nn.Module): The model to be evaluated. + device (torch.device): The device to use for evaluation. + test_loader (torch.utils.data.DataLoader): DataLoader for the test dataset. + + Returns: + tuple[float, float]: Test loss and accuracy. + """ + model.eval() + test_loss = 0.0 + correct = 0.0 + with torch.no_grad(): + for _test_data, _test_target in test_loader: + test_data, test_target = _test_data.to(device), _test_target.to(device) + output = model(test_data) + test_loss += f.nll_loss(output, test_target, reduction="sum").item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(test_target.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + accuracy = 100.0 * correct / len(test_loader.dataset) + return float(test_loss), float(accuracy) + + +def eval_configuration( + trial: Trial, + pipeline: Node, + device: str = "cpu", # Change if you have a GPU + epochs: int = 1, # Fixed for now + lr: float = 0.1, # Fixed for now + gamma: float = 0.7, # Fixed for now + batch_size: int = 64, # Fixed for now + log_interval: int = 10, # Fixed for now +) -> Trial.Report: + """Evaluates a configuration within the given trial. + + This function trains a model based on the provided pipeline and hyperparameters, + evaluates its performance, and returns a report containing the evaluation results. + + Args: + trial: The trial object for storing trial-specific information. + pipeline: The pipeline defining the model architecture. + device: The device to use for training and evaluation (default is "cpu"). + epochs: The number of training epochs (default is 1). + lr: The learning rate for the optimizer (default is 0.1). + gamma: The gamma value for the learning rate scheduler (default is 0.7). + batch_size: The batch size for training and evaluation (default is 64). + log_interval: The interval for logging training progress (default is 10). + + Returns: + Trial.Report: A report containing the evaluation results. + """ + trial.store({"config.json": pipeline.config}) + torch.manual_seed(trial.seed) + + train_loader = torch.utils.data.DataLoader( + datasets.MNIST( + "../data", + train=True, + download=False, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))], + ), + ), + batch_size=batch_size, + shuffle=True, + ) + test_loader = torch.utils.data.DataLoader( + datasets.MNIST( + "../data", + train=False, + download=False, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))], + ), + ), + batch_size=batch_size, + shuffle=True, + ) + + _device = torch.device(device) + print("Using device", _device) + + model = ( + pipeline.configure(trial.config) + .build(builder=build_model_from_pipeline) + .to(_device) + ) + + with trial.profile("training"): + optimizer = optim.Adadelta(model.parameters(), lr=lr) + lr_scheduler = StepLR(optimizer, step_size=1, gamma=gamma) + + for epoch in range(epochs): + for batch_idx, (_data, _target) in enumerate(train_loader): + optimizer.zero_grad() + data, target = _data.to(_device), _target.to(_device) + + output = model(data) + loss = f.nll_loss(output, target) + + loss.backward() + optimizer.step() + + if batch_idx % log_interval == 0: + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(data), + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss.item(), + ), + ) + lr_scheduler.step() + + final_train_loss, final_train_acc = test(model, _device, train_loader) + final_test_loss, final_test_acc = test(model, _device, test_loader) + trial.summary["final_test_loss"] = final_test_loss + trial.summary["final_test_accuracy"] = final_test_acc + trial.summary["final_train_loss"] = final_train_loss + trial.summary["final_train_accuracy"] = final_train_acc + + return trial.success(accuracy=final_test_acc) + + +def main() -> None: + """Main function to orchestrate the model training and evaluation process. + + This function sets up the training environment, defines the search space + for hyperparameter optimization, and iteratively evaluates different + configurations using the SMAC optimizer. + + Returns: + None + """ + # Training settings + _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + torch.device(_device) + + # Download the dataset + datasets.MNIST("../data", train=True, download=False) + datasets.MNIST("../data", train=False, download=False) + + # Define the pipeline with search space for hyperparameter optimization + pipeline: Sequential = Sequential( + Choice( + Sequential( + nn.Flatten(start_dim=1), + Component( + nn.Linear, + config={"in_features": 784, "out_features": 100}, + name="choice1-fc1", + ), + name="choice1", + ), + Sequential( + Component( + nn.Conv2d, + config={ + "in_channels": 1, # MNIST images are grayscale + "out_channels": 32, # Number of output channels (filters) + "kernel_size": (3, 3), # Size of the convolutional kernel + "stride": (1, 1), # Stride of the convolution + "padding": (1, 1), # Padding to add to the input + }, + name="choice2", + ), + nn.ReLU(), + nn.MaxPool2d(kernel_size=(2, 2)), + nn.Flatten(start_dim=1), + name="choice2", + ), + name="layer1", + ), + Component( + nn.Linear, + config={ + "in_features": MatchChosenDimensions( + choice_name="layer1", + choices={"choice1": 100, "choice2": 32 * 14 * 14}, + ), + "out_features": MatchDimensions("fc2", param="in_features"), + }, + name="fc1", + ), + Choice(nn.ReLU(), nn.Sigmoid(), name="activation"), + Component( + nn.Linear, + space={"in_features": (10, 50), "out_features": (10, 30)}, + name="fc2", + ), + Component( + nn.Linear, + config={ + "in_features": MatchDimensions("fc2", param="out_features"), + "out_features": 10, + }, + name="fc3", + ), + Component(nn.LogSoftmax, config={"dim": 1}), + name="my-mlp-pipeline", + ) + + # Define the metric for optimization + metric: Metric = Metric("accuracy", minimize=False, bounds=(0, 1)) + + # Initialize the SMAC optimizer + optimizer = SMACOptimizer.create( + space=pipeline, + metrics=metric, + seed=1, + bucket="pytorch-experiments", + ) + + # Iteratively evaluate different configurations using the optimizer + trial = optimizer.ask() + report = eval_configuration(trial, pipeline, device=_device) + optimizer.tell(report) + print(report) + + +if __name__ == "__main__": + main() diff --git a/examples/sklearn-hpo-cv.py b/examples/sklearn-hpo-cv.py new file mode 100644 index 00000000..04d6d6c4 --- /dev/null +++ b/examples/sklearn-hpo-cv.py @@ -0,0 +1,277 @@ +"""Random Search with CVEvaluation. + +This example demonstrates the [`CVEvaluation`][amltk.sklearn.CVEvaluation] class, +which builds a custom cross-validation task that can be used to evaluate +[`pipelines`](../guides/pipelines.md) with cross-validation, using +[`RandomSearch`][amltk.optimization.optimizers.random_search.RandomSearch]. +""" +from collections.abc import Mapping +from pathlib import Path +from typing import Any + +import numpy as np +import openml +import pandas as pd +from ConfigSpace import Categorical, Integer +from sklearn.ensemble import RandomForestClassifier +from sklearn.impute import SimpleImputer +from sklearn.metrics import get_scorer +from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder + +from amltk.optimization.optimizers.random_search import RandomSearch +from amltk.optimization.trial import Metric, Trial +from amltk.pipeline import Choice, Component, Node, Sequential, Split, request +from amltk.sklearn import CVEvaluation + + +def get_fold( + openml_task_id: int, + fold: int, +) -> tuple[ + pd.DataFrame, + pd.DataFrame, + pd.DataFrame | pd.Series, + pd.DataFrame | pd.Series, +]: + """Get the data for a specific fold of an OpenML task. + + Args: + openml_task_id: The OpenML task id. + fold: The fold number. + n_splits: The number of splits that will be applied. This is used + to resample training data such that enough at least instance for each class is present for + every stratified split. + seed: The random seed to use for reproducibility of resampling if necessary. + """ + task = openml.tasks.get_task( + openml_task_id, + download_splits=True, + download_data=True, + download_qualities=True, + download_features_meta_data=True, + ) + train_idx, test_idx = task.get_train_test_split_indices(fold=fold) + X, y = task.get_X_and_y(dataset_format="dataframe") # type: ignore + X_train, y_train = X.iloc[train_idx], y.iloc[train_idx] + X_test, y_test = X.iloc[test_idx], y.iloc[test_idx] + return X_train, X_test, y_train, y_test + + +preprocessing = Split( + { + "numerical": Component(SimpleImputer, space={"strategy": ["mean", "median"]}), + "categorical": [ + Component( + OrdinalEncoder, + config={ + "categories": "auto", + "handle_unknown": "use_encoded_value", + "unknown_value": -1, + "encoded_missing_value": -2, + }, + ), + Choice( + "passthrough", + Component( + OneHotEncoder, + space={"max_categories": (2, 20)}, + config={ + "categories": "auto", + "drop": None, + "sparse_output": False, + "handle_unknown": "infrequent_if_exist", + }, + ), + name="one_hot", + ), + ], + }, + name="preprocessing", +) + + +def rf_config_transform(config: Mapping[str, Any], _: Any) -> dict[str, Any]: + new_config = dict(config) + if new_config["class_weight"] == "None": + new_config["class_weight"] = None + return new_config + + +# NOTE: This space should not be used for evaluating how good this RF is +# vs other algorithms +rf_classifier = Component( + item=RandomForestClassifier, + config_transform=rf_config_transform, + space={ + "criterion": ["gini", "entropy"], + "max_features": Categorical( + "max_features", + list(np.logspace(0.1, 1, base=10, num=10) / 10), + ordered=True, + ), + "min_samples_split": Integer("min_samples_split", bounds=(2, 20), default=2), + "min_samples_leaf": Integer("min_samples_leaf", bounds=(1, 20), default=1), + "bootstrap": Categorical("bootstrap", [True, False], default=True), + "class_weight": ["balanced", "balanced_subsample", "None"], + "min_impurity_decrease": (1e-9, 1e-1), + }, + config={ + "random_state": request( + "random_state", + default=None, + ), # Will be provided later by the `Trial` + "n_estimators": 512, + "max_depth": None, + "min_weight_fraction_leaf": 0.0, + "max_leaf_nodes": None, + "warm_start": False, # False due to no iterative fit used here + "n_jobs": 1, + }, +) + +rf_pipeline = Sequential(preprocessing, rf_classifier, name="rf_pipeline") + + +def do_something_after_a_split_was_evaluated( + trial: Trial, + fold: int, + info: CVEvaluation.PostSplitInfo, +) -> CVEvaluation.PostSplitInfo: + return info + + +def do_something_after_a_complete_trial_was_evaluated( + report: Trial.Report, + pipeline: Node, + info: CVEvaluation.CompleteEvalInfo, +) -> Trial.Report: + return report + + +def main() -> None: + random_seed = 42 + openml_task_id = 31 # Adult dataset, classification + task_hint = "classification" + outer_fold_number = ( + 0 # Only run the first outer fold, wrap this in a loop if needs be, with a unique history file + # for each one + ) + optimizer_cls = RandomSearch + working_dir = Path("example-sklearn-hpo-cv").absolute() + results_to = working_dir / "results.parquet" + inner_fold_seed = random_seed + outer_fold_number + metric_definition = Metric( + "accuracy", + minimize=False, + bounds=(0, 1), + fn=get_scorer("accuracy"), + ) + + per_process_memory_limit = None # (4, "GB") # NOTE: May have issues on Mac + per_process_walltime_limit = None # (60, "s") + + debugging = False + if debugging: + max_trials = 1 + max_time = 30 + n_workers = 1 + # raise an error with traceback, something went wrong + on_trial_exception = "raise" + display = True + wait_for_all_workers_to_finish = True + else: + max_trials = 10 + max_time = 300 + n_workers = 4 + # Just mark the trial as fail and move on to the next one + on_trial_exception = "continue" + display = True + wait_for_all_workers_to_finish = False + + X, X_test, y, y_test = get_fold( + openml_task_id=openml_task_id, + fold=outer_fold_number, + ) + + # This object below is a highly customizable class to create a function that we can use for + # evaluating pipelines. + evaluator = CVEvaluation( + # Provide data, number of times to split, cross-validation and a hint of the task type + X, + y, + splitter="cv", + n_splits=8, + task_hint=task_hint, + # Seeding for reproducibility + random_state=inner_fold_seed, + # Provide test data to get test scores + X_test=X_test, + y_test=y_test, + # Record training scores + train_score=True, + # Where to store things + working_dir=working_dir, + # What to do when something goes wrong. + on_error="raise" if on_trial_exception == "raise" else "fail", + # Whether you want models to be store on disk under working_dir + store_models=False, + # A callback to be called at the end of each split + post_split=do_something_after_a_split_was_evaluated, + # Some callback that is called at the end of all fold evaluations + post_processing=do_something_after_a_complete_trial_was_evaluated, + # Whether the post_processing callback requires models will required models, i.e. + # to compute some bagged average over all fold models. If `False` will discard models eagerly + # to sasve sapce. + post_processing_requires_models=False, + # This handles edge cases related to stratified splitting when there are too + # few instances of a specific class. May wish to disable if your passing extra fit params + rebalance_if_required_for_stratified_splitting=True, + # Extra parameters requested by sklearn models/group splitters or metrics, + # such as `sample_weight` + params=None, + ) + + # Here we just use the `optimize` method to setup and run an optimization loop + # with `n_workers`. Please either look at the source code for `optimize` or + # refer to the `Scheduler` and `Optimizer` guide if you need more fine grained control. + # If you need to evaluate a certain configuraiton, you can create your own `Trial` object. + # + # trial = Trial.create(name=...., info=None, config=..., bucket=..., seed=..., metrics=metric_def) + # report = evaluator.evaluate(trial, rf_pipeline) + # print(report) + # + history = rf_pipeline.optimize( + target=evaluator.fn, + metric=metric_definition, + optimizer=optimizer_cls, + seed=inner_fold_seed, + process_memory_limit=per_process_memory_limit, + process_walltime_limit=per_process_walltime_limit, + working_dir=working_dir, + max_trials=max_trials, + timeout=max_time, + display=display, + wait=wait_for_all_workers_to_finish, + n_workers=n_workers, + on_trial_exception=on_trial_exception, + ) + + df = history.df() + + # Assign some new information to the dataframe + df.assign( + outer_fold=outer_fold_number, + inner_fold_seed=inner_fold_seed, + task_id=openml_task_id, + max_trials=max_trials, + max_time=max_time, + optimizer=optimizer_cls.__name__, + n_workers=n_workers, + ) + print(df) + print(f"Saving dataframe of results to path: {results_to}") + df.to_parquet(results_to) + + +if __name__ == "__main__": + main() diff --git a/justfile b/justfile index 7a821b0e..3941c9b6 100644 --- a/justfile +++ b/justfile @@ -25,12 +25,29 @@ check: check-types: mypy src -# Launch the docs server locally and open the webpage -docs exec_doc_code="true" example="None" offline="false": +# Launch the docs, executing code blocks and examples +docs-full: python -m webbrowser -t "http://127.0.0.1:8000/" - AMLTK_DOC_RENDER_EXAMPLES={{example}} \ - AMLTK_DOCS_OFFLINNE={{offline}} \ - AMLTK_EXEC_DOCS={{exec_doc_code}} mkdocs serve --watch-theme --dirtyreload + AMLTK_DOC_RENDER_EXAMPLES=all \ + AMLTK_DOCS_OFFLINE=true \ + AMLTK_EXEC_DOCS=true \ + mkdocs serve --watch-theme + +# Launch the docs and execute code blocks +docs-code: + python -m webbrowser -t "http://127.0.0.1:8000/" + AMLTK_DOCS_OFFLINE=true \ + AMLTK_EXEC_DOCS=true \ + AMLTK_DOC_RENDER_EXAMPLES=false \ + mkdocs serve --watch-theme + +# Launch the docs but dont run code examples +docs: + python -m webbrowser -t "http://127.0.0.1:8000/" + AMLTK_DOCS_OFFLINE=true \ + AMLTK_EXEC_DOCS=false \ + AMLTK_DOC_RENDER_EXAMPLES=false \ + mkdocs serve --watch-theme # https://github.com/pawamoy/markdown-exec/issues/19 action: diff --git a/mkdocs.yml b/mkdocs.yml index f052af9e..37d6b2cc 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,3 +1,17 @@ +# This project uses mkdocs to generate the documentation. +# Specifically it uses the mkdocs-material theme, which provides a whole +# host of nice features and customization +# +# mkdocs: https://www.mkdocs.org/getting-started/#getting-started-with-mkdocs +# mkdocs-material: https://squidfunk.github.io/mkdocs-material/ +# +# Please refer to these links for more information on how to use mkdocs +# +# For serving the docs locally, you can take a look at the `justfile` at +# the root of this repository, it contains a few commands for generating the docs +# with different levels of execution. +# +# Please refer to individual sections for any additional notes site_name: "AutoML-Toolkit" repo_url: https://github.com/automl/amltk/ repo_name: automl/amltk @@ -12,10 +26,11 @@ theme: - content.code.annotate - content.code.copy - navigation.footer + - navigation.sections + - toc.follow + - toc.integrate - navigation.tabs - navigation.tabs.sticky - - navigation.expand - - toc.follow - header.autohide - search.suggest - search.highlight @@ -42,6 +57,10 @@ theme: name: Switch to dark mode +# The `mike` versioning provider +# https://github.com/jimporter/mike +# +# This is what allows us to create versioned docs in the github cli extra: version: provider: mike @@ -51,6 +70,9 @@ extra: - icon: fontawesome/brands/twitter link: https://twitter.com/automl_org +# We do have some extra custom css +# If for whatever reason you think this is breaking something, +# please feel free to remove it. extra_css: - stylesheets/custom.css @@ -90,6 +112,18 @@ markdown_extensions: emoji_index: !!python/name:material.extensions.emoji.twemoji emoji_generator: !!python/name:material.extensions.emoji.to_svg +# These are files that are run when serving the docs. +hooks: + # This prevents logging messages from polluting the doc build + - docs/hooks/cleanup_log_output.py + # This prevents markdown_exec (plugin) from executing code blocks + # dependant on environment variables. These env variables are + # automatically set with the `justfile` commands to build docs + - docs/hooks/disable_markdown_exec.py + # This hook simply prints the page being rendered for an easier time debugging + # any issues with code in docs + - docs/hooks/debug_which_page_is_being_rendered.py + plugins: - search - autorefs @@ -111,10 +145,10 @@ plugins: - mkdocstrings: default_handler: python enable_inventory: true - custom_templates: docs/_templates handlers: python: paths: [src] + # Extra objects which allow for linking to external docs import: - 'https://docs.python.org/3/objects.inv' - 'https://numpy.org/doc/stable/objects.inv' @@ -123,31 +157,37 @@ plugins: - 'https://scikit-learn.org/stable/objects.inv' - 'https://pytorch.org/docs/stable/objects.inv' - 'https://jobqueue.dask.org/en/latest/objects.inv' + # Please do not try to change these without having + # looked at all of the documentation and seeing if it + # causes the API docs to look weird anywhere. options: # https://mkdocstrings.github.io/python/usage/ docstring_section_style: spacy docstring_options: ignore_init_summary: true trim_doctest_flags: true + returns_multiple_items: false show_docstring_attributes: true show_docstring_description: true - show_root_heading: false - show_root_toc_entry: false + show_root_heading: true + show_root_toc_entry: true show_object_full_path: false + show_root_members_full_path: false + signature_crossrefs: true merge_init_into_class: true + show_symbol_type_heading: true + show_symbol_type_toc: true docstring_style: google + inherited_members: true show_if_no_docstring: false show_bases: true show_source: true - members_order: "source" - # Would like to set `group_by_category` to false - # https://github.com/mkdocstrings/mkdocstrings/issues/579 + members_order: "alphabetical" group_by_category: true show_signature: true - separate_signature: false - show_signature_annotations: false + separate_signature: true + show_signature_annotations: true filters: - "!^_[^_]" - - "_sample" # Kind of a hack to have this render a private method nav: - Home: "index.md" @@ -187,4 +227,5 @@ nav: # Auto generated with docs/api_generator.py - API: "api/" - Contributing: "contributing.md" + - What's New?: "changelog.md" diff --git a/pyproject.toml b/pyproject.toml index ac4b17b3..08c4988c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "amltk" -version = "1.8.0" +version = "1.12.1" dependencies = [ "typing_extensions", # Better typing "more_itertools", # Better iteration @@ -31,7 +31,7 @@ license = { file = "LICENSE" } [project.optional-dependencies] dev = ["amltk[doc, tooling, test, examples]"] tooling = ["commitizen", "pre-commit", "ruff", "mypy", "types-psutil", "types-pyyaml"] -test = ["pytest", "pytest-coverage", "pytest-cases", "amltk[examples]"] +test = ["pytest<8", "pytest-coverage", "pytest-cases", "amltk[examples]", "torch"] examples = ["openml", "amltk[optuna, smac, sklearn, rich, loky, dask, xgboost, wandb, pynisher, path_loaders]"] doc = [ "mkdocs", @@ -47,6 +47,7 @@ doc = [ "mike", "pillow", "cairosvg", + "black", # This allows mkdocstrings to format signatures in the docs ] # --- Optional user dependancies sklearn = ["scikit-learn", "threadpoolctl"] @@ -54,7 +55,7 @@ smac = ["smac>=2.0", "amltk[configspace]"] optuna = ["optuna"] configspace = ["configspace>=0.6"] loky = ["loky"] -dask = ["dask<=2023.4", "distributed"] +dask = ["dask", "distributed"] pynisher = ["pynisher>=1.0.10"] wandb = ["wandb"] threadpoolctl = ["threadpoolctl"] @@ -97,7 +98,7 @@ exclude_lines = [ [tool.commitizen] name = "cz_conventional_commits" -version = "1.8.0" +version = "1.12.1" update_changelog_on_bump = true version_files = ["pyproject.toml:version", "src/amltk/__version__.py"] changelog_start_rev = "1.0.0" @@ -106,8 +107,14 @@ changelog_start_rev = "1.0.0" [tool.ruff] target-version = "py310" line-length = 88 -show-source = true +output-format = "full" src = ["src", "tests", "examples"] + +[tool.ruff.lint] +# Extend what ruff is allowed to fix, even it it may break +# This is okay given we use it all the time and it ensures +# better practices. Would be dangerous if using for first +# time on established project. extend-safe-fixes = ["ALL"] # Allow unused variables when underscore-prefixed. @@ -176,6 +183,7 @@ ignore = [ "PLC1901", # "" can be simplified to be falsey "TCH003", # Move stdlib import into TYPE_CHECKING "B010", # Do not use `setattr` + "PD011", # Use .to_numpy() instead of .values (triggers on report.values) # These tend to be lighweight and confuse pyright ] @@ -203,9 +211,11 @@ exclude = [ ] # Exclude a variety of commonly ignored directories. -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "tests/*.py" = [ "S101", + "D101", + "D102", "D103", "ANN001", "ANN201", @@ -215,6 +225,7 @@ exclude = [ "PD901", # X is a bad variable name. (pandas) "TCH", "N803", + "C901", # Too complex ] "src/amltk/__version__.py" = ["D100"] "__init__.py" = ["I002"] @@ -222,7 +233,7 @@ exclude = [ "docs/*" = ["INP001"] -[tool.ruff.isort] +[tool.ruff.lint.isort] known-first-party = ["amltk"] known-third-party = ["sklearn"] no-lines-before = ["future"] @@ -231,10 +242,10 @@ combine-as-imports = true extra-standard-library = ["typing_extensions"] force-wrap-aliases = true -[tool.ruff.pydocstyle] +[tool.ruff.lint.pydocstyle] convention = "google" -[tool.ruff.pylint] +[tool.ruff.lint.pylint] max-args = 10 # Changed from default of 5 [tool.mypy] diff --git a/src/amltk/__version__.py b/src/amltk/__version__.py index fc1822ec..5ce731e1 100644 --- a/src/amltk/__version__.py +++ b/src/amltk/__version__.py @@ -1,3 +1,3 @@ from __future__ import annotations -version = "1.8.0" +version = "1.12.1" diff --git a/src/amltk/_asyncm.py b/src/amltk/_asyncm.py index daa9fdda..ed0b9077 100644 --- a/src/amltk/_asyncm.py +++ b/src/amltk/_asyncm.py @@ -97,3 +97,12 @@ def context(self) -> tuple[str | None, BaseException | None]: A tuple of the message and exception. """ return self.msg, self.exception + + def __enter__(self) -> ContextEvent: + """Enter the context manager.""" + self.set() + return self + + def __exit__(self, *args: Any) -> None: + """Exit the context manager.""" + self.clear() diff --git a/src/amltk/_doc.py b/src/amltk/_doc.py index 59c045c3..da766aa1 100644 --- a/src/amltk/_doc.py +++ b/src/amltk/_doc.py @@ -1,5 +1,6 @@ from __future__ import annotations +import importlib import os from collections.abc import Callable from functools import lru_cache @@ -47,17 +48,17 @@ def link(obj: Any) -> str | None: return _try_get_link(fullname(obj)) -def make_picklable(thing: Any, name: str | None = None) -> None: +def make_picklable(thing: Any) -> None: """This is hack to make the examples code with schedulers work. Scheduler uses multiprocessing and multiprocessing requires that all objects passed to the scheduler are picklable. This is not the case for the classes/functions defined in the example code. """ - import __main__ + thing_module = thing.__module__ - _name = thing.__name__ if name is None else name - setattr(__main__, _name, thing) + _mod = importlib.import_module(thing_module) + setattr(_mod, thing.__name__, thing) def as_rich_svg( diff --git a/src/amltk/_richutil/__init__.py b/src/amltk/_richutil/__init__.py index 50e6a507..ae4bf870 100644 --- a/src/amltk/_richutil/__init__.py +++ b/src/amltk/_richutil/__init__.py @@ -1,6 +1,6 @@ from amltk._richutil.renderable import RichRenderable from amltk._richutil.renderers import Function, rich_make_column_selector -from amltk._richutil.util import df_to_table, richify +from amltk._richutil.util import df_to_table, is_jupyter, richify __all__ = [ "df_to_table", @@ -8,4 +8,5 @@ "RichRenderable", "Function", "rich_make_column_selector", + "is_jupyter", ] diff --git a/src/amltk/_richutil/util.py b/src/amltk/_richutil/util.py index 77d1cbee..5f8a0668 100644 --- a/src/amltk/_richutil/util.py +++ b/src/amltk/_richutil/util.py @@ -3,6 +3,7 @@ # where rich not being installed. from __future__ import annotations +import os from concurrent.futures import ProcessPoolExecutor from typing import TYPE_CHECKING, Any @@ -70,3 +71,25 @@ def df_to_table( table.add_row(str(index), *[str(cell) for cell in row]) return table + + +def is_jupyter() -> bool: + """Return True if running in a Jupyter environment.""" + # https://github.com/Textualize/rich/blob/fd981823644ccf50d685ac9c0cfe8e1e56c9dd35/rich/console.py#L518-L535 + try: + get_ipython # type: ignore[name-defined] # noqa: B018 + except NameError: + return False + ipython = get_ipython() # type: ignore[name-defined] # noqa: F821 + shell = ipython.__class__.__name__ + if ( + "google.colab" in str(ipython.__class__) + or os.getenv("DATABRICKS_RUNTIME_VERSION") + or shell == "ZMQInteractiveShell" + ): + return True # Jupyter notebook or qtconsole + + if shell == "TerminalInteractiveShell": + return False # Terminal running IPython + + return False # Other type (?) diff --git a/src/amltk/_util.py b/src/amltk/_util.py new file mode 100644 index 00000000..7825b110 --- /dev/null +++ b/src/amltk/_util.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import warnings +from collections.abc import Iterator +from contextlib import AbstractContextManager, ExitStack, contextmanager +from datetime import datetime +from typing import Any + +import pandas as pd + + +def threadpoolctl_heuristic(item_contained_in_node: Any | None) -> bool: + """Heuristic to determine if we should automatically set threadpoolctl. + + This is done by detecting if it's a scikit-learn `BaseEstimator` but this may + be extended in the future. + + !!! tip + + The reason to have this heuristic is that when running scikit-learn, or any + multithreaded model, in parallel, they will over subscribe to threads. This + causes a significant performance hit as most of the time is spent switching + thread contexts instead of work. This can be particularly bad for HPO where + we are evaluating multiple models in parallel on the same system. + + The recommened thread count is 1 per core with no additional information to + act upon. + + !!! todo + + This is potentially not an issue if running on multiple nodes of some cluster, + as they do not share logical cores and hence do not clash. + + Args: + item_contained_in_node: The item with which to base the heuristic on. + + Returns: + Whether we should automatically set threadpoolctl. + """ + if item_contained_in_node is None or not isinstance(item_contained_in_node, type): + return False + + try: + # NOTE: sklearn depends on threadpoolctl so it will be installed. + from sklearn.base import BaseEstimator + + return issubclass(item_contained_in_node, BaseEstimator) + except ImportError: + return False + + +def parse_timestamp_object(timestamp: Any) -> datetime: + """Parse a timestamp object, erring if it can't be parsed. + + Args: + timestamp: The timestamp to parse. + + Returns: + The parsed timestamp or `None` if it could not be parsed. + """ + # Make sure we correctly set it's generated at if + # we can + match timestamp: + case datetime(): + return timestamp + case pd.Timestamp(): + return timestamp.to_pydatetime() + case float() | int(): + return datetime.fromtimestamp(timestamp) + case str(): + try: + return datetime.fromisoformat(timestamp) + except ValueError as e: + raise ValueError( + f"Could not parse `str` type timestamp for '{timestamp}'." + " \nPlease provide a valid isoformat timestamp, e.g." + "'2021-01-01T00:00:00.000000'.", + ) from e + case _: + raise TypeError(f"Could not parse {timestamp=} of type {type(timestamp)}.") + + +@contextmanager +def ignore_warnings() -> Iterator[None]: + """Ignore warnings for the duration of the context manager.""" + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + yield + + +@contextmanager +def mutli_context(*managers: AbstractContextManager) -> Iterator: + """Run multiple context managers at once.""" + with ExitStack() as stack: + yield [stack.enter_context(m) for m in managers] diff --git a/src/amltk/exceptions.py b/src/amltk/exceptions.py index 5c78fcd1..e2140538 100644 --- a/src/amltk/exceptions.py +++ b/src/amltk/exceptions.py @@ -5,11 +5,8 @@ import traceback from collections.abc import Callable, Iterable, Iterator -from typing import TYPE_CHECKING, Any, TypeVar -from typing_extensions import ParamSpec, override - -if TYPE_CHECKING: - from amltk.pipeline.node import Node +from typing import Any, TypeVar +from typing_extensions import ParamSpec R = TypeVar("R") E = TypeVar("E") @@ -68,6 +65,13 @@ def __init__(self, name: str) -> None: super().__init__(f"No integration found for {name}.") +class AutomaticParameterWarning(UserWarning): + """Raised when an "auto" parameter of a function is used + and triggers some behaviour which would be better explicitly + set. + """ + + class SchedulerNotRunningError(RuntimeError): """The scheduler is not running.""" @@ -95,18 +99,97 @@ class ComponentBuildError(TypeError): class DuplicateNamesError(ValueError): """Raised when duplicate names are found.""" - def __init__(self, node: Node) -> None: + +class AutomaticThreadPoolCTLWarning(AutomaticParameterWarning): + """Raised when automatic threadpoolctl is enabled.""" + + +class ImplicitMetricConversionWarning(UserWarning): + """A warning raised when a metric is implicitly converted to an sklearn scorer. + + This is raised when a metric is provided with a custom function and is + implicitly converted to an sklearn scorer. This may fail in some cases + and it is recommended to explicitly convert the metric to an sklearn + scorer with `make_scorer` and then pass it to the metric with + [`Metric(fn=...)`][amltk.optimization.Metric]. + """ + + +class TaskTypeWarning(UserWarning): + """A warning raised about the task type.""" + + +class AutomaticTaskTypeInferredWarning(TaskTypeWarning, AutomaticParameterWarning): + """A warning raised when the task type is inferred from the target data.""" + + +class MismatchedTaskTypeWarning(TaskTypeWarning): + """A warning raised when inferred task type with `task_hint` does not + match the inferred task type from the target data. + """ + + +class TrialError(RuntimeError): + """An exception raised from a trial and it is meant to be raised directly + to the user. + """ + + +class CVEarlyStoppedError(RuntimeError): + """An exception raised when a CV evaluation is early stopped.""" + + +class MatchDimensionsError(KeyError): + """An exception raised for errors related to matching dimensions in a pipeline.""" + + def __init__(self, layer_name: str, param: str | None, *args: Any) -> None: + """Initialize the exception. + + Args: + layer_name: The name of the layer. + param: The parameter causing the error, if any. + *args: Additional arguments to pass to the exception. + """ + if param: + super().__init__( + f"Error in matching dimensions for layer '{layer_name}'. " + f"Parameter '{param}' not found in the configuration.", + *args, + ) + else: + super().__init__( + f"Error in matching dimensions for layer '{layer_name}'." + f" Configuration not found.", + *args, + ) + + +class MatchChosenDimensionsError(KeyError): + """An exception raised related to matching dimensions for chosen nodes.""" + + def __init__( + self, + choice_name: str, + chosen_node_name: str | None = None, + *args: Any, + ) -> None: """Initialize the exception. Args: - node: The node that has children with duplicate names. + choice_name: The name of the choice that caused the error. + chosen_node_name: The name of the chosen node if available. + *args: Additional arguments to pass to the exception. """ - super().__init__(node) - self.node = node - - @override - def __str__(self) -> str: - return ( - f"Duplicate names found in {self.node.name} and can't be handled." - f"\nnodes: {[n.name for n in self.node.nodes]}." - ) + if chosen_node_name: + message = ( + f"Error in matching dimensions for chosen node '{chosen_node_name}' " + f"of Choice '{choice_name}'. Make sure that the names for " + f"Choice and MatchChosenDimensions 'choices' parameters match." + ) + else: + message = ( + f"Choice name '{choice_name}' is not found in the chosen nodes." + f"Make sure that the names for Choice and " + f"MatchChosenDimensions 'choice_name' parameters match." + ) + super().__init__(message, *args) diff --git a/src/amltk/metalearning/dataset_distances.py b/src/amltk/metalearning/dataset_distances.py index cc99a61e..4d29fe79 100644 --- a/src/amltk/metalearning/dataset_distances.py +++ b/src/amltk/metalearning/dataset_distances.py @@ -1,92 +1,4 @@ -"""One common way to define how similar two datasets are is to compute some "similarity" -between them. This notion of "similarity" requires computing some features of a dataset -(**metafeatures**) first, such that we can numerically compute some distance function. - -Let's see how we can quickly compute the distance between some datasets with -[`dataset_distance()`][amltk.metalearning.dataset_distance]! - -```python exec="true" source="material-block" result="python" title="Dataset Distances P.1" session='dd' -import pandas as pd -import openml - -from amltk.metalearning import compute_metafeatures - -def get_dataset(dataset_id: int) -> tuple[pd.DataFrame, pd.Series]: - dataset = openml.datasets.get_dataset( - dataset_id, - download_data=True, - download_features_meta_data=False, - download_qualities=False, - ) - X, y, _, _ = dataset.get_data( - dataset_format="dataframe", - target=dataset.default_target_attribute, - ) - return X, y - -d31 = get_dataset(31) -d3 = get_dataset(3) -d4 = get_dataset(4) - -metafeatures_dict = { - "dataset_31": compute_metafeatures(*d31), - "dataset_3": compute_metafeatures(*d3), - "dataset_4": compute_metafeatures(*d4), -} - -metafeatures = pd.DataFrame(metafeatures_dict) -print(metafeatures) -``` - -Now we want to know which one of `#!python "dataset_3"` or `#!python "dataset_4"` is -more _similar_ to `#!python "dataset_31"`. - -```python exec="true" source="material-block" result="python" title="Dataset Distances P.2" session='dd' -from amltk.metalearning import dataset_distance - -target = metafeatures_dict.pop("dataset_31") -others = metafeatures_dict - -distances = dataset_distance(target, others, distance_metric="l2") -print(distances) -``` - -Seems like `#!python "dataset_3"` is some notion of closer to `#!python "dataset_31"` -than `#!python "dataset_4"`. However the scale of the metafeatures are not exactly all close. -For example, many lie between `#!python (0, 1)` but some like `instance_count` can completely -dominate the show. - -Lets repeat the computation but specify that we should apply a `#!python "minmax"` scaling -across the rows. - -```python exec="true" source="material-block" result="python" title="Dataset Distances P.3" session='dd' hl_lines="5" -distances = dataset_distance( - target, - others, - distance_metric="l2", - scaler="minmax" -) -print(distances) -``` - -Now `#!python "dataset_3"` is considered more similar but the difference between the two is a lot less -dramatic. In general, applying some scaling to values of different scales is required for metalearning. - -You can also use an [sklearn.preprocessing.MinMaxScaler][] or anything other scaler from scikit-learn -for that matter. - -```python exec="true" source="material-block" result="python" title="Dataset Distances P.3" session='dd' hl_lines="7" -from sklearn.preprocessing import MinMaxScaler - -distances = dataset_distance( - target, - others, - distance_metric="l2", - scaler=MinMaxScaler() -) -print(distances) -``` -""" # noqa: E501 +"""Calculating metadata distances.""" from __future__ import annotations import warnings diff --git a/src/amltk/metalearning/metafeatures.py b/src/amltk/metalearning/metafeatures.py index 6bea8a6c..75ee1bdd 100644 --- a/src/amltk/metalearning/metafeatures.py +++ b/src/amltk/metalearning/metafeatures.py @@ -1,140 +1,4 @@ -'''A [`MetaFeature`][amltk.metalearning.MetaFeature] is some -statistic about a dataset/task, that can be used to make datasets or -tasks more comparable, thus enabling meta-learning methods. - -Calculating meta-features of a dataset is quite straight foward. - -```python exec="true" source="material-block" result="python" title="Metafeatures" hl_lines="10" -import openml -from amltk.metalearning import compute_metafeatures - -dataset = openml.datasets.get_dataset( - 31, # credit-g - download_data=True, - download_features_meta_data=False, - download_qualities=False, -) -X, y, _, _ = dataset.get_data( - dataset_format="dataframe", - target=dataset.default_target_attribute, -) - -mfs = compute_metafeatures(X, y) - -print(mfs) -``` - -By default [`compute_metafeatures()`][amltk.metalearning.compute_metafeatures] will -calculate all the [`MetaFeature`][amltk.metalearning.MetaFeature] implemented, -iterating through their subclasses to do so. You can pass an explicit list -as well to `compute_metafeatures(X, y, features=[...])`. - -To implement your own is also quite straight forward: - -```python exec="true" source="material-block" result="python" title="Create Metafeature" hl_lines="10 11 12 13 14 15 16 17 18 19" -from amltk.metalearning import MetaFeature, compute_metafeatures -import openml - -dataset = openml.datasets.get_dataset( - 31, # credit-g - download_data=True, - download_features_meta_data=False, - download_qualities=False, -) -X, y, _, _ = dataset.get_data( - dataset_format="dataframe", - target=dataset.default_target_attribute, -) - -class TotalValues(MetaFeature): - - @classmethod - def compute( - cls, - x: pd.DataFrame, - y: pd.Series | pd.DataFrame, - dependancy_values: dict, - ) -> int: - return int(x.shape[0] * x.shape[1]) - -mfs = compute_metafeatures(X, y, features=[TotalValues]) -print(mfs) -``` - -As many metafeatures rely on pre-computed dataset statistics, and they do not -need to be calculated more than once, you can specify the dependancies of -a meta feature. When a metafeature would return something other than a single -value, i.e. a `dict` or a `pd.DataFrame`, we instead call those a -[`DatasetStatistic`][amltk.metalearning.DatasetStatistic]. These will -**not** be included in the result of [`compute_metafeatures()`][amltk.metalearning.compute_metafeatures]. -These `DatasetStatistic`s will only be calculated once on a call to `compute_metafeatures()` so -they can be re-used across all `MetaFeature`s that require that dependancy. - -```python exec="true" source="material-block" result="python" title="Metafeature Dependancy" hl_lines="10 11 12 13 14 15 16 17 18 19 20 23 26 35" -from amltk.metalearning import MetaFeature, DatasetStatistic, compute_metafeatures -import openml - -dataset = openml.datasets.get_dataset( - 31, # credit-g - download_data=True, - download_features_meta_data=False, - download_qualities=False, -) -X, y, _, _ = dataset.get_data( - dataset_format="dataframe", - target=dataset.default_target_attribute, -) - -class NAValues(DatasetStatistic): - """A mask of all NA values in a dataset""" - - @classmethod - def compute( - cls, - x: pd.DataFrame, - y: pd.Series | pd.DataFrame, - dependancy_values: dict, - ) -> pd.DataFrame: - return x.isna() - - -class PercentageNA(MetaFeature): - """The percentage of values missing""" - - dependencies = (NAValues,) - - @classmethod - def compute( - cls, - x: pd.DataFrame, - y: pd.Series | pd.DataFrame, - dependancy_values: dict, - ) -> int: - na_values = dependancy_values[NAValues] - n_na = na_values.sum().sum() - n_values = int(x.shape[0] * x.shape[1]) - return float(n_na / n_values) - -mfs = compute_metafeatures(X, y, features=[PercentageNA]) -print(mfs) -``` - -To view the description of a particular `MetaFeature`, you can call -[`.description()`][amltk.metalearning.DatasetStatistic.description] -on it. Otherwise you can access all of them in the following way: - -```python exec="true" source="tabbed-left" result="python" title="Metafeature Descriptions" hl_lines="4" -from pprint import pprint -from amltk.metalearning import metafeature_descriptions - -descriptions = metafeature_descriptions() -for name, description in descriptions.items(): - print("---") - print(name) - print("---") - print(" * " + description) -``` -''' # noqa: E501 +"""Metafeatures access.""" from __future__ import annotations import logging diff --git a/src/amltk/metalearning/portfolio.py b/src/amltk/metalearning/portfolio.py index 7a282a87..058e8469 100644 --- a/src/amltk/metalearning/portfolio.py +++ b/src/amltk/metalearning/portfolio.py @@ -1,115 +1,4 @@ -"""A portfolio in meta-learning is to a set (ordered or not) of configurations -that maximize some notion of coverage across datasets or tasks. -The intuition here is that this also means that any new dataset is also covered! - -Suppose we have the given performances of some configurations across some datasets. -```python exec="true" source="material-block" result="python" title="Initial Portfolio" -import pandas as pd - -performances = { - "c1": [90, 60, 20, 10], - "c2": [20, 10, 90, 20], - "c3": [10, 20, 40, 90], - "c4": [90, 10, 10, 10], -} -portfolio = pd.DataFrame(performances, index=["dataset_1", "dataset_2", "dataset_3", "dataset_4"]) -print(portfolio) -``` - -If we could only choose `#!python k=3` of these configurations on some new given dataset, which ones would -you choose and in what priority? -Here is where we can apply [`portfolio_selection()`][amltk.metalearning.portfolio_selection]! - -The idea is that we pick a subset of these algorithms that maximise some value of utility for -the portfolio. We do this by adding a single configuration from the entire set, 1-by-1 until -we reach `k`, beginning with the empty portfolio. - -Let's see this in action! - -```python exec="true" source="material-block" result="python" title="Portfolio Selection" hl_lines="12 13 14 15 16" -import pandas as pd -from amltk.metalearning import portfolio_selection - -performances = { - "c1": [90, 60, 20, 10], - "c2": [20, 10, 90, 20], - "c3": [10, 20, 40, 90], - "c4": [90, 10, 10, 10], -} -portfolio = pd.DataFrame(performances, index=["dataset_1", "dataset_2", "dataset_3", "dataset_4"]) - -selected_portfolio, trajectory = portfolio_selection( - portfolio, - k=3, - scaler="minmax" -) - -print(selected_portfolio) -print() -print(trajectory) -``` - -The trajectory tells us which configuration was added at each time stamp along with the utility -of the portfolio with that configuration added. However we havn't specified how _exactly_ we defined the -utility of a given portfolio. We could define our own function to do so: - -```python exec="true" source="material-block" result="python" title="Portfolio Selection Custom" hl_lines="12 13 14 20" -import pandas as pd -from amltk.metalearning import portfolio_selection - -performances = { - "c1": [90, 60, 20, 10], - "c2": [20, 10, 90, 20], - "c3": [10, 20, 40, 90], - "c4": [90, 10, 10, 10], -} -portfolio = pd.DataFrame(performances, index=["dataset_1", "dataset_2", "dataset_3", "dataset_4"]) - -def my_function(p: pd.DataFrame) -> float: - # Take the maximum score for each dataset and then take the mean across them. - return p.max(axis=1).mean() - -selected_portfolio, trajectory = portfolio_selection( - portfolio, - k=3, - scaler="minmax", - portfolio_value=my_function, -) - -print(selected_portfolio) -print() -print(trajectory) -``` - -This notion of reducing across all configurations for a dataset and then aggregating these is common -enough that we can also directly just define these operations and we will perform the rest. - -```python exec="true" source="material-block" result="python" title="Portfolio Selection With Reduction" hl_lines="17 18" -import pandas as pd -import numpy as np -from amltk.metalearning import portfolio_selection - -performances = { - "c1": [90, 60, 20, 10], - "c2": [20, 10, 90, 20], - "c3": [10, 20, 40, 90], - "c4": [90, 10, 10, 10], -} -portfolio = pd.DataFrame(performances, index=["dataset_1", "dataset_2", "dataset_3", "dataset_4"]) - -selected_portfolio, trajectory = portfolio_selection( - portfolio, - k=3, - scaler="minmax", - row_reducer=np.max, # This is actually the default - aggregator=np.mean, # This is actually the default -) - -print(selected_portfolio) -print() -print(trajectory) -``` -""" # noqa: E501 +"""Portfolio selection.""" from __future__ import annotations @@ -210,8 +99,9 @@ def portfolio_selection( The final portfolio The trajectory, where the entry is the value once added to the portfolio. """ - if not (1 <= k < len(items)): - raise ValueError(f"k must be in [1, {len(items)=})") + n_items = len(items) if isinstance(items, dict) else items.shape[1] + if not (1 <= k < n_items): + raise ValueError(f"k must be in [1, {n_items=})") all_portfolio = pd.DataFrame(items) diff --git a/src/amltk/optimization/__init__.py b/src/amltk/optimization/__init__.py index 916dbadc..3bf0b373 100644 --- a/src/amltk/optimization/__init__.py +++ b/src/amltk/optimization/__init__.py @@ -1,5 +1,5 @@ from amltk.optimization.history import History -from amltk.optimization.metric import Metric +from amltk.optimization.metric import Metric, MetricCollection from amltk.optimization.optimizer import Optimizer from amltk.optimization.trial import Trial @@ -7,5 +7,6 @@ "Optimizer", "Trial", "Metric", + "MetricCollection", "History", ] diff --git a/src/amltk/optimization/evaluation.py b/src/amltk/optimization/evaluation.py new file mode 100644 index 00000000..4095e3af --- /dev/null +++ b/src/amltk/optimization/evaluation.py @@ -0,0 +1,5 @@ +"""Evaluation protocols for how a trial and a pipeline should be evaluated. + +TODO: Sorry +""" +from __future__ import annotations diff --git a/src/amltk/optimization/history.py b/src/amltk/optimization/history.py index a598e407..8a354da2 100644 --- a/src/amltk/optimization/history.py +++ b/src/amltk/optimization/history.py @@ -16,12 +16,7 @@ def target_function(trial: Trial) -> Trial.Report: y = trial.config["y"] trial.store({"config.json": trial.config}) - with trial.begin(): - loss = x**2 - y - - if trial.exception: - return trial.fail() - + loss = x**2 - y return trial.success(loss=loss) # ... usually obtained from an optimizer @@ -29,7 +24,8 @@ def target_function(trial: Trial) -> Trial.Report: history = History() for x, y in zip([1, 2, 3], [4, 5, 6]): - trial = Trial(name="some-unique-name", config={"x": x, "y": y}, bucket=bucket, metrics=[loss]) + name = f"trial_{x}_{y}" + trial = Trial.create(name=name, config={"x": x, "y": y}, bucket=bucket / name, metrics=[loss]) report = target_function(trial) history.add(report) @@ -45,7 +41,7 @@ def target_function(trial: Trial) -> Trial.Report: * [`groupby(key=...)`][amltk.optimization.History.groupby] - Groups the history by some key, e.g. `#!python history.groupby(lambda report: report.config["x"] < 5)` * [`sortby(key=...)`][amltk.optimization.History.sortby] - Sorts the history by some - key, e.g. `#!python history.sortby(lambda report: report.time.end)` + key, e.g. `#!python history.sortby(lambda report: report.profiles["trial"].time.end)` There is also some serialization capabilities built in, to allow you to store your reports and load them back in later: @@ -56,7 +52,7 @@ def target_function(trial: Trial) -> Trial.Report: a `pd.DataFrame`. You can also retrieve individual reports from the history by using their -name, e.g. `#!python history["some-unique-name"]` or iterate through +name, e.g. `#!python history.reports["some-unique-name"]` or iterate through the history with `#!python for report in history: ...`. """ # noqa: E501 from __future__ import annotations @@ -65,20 +61,20 @@ def target_function(trial: Trial) -> Trial.Report: from collections import defaultdict from collections.abc import Callable, Hashable, Iterable, Iterator from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Literal, TypeVar +from typing import TYPE_CHECKING, Literal, TypeVar, overload from typing_extensions import override import pandas as pd from amltk._functional import compare_accumulate from amltk._richutil import RichRenderable +from amltk.optimization.metric import Metric from amltk.optimization.trial import Trial from amltk.types import Comparable if TYPE_CHECKING: from rich.console import RenderableType - from amltk.optimization.metric import Metric T = TypeVar("T") CT = TypeVar("CT", bound=Comparable) @@ -103,16 +99,16 @@ class History(RichRenderable): metric = Metric("cost", minimize=True) trials = [ - Trial(name=f"trial_{i}", config={"x": i}, metrics=[metric]) + Trial.create(name=f"trial_{i}", config={"x": i}, metrics=[metric]) for i in range(10) ] history = History() for trial in trials: - with trial.begin(): - x = trial.config["x"] - report = trial.success(cost=x**2 - x*2 + 4) - history.add(report) + x = trial.config["x"] + report = trial.success(cost=x**2 - x*2 + 4) + history.add(report) + trial.bucket.rmdir() # markdown-exec: hide for report in history: print(f"{report.name=}, {report}") @@ -145,7 +141,7 @@ def from_reports(cls, reports: Iterable[Trial.Report]) -> History: history.add(reports) return history - def best(self, metric: str | None = None) -> Trial.Report: + def best(self, metric: str | Metric | None = None) -> Trial.Report: """Returns the best report in the history. Args: @@ -156,26 +152,33 @@ def best(self, metric: str | None = None) -> Trial.Report: Returns: The best report. """ - if metric is None: - if len(self.metrics) > 1: - raise ValueError( - "There are multiple metrics in the history, " - "please specify which metric to sort by.", - ) - - _metric_def = next(iter(self.metrics.values())) - _metric_name = _metric_def.name - else: - if metric not in self.metrics: - raise ValueError( - f"Metric {metric} not found in history. " - f"Available metrics: {list(self.metrics.keys())}", - ) - _metric_def = self.metrics[metric] - _metric_name = metric + match metric: + case None: + if len(self.metrics) > 1: + raise ValueError( + "There are multiple metrics in the history, " + "please specify which metric to sort by for best.", + ) + + _metric_def = next(iter(self.metrics.values())) + _metric_name = _metric_def.name + case str(): + if metric not in self.metrics: + raise ValueError( + f"Metric {metric} not found in history. " + f"Available metrics: {list(self.metrics.keys())}", + ) + _metric_def = self.metrics[metric] + _metric_name = metric + case Metric(): + _metric_def = metric + _metric_name = metric.name _by = min if _metric_def.minimize else max - return _by(self.reports, key=lambda r: r.metrics[_metric_name]) + return _by( + (r for r in self.reports if _metric_name in r.values), + key=lambda r: r.values[_metric_name], + ) def add(self, report: Trial.Report | Iterable[Trial.Report]) -> None: """Adds a report or reports to the history. @@ -185,16 +188,7 @@ def add(self, report: Trial.Report | Iterable[Trial.Report]) -> None: """ match report: case Trial.Report(): - for m in report.metric_values: - if (_m := self.metrics.get(m.name)) is not None: - if m.metric != _m: - raise ValueError( - f"Metric {m.name} has conflicting definitions:" - f"\n{m.metric} != {_m}", - ) - else: - self.metrics[m.name] = m.metric - + self.metrics.update(report.metrics) self.reports.append(report) self._lookup[report.name] = len(self.reports) - 1 case Iterable(): @@ -247,14 +241,14 @@ def df( from amltk.optimization import Trial, History, Metric metric = Metric("cost", minimize=True) - trials = [Trial(name=f"trial_{i}", config={"x": i}, metrics=[metric]) for i in range(10)] + trials = [Trial.create(name=f"trial_{i}", config={"x": i}, metrics=[metric]) for i in range(10)] history = History() for trial in trials: - with trial.begin(): - x = trial.config["x"] - report = trial.success(cost=x**2 - x*2 + 4) - history.add(report) + x = trial.config["x"] + report = trial.success(cost=x**2 - x*2 + 4) + history.add(report) + trial.bucket.rmdir() # markdown-exec: hide print(history.df()) ``` @@ -312,18 +306,18 @@ def filter(self, key: Callable[[Trial.Report], bool]) -> History: from amltk.optimization import Trial, History, Metric metric = Metric("cost", minimize=True) - trials = [Trial(name=f"trial_{i}", config={"x": i}, metrics=[metric]) for i in range(10)] + trials = [Trial.create(name=f"trial_{i}", config={"x": i}, metrics=[metric]) for i in range(10)] history = History() for trial in trials: - with trial.begin(): - x = trial.config["x"] - report = trial.success(cost=x**2 - x*2 + 4) - history.add(report) + x = trial.config["x"] + report = trial.success(cost=x**2 - x*2 + 4) + trial.bucket.rmdir() # markdown-exec: hide + history.add(report) - filtered_history = history.filter(lambda report: report.metrics["cost"] < 10) + filtered_history = history.filter(lambda report: report.values["cost"] < 10) for report in filtered_history: - cost = report.metrics["cost"] + cost = report.values["cost"] print(f"{report.name}, {cost=}, {report}") ``` @@ -345,17 +339,17 @@ def groupby( from amltk.optimization import Trial, History, Metric metric = Metric("cost", minimize=True) - trials = [Trial(name=f"trial_{i}", config={"x": i}, metrics=[metric]) for i in range(10)] + trials = [Trial.create(name=f"trial_{i}", config={"x": i}, metrics=[metric]) for i in range(10)] history = History() for trial in trials: - with trial.begin(): - x = trial.config["x"] - if x % 2 == 0: - report = trial.fail(cost=1_000) - else: - report = trial.success(cost=x**2 - x*2 + 4) - history.add(report) + x = trial.config["x"] + if x % 2 == 0: + report = trial.fail(cost=1_000) + else: + report = trial.success(cost=x**2 - x*2 + 4) + trial.bucket.rmdir() # markdown-exec: hide + history.add(report) for status, history in history.groupby("status").items(): print(f"{status=}, {len(history)=}") @@ -367,16 +361,16 @@ def groupby( from amltk.optimization import Trial, History, Metric metric = Metric("cost", minimize=True) - trials = [Trial(name=f"trial_{i}", config={"x": i}, metrics=[metric]) for i in range(10)] + trials = [Trial.create(name=f"trial_{i}", config={"x": i}, metrics=[metric]) for i in range(10)] history = History() for trial in trials: - with trial.begin(): - x = trial.config["x"] - report = trial.fail(cost=x) - history.add(report) + x = trial.config["x"] + report = trial.fail(cost=x) + history.add(report) + trial.bucket.rmdir() # markdown-exec: hide - for below_5, history in history.groupby(lambda r: r.metrics["cost"] < 5).items(): + for below_5, history in history.groupby(lambda r: r.values["cost"] < 5).items(): print(f"{below_5=}, {len(history)=}") ``` @@ -401,11 +395,10 @@ def incumbents( self, key: Callable[[Trial.Report, Trial.Report], bool] | str, *, - sortby: Callable[[Trial.Report], Comparable] - | str = lambda report: report.time.end, + sortby: Callable[[Trial.Report], Comparable] | str = lambda r: r.reported_at, reverse: bool | None = None, ffill: bool = False, - ) -> list[Trial.Report]: + ) -> History: """Returns a trace of the incumbents, where only the report that is better than the previous best report is kept. @@ -413,21 +406,21 @@ def incumbents( from amltk.optimization import Trial, History, Metric metric = Metric("cost", minimize=True) - trials = [Trial(name=f"trial_{i}", config={"x": i}, metrics=[metric]) for i in range(10)] + trials = [Trial.create(name=f"trial_{i}", config={"x": i}, metrics=[metric]) for i in range(10)] history = History() for trial in trials: - with trial.begin(): - x = trial.config["x"] - report = trial.success(cost=x**2 - x*2 + 4) - history.add(report) + x = trial.config["x"] + report = trial.success(cost=x**2 - x*2 + 4) + history.add(report) + trial.bucket.rmdir() # markdown-exec: hide incumbents = ( history - .incumbents("cost", sortby=lambda r: r.time.end) + .incumbents("cost", sortby=lambda r: r.reported_at) ) for report in incumbents: - print(f"{report.metrics=}, {report.config=}") + print(f"{report.values=}, {report.config=}") ``` Args: @@ -460,33 +453,35 @@ def incumbents( case str(): metric = self.metrics[key] __op = operator.lt if metric.minimize else operator.gt # type: ignore - op = lambda r1, r2: __op(r1.metrics[key], r2.metrics[key]) + op = lambda r1, r2: __op(r1.values[key], r2.values[key]) case _: op = key sorted_reports = self.sortby(sortby, reverse=reverse) - return list(compare_accumulate(sorted_reports, op=op, ffill=ffill)) + return History.from_reports( + compare_accumulate(sorted_reports, op=op, ffill=ffill), + ) def sortby( self, key: Callable[[Trial.Report], Comparable] | str, *, reverse: bool | None = None, - ) -> list[Trial.Report]: + ) -> History: """Sorts the history by a key and returns a sorted History. ```python exec="true" source="material-block" result="python" title="sortby" hl_lines="15" from amltk.optimization import Trial, History, Metric metric = Metric("cost", minimize=True) - trials = [Trial(name=f"trial_{i}", config={"x": i}, metrics=[metric]) for i in range(10)] + trials = [Trial.create(name=f"trial_{i}", config={"x": i}, metrics=[metric]) for i in range(10)] history = History() for trial in trials: - with trial.begin(): - x = trial.config["x"] - report = trial.success(cost=x**2 - x*2 + 4) - history.add(report) + x = trial.config["x"] + report = trial.success(cost=x**2 - x*2 + 4) + history.add(report) + trial.bucket.rmdir() # markdown-exec: hide trace = ( history @@ -495,39 +490,47 @@ def sortby( ) for report in trace: - print(f"{report.metrics}, {report}") + print(f"{report.values}, {report}") ``` Args: key: The key to sort by. If given a str, it will sort by - the value of that key in the `.metrics` and also filter + the value of that key in the `.values` and also filter out anything that does not contain this key. reverse: Whether to sort in some given order. By default (`None`), if given a metric key, the reports with the best metric values will be sorted first. If given a `#!python Callable`, the reports with the smallest values will be sorted first. Using - `reverse=True` will always reverse this order, while - `reverse=False` will always preserve it. + `reverse=True/False` will apply to python's + [`sorted()`][sorted]. Returns: A sorted list of reports """ # noqa: E501 # If given a str, filter out anything that doesn't have that key if isinstance(key, str): - history = self.filter(lambda report: key in report.metric_names) - sort_key: Callable[[Trial.Report], Comparable] = lambda r: r.metrics[key] - reverse = ( - reverse if reverse is not None else (not self.metrics[key].minimize) - ) + history = self.filter(lambda report: key in report.values) + sort_key = lambda r: r.values[key] + reverse = reverse if reverse is not None else not self.metrics[key].minimize else: history = self sort_key = key + # Default is False reverse = False if reverse is None else reverse - return sorted(history.reports, key=sort_key, reverse=reverse) + return History.from_reports( + sorted(history.reports, key=sort_key, reverse=reverse), + ) + + @overload + def __getitem__(self, key: int | str) -> Trial.Report: + ... + + @overload + def __getitem__(self, key: slice) -> Trial.Report: + ... - @override def __getitem__( # type: ignore self, key: int | str | slice, diff --git a/src/amltk/optimization/metric.py b/src/amltk/optimization/metric.py index 9d5a462b..885b28d7 100644 --- a/src/amltk/optimization/metric.py +++ b/src/amltk/optimization/metric.py @@ -1,38 +1,25 @@ -"""A [`Metric`][amltk.optimization.Metric] to let optimizers know how to -handle numeric values properly. - -A `Metric` is defined by a `.name: str` and whether it is better to `.minimize: bool` -the metric. Further, you can specify `.bounds: tuple[lower, upper]` which can -help optimizers and other code know how to treat metrics. - -To easily convert between [`loss`][amltk.optimization.Metric.Value.loss], -[`score`][amltk.optimization.Metric.Value.score] of a -a value in a [`Metric.Value`][amltk.optimization.Metric.Value] object. +"""The metric definition.""" +from __future__ import annotations -If the metric is bounded, you can also get the -[`distance_to_optimal`][amltk.optimization.Metric.Value.distance_to_optimal] -which is the distance to the optimal value. +from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence +from dataclasses import dataclass, field +from enum import Enum +from typing import TYPE_CHECKING, Any, Generic, Literal, ParamSpec +from typing_extensions import Self, override -```python exec="true" source="material-block" result="python" -from amltk.optimization import Metric +import numpy as np -acc = Metric("accuracy", minimize=False, bounds=(0.0, 1.0)) +if TYPE_CHECKING: + from sklearn.metrics._scorer import _MultimetricScorer, _Scorer -acc_value = acc.as_value(0.9) -print(f"Cost: {acc_value.distance_to_optimal}") # Distance to optimal. -print(f"Loss: {acc_value.loss}") # Something that can be minimized -print(f"Score: {acc_value.score}") # Something that can be maximized -``` -""" -from __future__ import annotations +P = ParamSpec("P") -from dataclasses import dataclass, field -from typing_extensions import Self, override +SklearnResponseMethods = Literal["predict", "predict_proba", "decision_function"] @dataclass(frozen=True) -class Metric: +class Metric(Generic[P]): """A metric with a given name, optimal direction, and possible bounds.""" name: str @@ -44,6 +31,16 @@ class Metric: bounds: tuple[float, float] | None = field(kw_only=True, default=None) """The bounds of the metric, if any.""" + fn: Callable[P, float] | None = field(kw_only=True, default=None, compare=False) + """A function to attach to this metric to be used within a trial.""" + + class Comparison(str, Enum): + """The comparison between two values.""" + + BETTER = "better" + WORSE = "worse" + EQUAL = "equal" + def __post_init__(self) -> None: if self.bounds is not None: lower, upper = self.bounds @@ -65,12 +62,61 @@ def __post_init__(self) -> None: " Must be a valid Python identifier.", ) + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> float: + """Call the associated function with this metric.""" + if self.fn is None: + raise ValueError( + f"Metric {self.name} does not have a function to call." + " Please provide a function to `Metric(fn=...)` if you" + " want to call this metric like this.", + ) + return self.fn(*args, **kwargs) + + def as_scorer( + self, + *, + response_method: ( + SklearnResponseMethods | Sequence[SklearnResponseMethods] | None + ) = None, + **scorer_kwargs: Any, + ) -> _Scorer: + """Convert a metric to a sklearn scorer. + + Args: + response_method: The response method to use for the scorer. + This can be a single method or an iterable of methods. + scorer_kwargs: Additional keyword arguments to pass to the + scorer during the call. Forwards to [`sklearn.metrics.make_scorer`][]. + + Returns: + The sklearn scorer. + """ + from sklearn.metrics import get_scorer, make_scorer + + match self.fn: + case None: + try: + return get_scorer(self.name) + except ValueError as e: + raise ValueError( + f"Could not find scorer for {self.name}." + " Please provide a function to `Metric(fn=...)`.", + ) from e + case fn: + return make_scorer( + fn, + greater_is_better=not self.minimize, + response_method=response_method, + **scorer_kwargs, + ) + @override def __str__(self) -> str: parts = [self.name] if self.bounds is not None: parts.append(f"[{self.bounds[0]}, {self.bounds[1]}]") parts.append(f"({'minimize' if self.minimize else 'maximize'})") + return " ".join(parts) @classmethod @@ -108,95 +154,144 @@ def from_str(cls, s: str) -> Self: return cls(name=name, minimize=minimize, bounds=bounds) @property - def worst(self) -> Metric.Value: + def worst(self) -> float: """The worst possible value of the metric.""" - if self.bounds: - v = self.bounds[1] if self.minimize else self.bounds[0] - return self.as_value(v) + if self.bounds is not None: + return self.bounds[1] if self.minimize else self.bounds[0] - v = float("inf") if self.minimize else float("-inf") - return self.as_value(v) + return float("inf") if self.minimize else float("-inf") @property - def optimal(self) -> Metric.Value: + def optimal(self) -> float: """The optimal value of the metric.""" if self.bounds: - v = self.bounds[0] if self.minimize else self.bounds[1] - return self.as_value(v) - v = float("-inf") if self.minimize else float("inf") - return self.as_value(v) - - def as_value(self, value: float | int) -> Metric.Value: - """Convert a value to an metric value.""" - return Metric.Value(metric=self, value=float(value)) - - def __call__(self, value: float | int) -> Metric.Value: - """Convert a value to an metric value.""" - return Metric.Value(metric=self, value=float(value)) - - @dataclass(frozen=True, order=True) - class Value: - """A recorded value of an metric.""" - - metric: Metric = field(compare=False, hash=True) - """The metric.""" - - value: float = field(compare=True, hash=True) - """The value of the metric.""" - - @property - def minimize(self) -> bool: - """Whether to minimize or maximize the metric.""" - return self.metric.minimize - - @property - def bounds(self) -> tuple[float, float] | None: - """Whether to minimize or maximize the metric.""" - return self.metric.bounds - - @property - def name(self) -> str: - """The name of the metric.""" - return self.metric.name - - @property - def loss(self) -> float: - """Convert a value to a loss.""" - if self.minimize: - return float(self.value) - return -float(self.value) - - @property - def score(self) -> float: - """Convert a value to a score.""" - if self.minimize: - return -float(self.value) - return float(self.value) - - @property - def distance_to_optimal(self) -> float | None: - """The distance to the optimal value, using the bounds if possible.""" - match self.bounds: - case None: - return None - case (lower, upper) if lower <= self.value <= upper: - if self.minimize: - return abs(self.value - lower) - return abs(self.value - upper) - case (lower, upper): - raise ValueError(f"Value {self.value} is not within {self.bounds=}") - - return None - - def __float__(self) -> float: - """Convert a value to a float.""" - return float(self.value) - - @override - def __eq__(self, __value: object) -> bool: - """Check if two values are equal.""" - if isinstance(__value, Metric.Value): - return self.value == __value.value - if isinstance(__value, float | int): - return self.value == float(__value) - return NotImplemented + return self.bounds[0] if self.minimize else self.bounds[1] + + return float("-inf") if self.minimize else float("inf") + + def distance_to_optimal(self, v: float) -> float: + """The distance to the optimal value, using the bounds if possible.""" + match self.bounds: + case None: + raise ValueError( + f"Metric {self.name} is unbounded, can not compute distance" + " to optimal.", + ) + case (lower, upper) if lower <= v <= upper: + if self.minimize: + return abs(v - lower) + return abs(v - upper) + case (lower, upper): + raise ValueError(f"Value {v} is not within {self.bounds=}") + case _: + raise ValueError(f"Invalid {self.bounds=}") + + def normalized_loss(self, v: float) -> float: + """The normalized loss of a value if possible. + + If both sides of the bounds are finite, we can normalize the value + to be between 0 and 1. + """ + match self.bounds: + # If both sides are finite, we can 0-1 normalize + case (lower, upper) if not np.isinf(lower) and not np.isinf(upper): + cost = (v - lower) / (upper - lower) + cost = 1 - cost if self.minimize is False else cost + # No bounds or one unbounded bound, we can't normalize + case _: + cost = v if self.minimize else -v + + return cost + + def loss(self, v: float, /) -> float: + """Convert a value to a loss.""" + return float(v) if self.minimize else -float(v) + + def score(self, v: float, /) -> float: + """Convert a value to a score.""" + return -float(v) if self.minimize else float(v) + + def compare(self, v1: float, v2: float) -> Metric.Comparison: + """Check if `v1` is better than `v2`.""" + minimize = self.minimize + if v1 == v2: + return Metric.Comparison.EQUAL + if v1 > v2: + return Metric.Comparison.WORSE if minimize else Metric.Comparison.BETTER + + # v1 < v2 + return Metric.Comparison.BETTER if minimize else Metric.Comparison.WORSE + + +@dataclass(frozen=True, kw_only=True) +class MetricCollection(Mapping[str, Metric]): + """A collection of metrics.""" + + metrics: Mapping[str, Metric] = field(default_factory=dict) + """The metrics in this collection.""" + + @override + def __getitem__(self, key: str) -> Metric: + return self.metrics[key] + + @override + def __len__(self) -> int: + return len(self.metrics) + + @override + def __iter__(self) -> Iterator[str]: + return iter(self.metrics) + + def as_sklearn_scorer( + self, + *, + response_methods: ( + Mapping[str, SklearnResponseMethods | Sequence[SklearnResponseMethods]] + | None + ) = None, + scorer_kwargs: Mapping[str, Mapping[str, Any]] | None = None, + raise_exc: bool = True, + ) -> _MultimetricScorer: + """Convert this collection to a sklearn scorer.""" + from sklearn.metrics._scorer import _MultimetricScorer + + rms = response_methods or {} + skwargs = scorer_kwargs or {} + + scorers = { + k: v.as_scorer(response_method=rms.get(k), **skwargs.get(k, {})) + for k, v in self.items() + } + return _MultimetricScorer(scorers=scorers, raise_exc=raise_exc) + + def optimums(self) -> Mapping[str, float]: + """The optimums of the metrics.""" + return {k: v.optimal for k, v in self.items()} + + def worsts(self) -> Mapping[str, float]: + """The worsts of the metrics.""" + return {k: v.worst for k, v in self.items()} + + @classmethod + def from_empty(cls) -> MetricCollection: + """Create an empty metric collection.""" + return cls(metrics={}) + + @classmethod + def from_collection( + cls, + metrics: Metric | Iterable[Metric] | Mapping[str, Metric], + ) -> MetricCollection: + """Create a metric collection from an iterable of metrics.""" + match metrics: + case Metric(): + return cls(metrics={metrics.name: metrics}) + case Mapping(): + return MetricCollection(metrics={m.name: m for m in metrics.values()}) + case Iterable(): + return cls(metrics={m.name: m for m in metrics}) # type: ignore + case _: + raise TypeError( + f"Expected a Metric, Iterable[Metric], or Mapping[str, Metric]." + f" Got {type(metrics)} instead.", + ) diff --git a/src/amltk/optimization/optimizer.py b/src/amltk/optimization/optimizer.py index 309a182f..44bb546b 100644 --- a/src/amltk/optimization/optimizer.py +++ b/src/amltk/optimization/optimizer.py @@ -15,24 +15,40 @@ """ from __future__ import annotations +import logging from abc import abstractmethod -from collections.abc import Callable, Sequence +from collections.abc import Callable, Iterable, Iterator, Sequence from datetime import datetime -from typing import TYPE_CHECKING, Any, Concatenate, Generic, ParamSpec, TypeVar +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + Concatenate, + Generic, + ParamSpec, + Protocol, + TypeVar, + overload, +) +from typing_extensions import Self from more_itertools import all_unique +from amltk.optimization.metric import MetricCollection from amltk.store.paths.path_bucket import PathBucket if TYPE_CHECKING: from amltk.optimization.metric import Metric from amltk.optimization.trial import Trial from amltk.pipeline import Node + from amltk.types import Seed I = TypeVar("I") # noqa: E741 P = ParamSpec("P") ParserOutput = TypeVar("ParserOutput") +logger = logging.getLogger(__name__) + class Optimizer(Generic[I]): """An optimizer protocol. @@ -41,7 +57,7 @@ class Optimizer(Generic[I]): `tell` to inform the optimizer of the report from that trial. """ - metrics: Sequence[Metric] + metrics: MetricCollection """The metrics to optimize.""" bucket: PathBucket @@ -66,7 +82,7 @@ def __init__( f"Got {metrics} with names {[metric.name for metric in metrics]}", ) - self.metrics = metrics + self.metrics = MetricCollection.from_collection(metrics) self.bucket = ( bucket if bucket is not None @@ -81,10 +97,23 @@ def tell(self, report: Trial.Report[I]) -> None: report: The report for a trial """ + @overload + @abstractmethod + def ask(self, n: int) -> Iterable[Trial[I]]: + ... + + @overload @abstractmethod - def ask(self) -> Trial[I]: + def ask(self, n: None = None) -> Trial[I]: + ... + + @abstractmethod + def ask(self, n: int | None = None) -> Trial[I] | Iterable[Trial[I]]: """Ask the optimizer for a trial to evaluate. + Args: + n: The number of trials to ask for. If `None`, ask for a single trial. + Returns: A config to sample. """ @@ -102,3 +131,91 @@ def preferred_parser( """ return None + + @classmethod + @abstractmethod + def create( + cls, + *, + space: Node, + metrics: Metric | Sequence[Metric], + bucket: str | Path | PathBucket | None = None, + seed: Seed | None = None, + ) -> Self: + """Create this optimizer. + + !!! note + + Subclasses should override this with more specific configuration + but these arguments should be all that's necessary to create the optimizer. + + Args: + space: The space to optimize over. + bucket: The bucket for where to store things related to the trial. + metrics: The metrics to optimize. + seed: The seed to use for the optimizer. + + Returns: + The optimizer. + """ + + class CreateSignature(Protocol): + """A Protocol which defines the keywords required to create an + optimizer with deterministic behavior at a desired location. + + This protocol matches the `Optimizer.create` classmethod, however we also + allow any function which accepts the keyword arguments to create an + Optimizer. + """ + + def __call__( + self, + *, + space: Node, + metrics: Metric | Sequence[Metric], + bucket: PathBucket | None = None, + seed: Seed | None = None, + ) -> Optimizer: + """A function which creates an optimizer for node.optimize should + accept the following keyword arguments. + + Args: + space: The node to optimize + metrics: The metrics to optimize + bucket: The bucket to store the results in + seed: The seed to use for the optimization + """ + ... + + @classmethod + def _get_known_importable_optimizer_classes(cls) -> Iterator[type[Optimizer]]: + """Get all developer known optimizer classes. This is used for defaults. + + Do not rely on this functionality and prefer to give concrete optimizers to + functionality requiring one. This is intended for convenience of particular + quickstart methods. + """ + # NOTE: We can't use the `Optimizer.__subclasses__` method as the optimizers + # are not imported by any other module initially and so they do no exist + # until imported. Hence this manual iteration. For now, we be explicit and + # only if the optimizer list grows should we consider dynamic importing. + try: + from amltk.optimization.optimizers.smac import SMACOptimizer + + yield SMACOptimizer + except ImportError as e: + logger.debug("Failed to import SMACOptimizer", exc_info=e) + + try: + from amltk.optimization.optimizers.optuna import OptunaOptimizer + + yield OptunaOptimizer + except ImportError as e: + logger.debug("Failed to import OptunaOptimizer", exc_info=e) + + try: + from amltk.optimization.optimizers.neps import NEPSOptimizer + + yield NEPSOptimizer + except ImportError as e: + logger.debug("Failed to import NEPSOptimizer", exc_info=e) diff --git a/src/amltk/optimization/optimizers/neps.py b/src/amltk/optimization/optimizers/neps.py index b735712d..e2304be8 100644 --- a/src/amltk/optimization/optimizers/neps.py +++ b/src/amltk/optimization/optimizers/neps.py @@ -63,14 +63,16 @@ def target_function(trial: Trial, pipeline: Pipeline) -> Trial.Report: X_train, X_test, y_train, y_test = train_test_split(X, y) clf = pipeline.configure(trial.config).build("sklearn") - with trial.begin(): - clf.fit(X_train, y_train) - y_pred = clf.predict(X_test) - accuracy = accuracy_score(y_test, y_pred) - loss = 1 - accuracy - return trial.success(loss=loss, accuracy=accuracy) - - return trial.fail() + with trial.profile("trial"): + try: + clf.fit(X_train, y_train) + y_pred = clf.predict(X_test) + accuracy = accuracy_score(y_test, y_pred) + loss = 1 - accuracy + return trial.success(loss=loss, accuracy=accuracy) + except Exception as e: + return trial.fail(e) + from amltk._doc import make_picklable; make_picklable(target_function) # markdown-exec: hide pipeline = Component(RandomForestClassifier, space={"n_estimators": (10, 100)}) @@ -119,19 +121,18 @@ def add_to_history(_, report: Trial.Report): import logging import shutil -from collections.abc import Mapping, Sequence +from collections.abc import Iterable, Mapping, Sequence from copy import deepcopy from dataclasses import dataclass from datetime import datetime from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Any, Protocol +from typing import TYPE_CHECKING, Any, Protocol, overload from typing_extensions import override import metahyper.api from ConfigSpace import ConfigurationSpace from metahyper import instance_from_map -from more_itertools import first_true from neps.optimizers import SearcherMapping from neps.search_spaces.parameter import Parameter from neps.search_spaces.search_space import SearchSpace, pipeline_space_from_configspace @@ -140,6 +141,7 @@ def add_to_history(_, report: Trial.Report): from amltk.optimization import Optimizer, Trial from amltk.pipeline import Node from amltk.pipeline.parsers.configspace import parser as configspace_parser +from amltk.profiling import Timer from amltk.store import PathBucket if TYPE_CHECKING: @@ -249,8 +251,9 @@ def __init__( self, *, space: SearchSpace, - loss_metric: Metric, + loss_metric: Metric | Sequence[Metric], cost_metric: Metric | None = None, + time_profile: str | None = None, optimizer: BaseOptimizer, working_dir: Path, seed: Seed | None = None, @@ -262,6 +265,8 @@ def __init__( space: The space to use. loss_metric: The metric to optimize. cost_metric: The cost metric to use. Only certain NePs optimizers support + time_profile: The profile from which to get the timing of + the trial from. optimizer: The optimizer to use. seed: The seed to use for the trials (and not optimizers). working_dir: The directory to use for the trials. @@ -289,6 +294,7 @@ def __init__( self.working_dir = working_dir self.loss_metric = loss_metric self.cost_metric = cost_metric + self.time_profile = time_profile self.optimizer_state_file = self.working_dir / "optimizer_state.yaml" self.base_result_directory = self.working_dir / "results" @@ -297,8 +303,9 @@ def __init__( self.working_dir.mkdir(parents=True, exist_ok=True) self.base_result_directory.mkdir(parents=True, exist_ok=True) + @override @classmethod - def create( # noqa: PLR0913 + def create( cls, *, space: ( @@ -307,8 +314,9 @@ def create( # noqa: PLR0913 | Mapping[str, ConfigurationSpace | Parameter] | Node ), - metrics: Metric, + metrics: Metric | Sequence[Metric], cost_metric: Metric | None = None, + time_profile: str | None = None, bucket: PathBucket | str | Path | None = None, searcher: str | BaseOptimizer = "default", working_dir: str | Path = "neps", @@ -330,6 +338,7 @@ def create( # noqa: PLR0913 cost_metric: The cost metric to use. Only certain NePs optimizers support this. + time_profile: What profiler to take the time end from. seed: The seed to use for the trials. !!! warning @@ -387,17 +396,33 @@ def create( # noqa: PLR0913 seed=seed, loss_metric=metrics, cost_metric=cost_metric, + time_profile=time_profile, optimizer=searcher, working_dir=working_dir, ) + @overload + def ask(self, n: int) -> Iterable[Trial[NEPSTrialInfo]]: + ... + + @overload + def ask(self, n: None = None) -> Trial[NEPSTrialInfo]: + ... + @override - def ask(self) -> Trial[NEPSTrialInfo]: + def ask( + self, + n: int | None = None, + ) -> Trial[NEPSTrialInfo] | Iterable[Trial[NEPSTrialInfo]]: """Ask the optimizer for a new config. Returns: The trial info for the new config. """ + # TODO: Ask neps people if there's a good way to batch sample rather than 1 by 1 + if n is not None: + return (self.ask(n=None) for _ in range(n)) + with self.optimizer.using_state(self.optimizer_state_file, self.serializer): ( config_id, @@ -433,13 +458,13 @@ def ask(self) -> Trial[NEPSTrialInfo]: case cost_metric: metrics = [self.loss_metric, cost_metric] - trial = Trial( + trial = Trial.create( name=info.name, config=info.config, info=info, seed=self.seed, - bucket=self.bucket, - metrics=metrics, + bucket=self.bucket / info.name, + metrics={m.name: m for m in metrics}, ) logger.debug(f"Asked for trial {trial.name}") return trial @@ -456,30 +481,29 @@ def tell(self, report: Trial.Report[NEPSTrialInfo]) -> None: assert info is not None # Get a metric result - metric_result = first_true( - report.metric_values, - pred=lambda value: value.metric.name == self.loss_metric.name, - default=self.loss_metric.worst, - ) - - # Convert metric result to a minimization loss - neps_loss: float - if (_loss := metric_result.distance_to_optimal) is not None: - neps_loss = _loss + loss = report.values.get(self.loss_metric.name, self.loss_metric.worst) + normalized_loss = self.loss_metric.normalized_loss(loss) + + result: dict[str, Any] = {"loss": normalized_loss} + + metadata: dict[str, Any] + if ( + self.time_profile is not None + and (profile := report.profiles.get(self.time_profile)) is not None + ): + if profile.time.unit is not Timer.Unit.SECONDS: + raise NotImplementedError( + "NePs only supports seconds as the time unit", + ) + metadata = {"time_end": profile.time.end} else: - neps_loss = metric_result.loss - - result: dict[str, Any] = {"loss": neps_loss} - metadata: dict[str, Any] = {"time_end": report.time.end} + # TODO: I'm not sure "time_end" is requried but probably for some optimizers + metadata = {} if self.cost_metric is not None: - cost_metric: Metric = self.cost_metric - _cost = first_true( - report.metric_values, - pred=lambda value: value.metric.name == cost_metric.name, - default=self.cost_metric.worst, - ) - cost = _cost.value + cost = report.values.get(self.cost_metric.name, self.cost_metric.worst) + + # We don't normalize here result["cost"] = cost # If it's a budget aware optimizer diff --git a/src/amltk/optimization/optimizers/optuna.py b/src/amltk/optimization/optimizers/optuna.py index 011317a8..8128f7f7 100644 --- a/src/amltk/optimization/optimizers/optuna.py +++ b/src/amltk/optimization/optimizers/optuna.py @@ -52,12 +52,15 @@ def target_function(trial: Trial, pipeline: Pipeline) -> Trial.Report: X_train, X_test, y_train, y_test = train_test_split(X, y) clf = pipeline.configure(trial.config).build("sklearn") - with trial.begin(): - clf.fit(X_train, y_train) - y_pred = clf.predict(X_test) - return trial.success(accuracy=accuracy_score(y_test, y_pred)) + with trial.profile("trial"): + try: + clf.fit(X_train, y_train) + y_pred = clf.predict(X_test) + accuracy = accuracy_score(y_test, y_pred) + return trial.success(accuracy=accuracy) + except Exception as e: + return trial.fail(e) - return trial.fail() from amltk._doc import make_picklable; make_picklable(target_function) # markdown-exec: hide pipeline = Component(RandomForestClassifier, space={"n_estimators": (10, 100)}) @@ -99,16 +102,16 @@ def add_to_history(_, report: Trial.Report): Sorry! """ # noqa: E501 + from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Iterable, Sequence from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, overload from typing_extensions import Self, override import optuna -from more_itertools import first_true from optuna.samplers import BaseSampler, NSGAIISampler, TPESampler from optuna.study import Study, StudyDirection from optuna.trial import ( @@ -189,9 +192,9 @@ def __init__( super().__init__(bucket=bucket, metrics=metrics) self.seed = amltk.randomness.as_int(seed) self.study = study - self.metrics = metrics self.space = space + @override @classmethod def create( cls, @@ -242,19 +245,6 @@ def create( case bucket: bucket = bucket # noqa: PLW0127 - match metrics: - case Metric(minimize=minimize): - direction = ( - StudyDirection.MINIMIZE if minimize else StudyDirection.MAXIMIZE - ) - study = optuna.create_study(direction=direction, **kwargs) - case metrics: - directions = [ - StudyDirection.MINIMIZE if m.minimize else StudyDirection.MAXIMIZE - for m in metrics - ] - study = optuna.create_study(directions=directions, **kwargs) - if sampler is None: sampler_seed = amltk.randomness.as_int(seed) match metrics: @@ -263,27 +253,56 @@ def create( case metrics: sampler = NSGAIISampler(seed=sampler_seed) # from `create_study()` - return cls(study=study, metrics=metrics, space=space, bucket=bucket, seed=seed) + match metrics: + case Metric(minimize=minimize): + direction = [ + StudyDirection.MINIMIZE if minimize else StudyDirection.MAXIMIZE, + ] + case metrics: + direction = [ + StudyDirection.MINIMIZE if m.minimize else StudyDirection.MAXIMIZE + for m in metrics + ] + + return cls( + study=optuna.create_study(directions=direction, sampler=sampler, **kwargs), + metrics=metrics, + space=space, + bucket=bucket, + seed=seed, + ) + + @overload + def ask(self, n: int) -> Iterable[Trial[OptunaTrial]]: + ... + + @overload + def ask(self, n: None = None) -> Trial[OptunaTrial]: + ... @override - def ask(self) -> Trial[OptunaTrial]: + def ask( + self, + n: int | None = None, + ) -> Trial[OptunaTrial] | Iterable[Trial[OptunaTrial]]: """Ask the optimizer for a new config. Returns: The trial info for the new config. """ - optuna_trial: optuna.Trial = self.study.ask(self.space) + if n is not None: + return (self.ask(n=None) for _ in range(n)) + optuna_trial = self.space.get_trial(self.study) config = optuna_trial.params trial_number = optuna_trial.number unique_name = f"{trial_number=}" - metrics = [self.metrics] if isinstance(self.metrics, Metric) else self.metrics - return Trial( + return Trial.create( name=unique_name, seed=self.seed, config=config, info=optuna_trial, - bucket=self.bucket, - metrics=metrics, + bucket=self.bucket / unique_name, + metrics=self.metrics, ) @override @@ -301,31 +320,15 @@ def tell(self, report: Trial.Report[OptunaTrial]) -> None: # NOTE: Can't tell any values if the trial crashed or failed self.study.tell(trial=trial, state=TrialState.FAIL) case Trial.Status.SUCCESS: - match self.metrics: - case [metric]: - metric_value: Metric.Value = first_true( - report.metric_values, - pred=lambda m: m.metric == metric, - default=metric.worst, - ) - self.study.tell( - trial=trial, - state=TrialState.COMPLETE, - values=metric_value.value, - ) - case metrics: - # NOTE: We need to make sure that there sorted in the order - # that Optuna expects, with any missing metrics filled in - _lookup = {v.metric.name: v for v in report.metric_values} - values = [ - _lookup.get(metric.name, metric.worst).value - for metric in metrics - ] - self.study.tell( - trial=trial, - state=TrialState.COMPLETE, - values=values, - ) + values = { + name: report.values.get(name, metric.worst) + for name, metric in self.metrics.items() + } + v: list[float] = list(values.values()) + if len(v) == 1: + self.study.tell(trial=trial, state=TrialState.COMPLETE, values=v[0]) + else: + self.study.tell(trial=trial, state=TrialState.COMPLETE, values=v) @override @classmethod diff --git a/src/amltk/optimization/optimizers/random_search.py b/src/amltk/optimization/optimizers/random_search.py new file mode 100644 index 00000000..8ef3dd60 --- /dev/null +++ b/src/amltk/optimization/optimizers/random_search.py @@ -0,0 +1,150 @@ +"""An optimizer that uses ConfigSpace for random search.""" +from __future__ import annotations + +from collections.abc import Iterable, Sequence +from datetime import datetime +from pathlib import Path +from typing import TYPE_CHECKING, Literal, overload +from typing_extensions import override + +from amltk.optimization import Metric, Optimizer, Trial +from amltk.pipeline import Node +from amltk.randomness import as_int, randuid +from amltk.store import PathBucket + +if TYPE_CHECKING: + from typing_extensions import Self + + from ConfigSpace import ConfigurationSpace + + from amltk.types import Seed + + +class RandomSearch(Optimizer[None]): + """An optimizer that uses ConfigSpace for random search.""" + + def __init__( + self, + *, + space: ConfigurationSpace, + bucket: PathBucket | None = None, + metrics: Metric | Sequence[Metric], + seed: Seed | None = None, + ) -> None: + """Initialize the optimizer. + + Args: + space: The search space to search over. + bucket: The bucket given to trials generated by this optimizer. + metrics: The metrics to optimize. Unused for RandomSearch. + seed: The seed to use for the optimization. + """ + metrics = metrics if isinstance(metrics, Sequence) else [metrics] + super().__init__(metrics=metrics, bucket=bucket) + seed = as_int(seed) + space.seed(seed) + self._counter = 0 + self.seed = seed + self.space = space + + @override + @classmethod + def create( + cls, + *, + space: ConfigurationSpace | Node, + metrics: Metric | Sequence[Metric], + bucket: PathBucket | str | Path | None = None, + seed: Seed | None = None, + ) -> Self: + """Create a random search optimizer. + + Args: + space: The node to optimize + metrics: The metrics to optimize + bucket: The bucket to store the results in + seed: The seed to use for the optimization + """ + seed = as_int(seed) + match bucket: + case None: + bucket = PathBucket( + f"{cls.__name__}-{datetime.now().isoformat()}", + ) + case str() | Path(): + bucket = PathBucket(bucket) + case bucket: + bucket = bucket # noqa: PLW0127 + + if isinstance(space, Node): + space = space.search_space(parser=cls.preferred_parser()) + + return cls( + space=space, + seed=seed, + bucket=bucket, + metrics=metrics, + ) + + @overload + def ask(self, n: int) -> Iterable[Trial[None]]: + ... + + @overload + def ask(self, n: None = None) -> Trial[None]: + ... + + @override + def ask( + self, + n: int | None = None, + ) -> Trial[None] | Iterable[Trial[None]]: + """Ask the optimizer for a new config. + + Args: + n: The number of configs to ask for. If `None`, ask for a single config. + + + Returns: + The trial info for the new config. + """ + if n is None: + configs = [self.space.sample_configuration()] + else: + configs = self.space.sample_configuration(n) + + trials: list[Trial[None]] = [] + for config in configs: + self._counter += 1 + randuid_seed = self.seed + self._counter + unique_name = f"trial-{randuid(4, seed=randuid_seed)}-{self._counter}" + trial: Trial[None] = Trial.create( + name=unique_name, + config=dict(config), + info=None, + seed=self.seed, + bucket=self.bucket / unique_name, + metrics=self.metrics, + ) + trials.append(trial) + + if n is None: + return trials[0] + + return trials + + @override + def tell(self, report: Trial.Report[None]) -> None: + """Tell the optimizer about the result of a trial. + + Does nothing for random search. + + Args: + report: The report of the trial. + """ + + @override + @classmethod + def preferred_parser(cls) -> Literal["configspace"]: + """The preferred parser for this optimizer.""" + return "configspace" diff --git a/src/amltk/optimization/optimizers/smac.py b/src/amltk/optimization/optimizers/smac.py index b2ac2b8c..c02f7045 100644 --- a/src/amltk/optimization/optimizers/smac.py +++ b/src/amltk/optimization/optimizers/smac.py @@ -46,11 +46,14 @@ def target_function(trial: Trial, pipeline: Node) -> Trial.Report: X_train, X_test, y_train, y_test = train_test_split(X, y) clf = pipeline.configure(trial.config).build("sklearn") - with trial.begin(): - clf.fit(X_train, y_train) - y_pred = clf.predict(X_test) - accuracy = accuracy_score(y_test, y_pred) - return trial.success(accuracy=accuracy) + with trial.profile("trial"): + try: + clf.fit(X_train, y_train) + y_pred = clf.predict(X_test) + accuracy = accuracy_score(y_test, y_pred) + return trial.success(accuracy=accuracy) + except Exception as e: + return trial.fail(e) return trial.fail() from amltk._doc import make_picklable; make_picklable(target_function) # markdown-exec: hide @@ -91,14 +94,12 @@ def add_to_history(_, report: Trial.Report): from __future__ import annotations import logging -from collections.abc import Mapping, Sequence +from collections.abc import Iterable, Mapping, Sequence from datetime import datetime from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, overload from typing_extensions import override -import numpy as np -from more_itertools import first_true from pynisher import MemoryLimitException, TimeoutException from smac import HyperparameterOptimizationFacade, MultiFidelityFacade, Scenario from smac.runhistory import ( @@ -109,6 +110,7 @@ def add_to_history(_, report: Trial.Report): from amltk.optimization import Metric, Optimizer, Trial from amltk.pipeline import Node +from amltk.profiling import Profile from amltk.randomness import as_int from amltk.store import PathBucket @@ -140,6 +142,7 @@ def __init__( bucket: PathBucket | None = None, metrics: Metric | Sequence[Metric], fidelities: Mapping[str, FidT] | None = None, + time_profile: str | None = None, ) -> None: """Initialize the optimizer. @@ -148,17 +151,21 @@ def __init__( bucket: The bucket given to trials generated by this optimizer. metrics: The metrics to optimize. fidelities: The fidelities to use, if any. + time_profile: The profile to use to get time information to the + optimizer. Must use `trial.profile(time_profile)` in your + target function then. """ # We need to very that the scenario is correct incase user pass in # their own facade construction - assert self.crash_cost(metrics) == facade.scenario.crash_cost + assert list(self.crash_costs(metrics).values()) == facade.scenario.crash_cost metrics = metrics if isinstance(metrics, Sequence) else [metrics] super().__init__(metrics=metrics, bucket=bucket) self.facade = facade - self.metrics = metrics self.fidelities = fidelities + self.time_profile = time_profile + @override @classmethod def create( cls, @@ -166,6 +173,7 @@ def create( space: ConfigurationSpace | Node, metrics: Metric | Sequence[Metric], bucket: PathBucket | str | Path | None = None, + time_profile: str | None = None, deterministic: bool = True, seed: Seed | None = None, fidelities: Mapping[str, FidT] | None = None, @@ -179,6 +187,9 @@ def create( space: The config space to optimize. metrics: The metrics to optimize. bucket: The bucket given to trials generated by this optimizer. + time_profile: The profile to use to get time information to the + optimizer. Must use `trial.profile(time_profile)` in your + target function then. deterministic: Whether the function your optimizing is deterministic, given a seed and config. seed: The seed to use for the optimizer. @@ -224,7 +235,7 @@ def create( seed=seed, min_budget=min_budget, max_budget=max_budget, - crash_cost=cls.crash_cost(metrics), + crash_cost=list(cls.crash_costs(metrics).values()), ) facade_cls = MultiFidelityFacade else: @@ -234,7 +245,7 @@ def create( output_directory=bucket.path / "smac3_output", deterministic=deterministic, objectives=metric_names, - crash_cost=cls.crash_cost(metrics), + crash_cost=list(cls.crash_costs(metrics).values()), ) facade_cls = HyperparameterOptimizationFacade @@ -247,15 +258,39 @@ def create( scenario=scenario, ), ) - return cls(facade=facade, fidelities=fidelities, bucket=bucket, metrics=metrics) + return cls( + facade=facade, + fidelities=fidelities, + bucket=bucket, + metrics=metrics, + time_profile=time_profile, + ) + + @overload + def ask(self, n: int) -> Iterable[Trial[SMACTrialInfo]]: + ... + + @overload + def ask(self, n: None = None) -> Trial[SMACTrialInfo]: + ... @override - def ask(self) -> Trial[SMACTrialInfo]: + def ask( + self, + n: int | None = None, + ) -> Trial[SMACTrialInfo] | Iterable[Trial[SMACTrialInfo]]: """Ask the optimizer for a new config. + Args: + n: The number of configs to ask for. If `None`, ask for a single config. + + Returns: The trial info for the new config. """ + if n is not None: + return (self.ask(n=None) for _ in range(n)) + smac_trial_info = self.facade.ask() config = smac_trial_info.config budget = smac_trial_info.budget @@ -273,13 +308,13 @@ def ask(self) -> Trial[SMACTrialInfo]: config_id = self.facade.runhistory.config_ids[config] unique_name = f"{config_id=}_{seed=}_{budget=}_{instance=}" - trial: Trial[SMACTrialInfo] = Trial( + trial: Trial[SMACTrialInfo] = Trial.create( name=unique_name, config=dict(config), info=smac_trial_info, seed=seed, fidelities=trial_fids, - bucket=self.bucket, + bucket=self.bucket / unique_name, metrics=self.metrics, ) logger.debug(f"Asked for trial {trial.name}") @@ -294,50 +329,60 @@ def tell(self, report: Trial.Report[SMACTrialInfo]) -> None: """ assert report.trial.info is not None - cost: float | list[float] - match self.metrics: - case [metric]: # Single obj - val: Metric.Value = first_true( - report.metric_values, - pred=lambda m: m.metric == metric, - default=metric.worst, - ) - cost = self.cost(val) - case metrics: - # NOTE: We need to make sure that there sorted in the order - # that SMAC expects, with any missing metrics filled in - _lookup = {v.metric.name: v for v in report.metric_values} - cost = [ - self.cost(_lookup.get(metric.name, metric.worst)) - for metric in metrics - ] - - logger.debug(f"Telling report for trial {report.trial.name}") + costs: dict[str, float] = {} + for name, metric in self.metrics.items(): + value = report.values.get(metric.name) + if value is None: + if report.status == Trial.Status.SUCCESS: + raise ValueError( + f"Could not find metric '{metric.name}' in report values." + " Make sure you use `trial.success()` in your target function." + " So that we can report the metric value to SMAC.", + ) + value = metric.worst + + costs[name] = metric.normalized_loss(value) + + logger.debug(f"Reporting for trial {report.trial.name} with costs: {costs}") + + cost = next(iter(costs.values())) if len(costs) == 1 else list(costs.values()) # If we're successful, get the cost and times and report them params: dict[str, Any] match report.status: - case Trial.Status.SUCCESS: - params = { - "time": report.time.duration, - "starttime": report.time.start, - "endtime": report.time.end, - "cost": cost, - "status": StatusType.SUCCESS, - } - case Trial.Status.FAIL: - params = { - "time": report.time.duration, - "starttime": report.time.start, - "endtime": report.time.end, - "cost": cost, - "status": StatusType.CRASHED, - } + case Trial.Status.SUCCESS | Trial.Status.FAIL: + smac_status = ( + StatusType.SUCCESS + if report.status == Trial.Status.SUCCESS + else StatusType.CRASHED + ) + params = {"cost": cost, "status": smac_status} case Trial.Status.CRASHED | Trial.Status.UNKNOWN: - params = { - "cost": cost, - "status": StatusType.CRASHED, - } + params = {"cost": cost, "status": StatusType.CRASHED} + + if self.time_profile: + profile = report.trial.profiles.get(self.time_profile) + match profile: + # If it was a success, we kind of expect there to have been this + # timing. Otherwise, for failure we don't necessarily expect it. + case None if report.status in Trial.Status.SUCCESS: + raise ValueError( + f"Could not find profile '{self.time_profile}' in trial" + " as specified by `time_profile` during construction." + " Make sure you use `with trial.profile(time_profile):`" + " in your target function. So that we can report the" + " timing information to SMAC.", + ) + case Profile.Interval(time=timer): + params.update( + { + "time": timer.duration, + "starttime": timer.start, + "endtime": timer.end, + }, + ) + case None: + pass match report.exception: case None: @@ -368,32 +413,18 @@ def preferred_parser(cls) -> Literal["configspace"]: """The preferred parser for this optimizer.""" return "configspace" - @overload - @classmethod - def crash_cost(cls, metric: Metric) -> float: - ... - - @overload @classmethod - def crash_cost(cls, metric: Sequence[Metric]) -> list[float]: - ... - - @classmethod - def crash_cost(cls, metric: Metric | Sequence[Metric]) -> float | list[float]: + def crash_costs(cls, metric: Metric | Iterable[Metric]) -> dict[str, float]: """Get the crash cost for a metric for SMAC.""" match metric: - case Metric(bounds=(lower, upper)): # Bounded metrics - return abs(upper - lower) - case Metric(): # Unbounded metric - return np.inf - case metrics: - return [cls.crash_cost(m) for m in metrics] - - @classmethod - def cost(cls, value: Metric.Value) -> float: - """Get the cost for a metric value for SMAC.""" - match value.distance_to_optimal: - case None: # If we can't compute the distance, use the loss - return value.loss - case distance: # If we can compute the distance, use that - return distance + case Metric(): + return {metric.name: metric.normalized_loss(metric.worst)} + case Iterable(): + return { + metric.name: metric.normalized_loss(metric.worst) + for metric in metric + } + case _: + raise TypeError( + f"Expected a Metric, Mapping, or Iterable of Metrics. Got {metric}", + ) diff --git a/src/amltk/optimization/trial.py b/src/amltk/optimization/trial.py index 806831ac..4ceeeed5 100644 --- a/src/amltk/optimization/trial.py +++ b/src/amltk/optimization/trial.py @@ -1,35 +1,16 @@ -"""A [`Trial`][amltk.optimization.Trial] is -typically the output of -[`Optimizer.ask()`][amltk.optimization.Optimizer.ask], indicating -what the optimizer would like to evaluate next. We provide a host -of convenience methods attached to the `Trial` to make it easy to -save results, store artifacts, and more. - -Paired with the `Trial` is the [`Trial.Report`][amltk.optimization.Trial.Report], -class, providing an easy way to report back to the optimizer's -[`tell()`][amltk.optimization.Optimizer.tell] with -a simple [`trial.success(cost=...)`][amltk.optimization.Trial.success] or -[`trial.fail(cost=...)`][amltk.optimization.Trial.fail] call.. - -### Trial - -::: amltk.optimization.trial.Trial - options: - members: False - -### Report - -::: amltk.optimization.trial.Trial.Report - options: - members: False - -""" +"""The Trial and Report class.""" from __future__ import annotations import copy import logging -import traceback -from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence +import traceback as traceback_module +from collections.abc import ( + Hashable, + Iterable, + Iterator, + Mapping, + MutableMapping, +) from contextlib import contextmanager from dataclasses import dataclass, field from datetime import datetime @@ -43,9 +24,10 @@ from amltk._functional import dict_get_not_none, mapping_select, prefix_keys from amltk._richutil.renderable import RichRenderable -from amltk.optimization.metric import Metric +from amltk._util import parse_timestamp_object +from amltk.optimization.metric import Metric, MetricCollection from amltk.profiling import Memory, Profile, Profiler, Timer -from amltk.store import Bucket, PathBucket +from amltk.store import PathBucket if TYPE_CHECKING: from rich.console import RenderableType @@ -71,159 +53,7 @@ @dataclass(kw_only=True) class Trial(RichRenderable, Generic[I]): - """A [`Trial`][amltk.optimization.Trial] encapsulates some configuration - that needs to be evaluated. Typically, this is what is generated by an - [`Optimizer.ask()`][amltk.optimization.Optimizer.ask] call. - - ??? tip "Usage" - - To begin a trial, you can use the - [`trial.begin()`][amltk.optimization.Trial.begin], which will catch - exceptions/traceback and profile the block of code. - - If all went smooth, your trial was successful and you can use - [`trial.success()`][amltk.optimization.Trial.success] to generate - a success [`Report`][amltk.optimization.Trial.Report], typically - passing what your chosen optimizer expects, e.g., `"loss"` or `"cost"`. - - If your trial failed, you can instead use the - [`trial.fail()`][amltk.optimization.Trial.fail] to generate a - failure [`Report`][amltk.optimization.Trial.Report], where - any caught exception will be attached to it. Each - [`Optimizer`][amltk.optimization.Optimizer] will take care of what to do - from here. - - ```python exec="true" source="material-block" html="true" - from amltk.optimization import Trial, Metric - from amltk.store import PathBucket - - cost = Metric("cost", minimize=True) - - def target_function(trial: Trial) -> Trial.Report: - x = trial.config["x"] - y = trial.config["y"] - - with trial.begin(): - cost = x**2 - y - - if trial.exception: - return trial.fail() - - return trial.success(cost=cost) - - # ... usually obtained from an optimizer - trial = Trial(name="some-unique-name", config={"x": 1, "y": 2}, metrics=[cost]) - - report = target_function(trial) - print(report.df()) - ``` - - - What you can return with [`trial.success()`][amltk.optimization.Trial.success] - or [`trial.fail()`][amltk.optimization.Trial.fail] depends on the - [`metrics`][amltk.optimization.Trial.metrics] of the trial. Typically, - an optimizer will provide the trial with the list of metrics. - - ??? tip "Metrics" - - ::: amltk.optimization.metric.Metric - options: - members: False - - Some important properties are that they have a unique - [`.name`][amltk.optimization.Trial.name] given the optimization run, - a candidate [`.config`][amltk.optimization.Trial.config] to evaluate, - a possible [`.seed`][amltk.optimization.Trial.seed] to use, - and an [`.info`][amltk.optimization.Trial.info] object, which is the optimizer - specific information, if required by you. - - !!! tip "Reporting success (or failure)" - - When using the [`success()`][amltk.optimization.trial.Trial.success] - or [`fail()`][amltk.optimization.trial.Trial.success] method, make sure to - provide values for all metrics specified in the - [`.metrics`][amltk.optimization.Trial.metrics] attribute. Usually these are - set by the optimizer generating the `Trial`. - - Each metric has a unique name, and it's crucial to use the correct names when - reporting success, otherwise an error will occur. - - ??? example "Reporting success for metrics" - - For example: - - ```python exec="true" result="python" source="material-block" - from amltk.optimization import Trial, Metric - - # Gotten from some optimizer usually, i.e. via `optimizer.ask()` - trial = Trial( - name="example_trial", - config={"param": 42}, - metrics=[Metric(name="accuracy", minimize=False)] - ) - - # Incorrect usage (will raise an error) - try: - report = trial.success(invalid_metric=0.95) - except ValueError as error: - print(error) - - # Correct usage - report = trial.success(accuracy=0.95) - ``` - - If using [`Plugins`][amltk.scheduling.plugins.Plugin], they may insert - some extra objects in the [`.extra`][amltk.optimization.Trial.extras] dict. - - To profile your trial, you can wrap the logic you'd like to check with - [`trial.begin()`][amltk.optimization.Trial.begin], which will automatically - catch any errors, record the traceback, and profile the block of code, in - terms of time and memory. - - You can access the profiled time and memory using the - [`.time`][amltk.optimization.Trial.time] and - [`.memory`][amltk.optimization.Trial.memory] attributes. - If you've [`profile()`][amltk.optimization.Trial.profile]'ed any other intervals, - you can access them by name through - [`trial.profiles`][amltk.optimization.Trial.profiles]. - Please see the [`Profiler`][amltk.profiling.profiler.Profiler] - for more. - - ??? example "Profiling with a trial." - - ```python exec="true" source="material-block" result="python" title="profile" - from amltk.optimization import Trial - - trial = Trial(name="some-unique-name", config={}) - - # ... somewhere where you've begun your trial. - with trial.profile("some_interval"): - for work in range(100): - pass - - print(trial.profiler.df()) - ``` - - You can also record anything you'd like into the - [`.summary`][amltk.optimization.Trial.summary], a plain `#!python dict` - or use [`trial.store()`][amltk.optimization.Trial.store] to store artifacts - related to the trial. - - ??? tip "What to put in `.summary`?" - - For large items, e.g. predictions or models, these are highly advised to - [`.store()`][amltk.optimization.Trial.store] to disk, especially if using - a `Task` for multiprocessing. - - Further, if serializing the report using the - [`report.df()`][amltk.optimization.Trial.Report.df], - returning a single row, - or a [`History`][amltk.optimization.History] - with [`history.df()`][amltk.optimization.History.df] for a dataframe consisting - of many of the reports, then you'd likely only want to store things - that are scalar and can be serialised to disk by a pandas DataFrame. - - """ + """The trial class.""" name: str """The unique name of the trial.""" @@ -231,115 +61,148 @@ def target_function(trial: Trial) -> Trial.Report: config: Mapping[str, Any] """The config of the trial provided by the optimizer.""" - bucket: PathBucket = field( - default_factory=lambda: PathBucket("unknown-trial-bucket"), - ) + bucket: PathBucket """The bucket to store trial related output to.""" - info: I | None = field(default=None, repr=False) + info: I | None = field(repr=False) """The info of the trial provided by the optimizer.""" - metrics: Sequence[Metric] = field(default_factory=list) - """The metrics associated with the trial.""" + metrics: MetricCollection + """The metrics associated with the trial. + + You can access the metrics by name, e.g. `#!python trial.metrics["loss"]`. + """ + + created_at: datetime + """When the trial was created.""" seed: int | None = None """The seed to use if suggested by the optimizer.""" - fidelities: dict[str, Any] | None = None + fidelities: Mapping[str, Any] """The fidelities at which to evaluate the trial, if any.""" - time: Timer.Interval = field(repr=False, default_factory=Timer.na) - """The time taken by the trial, once ended.""" - - memory: Memory.Interval = field(repr=False, default_factory=Memory.na) - """The memory used by the trial, once ended.""" - - profiler: Profiler = field( - repr=False, - default_factory=lambda: Profiler(memory_unit="B", time_kind="wall"), - ) + profiler: Profiler = field(repr=False) """A profiler for this trial.""" - summary: dict[str, Any] = field(default_factory=dict) + summary: MutableMapping[str, Any] """The summary of the trial. These are for summary statistics of a trial and are single values.""" - exception: BaseException | None = field(repr=True, default=None) - """The exception raised by the trial, if any.""" - - traceback: str | None = field(repr=False, default=None) - """The traceback of the exception, if any.""" - - storage: set[Any] = field(default_factory=set) + storage: set[Any] """Anything stored in the trial, the elements of the list are keys that can be used to retrieve them later, such as a Path. """ - extras: dict[str, Any] = field(default_factory=dict) + extras: MutableMapping[str, Any] """Any extras attached to the trial.""" - @property - def profiles(self) -> Mapping[str, Profile.Interval]: - """The profiles of the trial.""" - return self.profiler.profiles - - @contextmanager - def begin( - self, - time: Timer.Kind | Literal["wall", "cpu", "process"] | None = None, - memory_unit: Memory.Unit | Literal["B", "KB", "MB", "GB"] | None = None, - ) -> Iterator[None]: - """Begin the trial with a `contextmanager`. + @classmethod + def create( # noqa: PLR0913 + cls, + name: str, + config: Mapping[str, Any] | None = None, + *, + metrics: Metric | Iterable[Metric] | Mapping[str, Metric] | None = None, + info: I | None = None, + seed: int | None = None, + fidelities: Mapping[str, Any] | None = None, + created_at: datetime | None = None, + profiler: Profiler | None = None, + bucket: str | Path | PathBucket | None = None, + summary: MutableMapping[str, Any] | None = None, + storage: set[Hashable] | None = None, + extras: MutableMapping[str, Any] | None = None, + ) -> Trial[I]: + """Create a trial. - Will begin timing the trial in the `with` block, attaching the profiled time and memory - to the trial once completed, under `.profile.time` and `.profile.memory` attributes. + Args: + name: The name of the trial. + metrics: The metrics of the trial. + config: The config of the trial. + info: The info of the trial. + seed: The seed of the trial. + fidelities: The fidelities of the trial. + bucket: The bucket of the trial. + created_at: When the trial was created. + profiler: The profiler of the trial. + summary: The summary of the trial. + storage: The storage of the trial. + extras: The extras of the trial. - If an exception is raised, it will be attached to the trial under `.exception` - with the traceback attached to the actual error message, such that it can - be pickled and sent back to the main process loop. + Returns: + The trial. + """ + return Trial( + name=name, + metrics=( + MetricCollection.from_collection(metrics) + if metrics is not None + else MetricCollection() + ), + profiler=( + profiler + if profiler is not None + else Profiler(memory_unit="B", time_kind="wall") + ), + config=config if config is not None else {}, + info=info, + seed=seed, + created_at=created_at if created_at is not None else datetime.now(), + fidelities=fidelities if fidelities is not None else {}, + bucket=( + bucket + if isinstance(bucket, PathBucket) + else ( + PathBucket(bucket) + if bucket is not None + else PathBucket(f"trial-{name}-{datetime.now().isoformat()}") + ) + ), + summary=summary if summary is not None else {}, + storage=storage if storage is not None else set(), + extras=extras if extras is not None else {}, + ) - ```python exec="true" source="material-block" result="python" title="begin" hl_lines="5" - from amltk.optimization import Trial + @property + def profiles(self) -> Mapping[str, Profile.Interval]: + """The profiles of the trial. - trial = Trial(name="trial", config={"x": 1}) + These are indexed by the name of the profile indicated by: - with trial.begin(): - # Do some work - pass + ```python + with trial.profile("key_to_index"): + # ... - print(trial.memory) - print(trial.time) + profile = trial.profiles["key_to_index"] ``` - ```python exec="true" source="material-block" result="python" title="begin-fail" hl_lines="5" - from amltk.optimization import Trial - - trial = Trial(name="trial", config={"x": -1}) - - with trial.begin(): - raise ValueError("x must be positive") + The values are a + [`Profile.Interval`][amltk.profiling.profiler.Profile.Interval], + which contain a + [`Memory.Interval`][amltk.profiling.memory.Memory.Interval] + and a + [`Timer.Interval`][amltk.profiling.timing.Timer.Interval]. + Please see the respective documentation for more. + """ + return self.profiler.profiles - print(trial.exception) - print(trial.traceback) - print(trial.memory) - print(trial.time) - ``` + def dump_exception( + self, + exception: BaseException, + *, + name: str | None = None, + ) -> None: + """Dump an exception to the trial. Args: - time: The timer kind to use for the trial. Defaults to the default - timer kind of the profiler. - memory_unit: The memory unit to use for the trial. Defaults to the - default memory unit of the profiler. - """ # noqa: E501 - with self.profiler(name="trial", memory_unit=memory_unit, time_kind=time): - try: - yield - except Exception as error: # noqa: BLE001 - self.exception = error - self.traceback = traceback.format_exc() - finally: - self.time = self.profiler["trial"].time - self.memory = self.profiler["trial"].memory + exception: The exception to dump. + name: The name of the file to dump to. If `None`, will be `"exception"`. + """ + fname = name if name is not None else "exception" + traceback = "".join(traceback_module.format_tb(exception.__traceback__)) + msg = f"{traceback}\n{exception.__class__.__name__}: {exception}" + self.store({f"{fname}.txt": msg}) @contextmanager def profile( @@ -359,13 +222,14 @@ def profile( from amltk.optimization import Trial import time - trial = Trial(name="trial", config={"x": 1}) + trial = Trial.create(name="trial", config={"x": 1}) with trial.profile("some_interval"): # Do some work time.sleep(1) print(trial.profiler["some_interval"].time) + trial.bucket.rmdir() # markdown-exec: hide ``` Args: @@ -394,13 +258,11 @@ def success(self, **metrics: float | int) -> Trial.Report[I]: loss_metric = Metric("loss", minimize=True) - trial = Trial(name="trial", config={"x": 1}, metrics=[loss_metric]) - - with trial.begin(): - # Do some work - report = trial.success(loss=1) + trial = Trial.create(name="trial", config={"x": 1}, metrics=[loss_metric]) + report = trial.success(loss=1) print(report) + trial.bucket.rmdir() # markdown-exec: hide ``` Args: @@ -410,35 +272,37 @@ def success(self, **metrics: float | int) -> Trial.Report[I]: Returns: The report of the trial. """ # noqa: E501 - _recorded_values: list[Metric.Value] = [] - for _metric in self.metrics: - if (raw_value := metrics.get(_metric.name)) is not None: - _recorded_values.append(_metric.as_value(raw_value)) + values: dict[str, float] = {} + + for metric_def in self.metrics.values(): + if (reported_value := metrics.get(metric_def.name)) is not None: + values[metric_def.name] = reported_value else: raise ValueError( - f"Cannot report success without {self.metrics=}." - f" Please provide a value for the metric '{_metric.name}'." - f"\nPlease provide '{_metric.name}' as `trial.success(" - f"{_metric.name}=value)` or rename your metric to" - f'`Metric(name="{{provided_key}}", minimize={_metric.minimize}, ' - f"bounds={_metric.bounds})`", + f" Please provide a value for the metric '{metric_def.name}' as " + " this is one of the metrics of the trial. " + f"\n Try `trial.success({metric_def.name}=value, ...)`.", ) # Need to check if anything extra was reported! - extra = set(metrics.keys()) - {metric.name for metric in self.metrics} + extra = set(metrics.keys()) - self.metrics.keys() if extra: raise ValueError( - f"Cannot report success with extra metrics: {extra=}." - f"\nOnly {self.metrics=} are allowed.", + f"Cannot report `success()` with extra metrics: {extra=}." + f"\nOnly metrics {list(self.metrics)} as these are the metrics" + " provided for this trial." + "\nTo record other numerics, use `trial.summary` instead.", ) - return Trial.Report( - trial=self, - status=Trial.Status.SUCCESS, - metric_values=tuple(_recorded_values), - ) + return Trial.Report(trial=self, status=Trial.Status.SUCCESS, values=values) - def fail(self, **metrics: float | int) -> Trial.Report[I]: + def fail( + self, + exception: Exception | None = None, + traceback: str | None = None, + /, + **metrics: float | int, + ) -> Trial.Report[I]: """Generate a failure report. !!! note "Non specifed metrics" @@ -452,37 +316,45 @@ def fail(self, **metrics: float | int) -> Trial.Report[I]: from amltk.optimization import Trial, Metric loss = Metric("loss", minimize=True, bounds=(0, 1_000)) - trial = Trial(name="trial", config={"x": 1}, metrics=[loss]) + trial = Trial.create(name="trial", config={"x": 1}, metrics=[loss]) - with trial.begin(): + try: raise ValueError("This is an error") # Something went wrong + except Exception as error: + report = trial.fail(error) - if trial.exception: # You can check for an exception of the trial here - report = trial.fail() - - print(report.metrics) + print(report.values) print(report) + trial.bucket.rmdir() # markdown-exec: hide ``` Returns: The result of the trial. """ - _recorded_values: list[Metric.Value] = [] - for _metric in self.metrics: - if (raw_value := metrics.get(_metric.name)) is not None: - _recorded_values.append(_metric.as_value(raw_value)) - else: - _recorded_values.append(_metric.worst) + if exception is not None and traceback is None: + traceback = traceback_module.format_exc() + + # Need to check if anything extra was reported! + extra = set(metrics.keys()) - self.metrics.keys() + if extra: + raise ValueError( + f"Cannot report `fail()` with extra metrics: {extra=}." + f"\nOnly metrics {list(self.metrics)} as these are the metrics" + " provided for this trial." + "\nTo record other numerics, use `trial.summary` instead.", + ) return Trial.Report( trial=self, status=Trial.Status.FAIL, - metric_values=tuple(_recorded_values), + exception=exception, + traceback=traceback, + values=metrics, ) def crashed( self, - exception: BaseException | None = None, + exception: Exception, traceback: str | None = None, ) -> Trial.Report[I]: """Generate a crash report. @@ -509,51 +381,27 @@ def crashed( Returns: The report of the trial. """ - if exception is None and self.exception is None: - raise RuntimeError( - "Cannot generate a crash report without an exception." - " Please provide an exception or use `with trial.begin():` to start" - " the trial.", - ) - - self.exception = exception if exception else self.exception - self.traceback = traceback if traceback else self.traceback + if traceback is None: + traceback = "".join(traceback_module.format_tb(exception.__traceback__)) return Trial.Report( trial=self, status=Trial.Status.CRASHED, - metric_values=tuple(metric.worst for metric in self.metrics), + exception=exception, + traceback=traceback, ) - def store( - self, - items: Mapping[str, T], - *, - where: ( - str | Path | Bucket | Callable[[str, Mapping[str, T]], None] | None - ) = None, - ) -> None: + def store(self, items: Mapping[str, T]) -> None: """Store items related to the trial. ```python exec="true" source="material-block" result="python" title="store" hl_lines="5" from amltk.optimization import Trial from amltk.store import PathBucket - trial = Trial(name="trial", config={"x": 1}, bucket=PathBucket("results")) + trial = Trial.create(name="trial", config={"x": 1}, bucket=PathBucket("my-trial")) trial.store({"config.json": trial.config}) - - print(trial.storage) - ``` - - You could also specify `where=` exactly to store the thing - - ```python exec="true" source="material-block" result="python" title="store-bucket" hl_lines="7" - from amltk.optimization import Trial - - trial = Trial(name="trial", config={"x": 1}) - trial.store({"config.json": trial.config}, where="./results") - print(trial.storage) + trial.bucket.rmdir() # markdown-exec: hide ``` Args: @@ -561,48 +409,12 @@ def store( to the item itself.If using a `str`, `Path` or `PathBucket`, the keys of the items should be a valid filename, including the correct extension. e.g. `#!python {"config.json": trial.config}` - - where: Where to store the items. - - * If `None`, will use the bucket attached to the `Trial` if any, - otherwise it will raise an error. - - * If a `str` or `Path`, will store - a bucket will be created at the path, and the items will be - stored in a sub-bucket with the name of the trial. - - * If a `Bucket`, will store the items **in a sub-bucket** with the - name of the trial. - - * If a `Callable`, will call the callable with the name of the - trial and the key-valued pair of items to store. """ # noqa: E501 - method: Bucket - match where: - case None: - method = self.bucket - method.sub(self.name).store(items) - case str() | Path(): - method = PathBucket(where, create=True) - method.sub(self.name).store(items) - case Bucket(): - method = where - method.sub(self.name).store(items) - case _: - # Leave it up to supplied method - where(self.name, items) - + self.bucket.store(items) # Add the keys to storage - self.storage.update(items.keys()) + self.storage.update(items) - def delete_from_storage( - self, - items: Iterable[str], - *, - where: ( - str | Path | Bucket | Callable[[str, Iterable[str]], dict[str, bool]] | None - ) = None, - ) -> dict[str, bool]: + def delete_from_storage(self, items: Iterable[str]) -> dict[str, bool]: """Delete items related to the trial. ```python exec="true" source="material-block" result="python" title="delete-storage" hl_lines="6" @@ -610,64 +422,25 @@ def delete_from_storage( from amltk.store import PathBucket bucket = PathBucket("results") - trial = Trial(name="trial", config={"x": 1}, info={}, bucket=bucket) - - trial.store({"config.json": trial.config}) - trial.delete_from_storage(items=["config.json"]) - - print(trial.storage) - ``` - - You could also create a Bucket and use that instead. - - ```python exec="true" source="material-block" result="python" title="delete-storage-bucket" hl_lines="9" - from amltk.optimization import Trial - from amltk.store import PathBucket - - bucket = PathBucket("results") - trial = Trial(name="trial", config={"x": 1}, bucket=bucket) + trial = Trial.create(name="trial", config={"x": 1}, info={}, bucket=bucket) trial.store({"config.json": trial.config}) trial.delete_from_storage(items=["config.json"]) print(trial.storage) + trial.bucket.rmdir() # markdown-exec: hide ``` Args: items: The items to delete, an iterable of keys - where: Where the items are stored - - * If `None`, will use the bucket attached to the `Trial` if any, - otherwise it will raise an error. - - * If a `str` or `Path`, will lookup a bucket at the path, - and the items will be deleted from a sub-bucket with the name of the trial. - - * If a `Bucket`, will delete the items in a sub-bucket with the - name of the trial. - - * If a `Callable`, will call the callable with the name of the - trial and the keys of the items to delete. Should a mapping from - the key to whether it was deleted or not. Returns: A dict from the key to whether it was deleted or not. """ # noqa: E501 # If not a Callable, we convert to a path bucket - method: Bucket - match where: - case None: - method = self.bucket - case str() | Path(): - method = PathBucket(where, create=False) - case Bucket(): - method = where - case _: - # Leave it up to supplied method - return where(self.name, items) - - sub_bucket = method.sub(self.name) - return sub_bucket.remove(items) + removed = self.bucket.remove(items) + self.storage.difference_update(items) + return removed def copy(self) -> Self: """Create a copy of the trial. @@ -678,38 +451,16 @@ def copy(self) -> Self: return copy.deepcopy(self) @overload - def retrieve( - self, - key: str, - *, - where: str | Path | Bucket[str, Any] | None = ..., - check: None = None, - ) -> Any: + def retrieve(self, key: str, *, check: None = None) -> Any: ... @overload - def retrieve( - self, - key: str, - *, - where: str | Path | Bucket[str, Any] | None = ..., - check: type[R], - ) -> R: + def retrieve(self, key: str, *, check: type[R]) -> R: ... - def retrieve( - self, - key: str, - *, - where: str | Path | Bucket[str, Any] | None = None, - check: type[R] | None = None, - ) -> R | Any: + def retrieve(self, key: str, *, check: type[R] | None = None) -> R | Any: """Retrieve items related to the trial. - !!! note "Same argument for `where=`" - - Use the same argument for `where=` as you did for `store()`. - ```python exec="true" source="material-block" result="python" title="retrieve" hl_lines="7" from amltk.optimization import Trial from amltk.store import PathBucket @@ -717,49 +468,19 @@ def retrieve( bucket = PathBucket("results") # Create a trial, normally done by an optimizer - trial = Trial(name="trial", config={"x": 1}, bucket=bucket) + trial = Trial.create(name="trial", config={"x": 1}, bucket=bucket) trial.store({"config.json": trial.config}) config = trial.retrieve("config.json") print(config) - ``` - - You could also manually specify where something get's stored and retrieved - - ```python exec="true" source="material-block" result="python" title="retrieve-bucket" hl_lines="11" - - from amltk.optimization import Trial - from amltk.store import PathBucket - - path = "./config_path" - - trial = Trial(name="trial", config={"x": 1}) - - trial.store({"config.json": trial.config}, where=path) - - config = trial.retrieve("config.json", where=path) - print(config) - import shutil; shutil.rmtree(path) # markdown-exec: hide + trial.bucket.rmdir() # markdown-exec: hide ``` Args: key: The key of the item to retrieve as said in `.storage`. check: If provided, will check that the retrieved item is of the - provided type. If not, will raise a `TypeError`. This - is only used if `where=` is a `str`, `Path` or `Bucket`. - - where: Where to retrieve the items from. - - * If `None`, will use the bucket attached to the `Trial` if any, - otherwise it will raise an error. - - * If a `str` or `Path`, will store - a bucket will be created at the path, and the items will be - retrieved from a sub-bucket with the name of the trial. - - * If a `Bucket`, will retrieve the items from a sub-bucket with the - name of the trial. + provided type. If not, will raise a `TypeError`. Returns: The retrieved item. @@ -768,20 +489,7 @@ def retrieve( TypeError: If `check=` is provided and the retrieved item is not of the provided type. """ # noqa: E501 - # If not a Callable, we convert to a path bucket - method: Bucket[str, Any] - match where: - case None: - method = self.bucket - case str(): - method = PathBucket(where, create=True) - case Path(): - method = PathBucket(where, create=True) - case Bucket(): - method = where - - # Store in a sub-bucket - return method.sub(self.name)[key].load(check=check) + return self.bucket[key].load(check=check) def attach_extra(self, name: str, plugin_item: Any) -> None: """Attach a plugin item to the trial. @@ -792,12 +500,11 @@ def attach_extra(self, name: str, plugin_item: Any) -> None: """ self.extras[name] = plugin_item - def rich_renderables(self) -> Iterable[RenderableType]: # noqa: C901 + def rich_renderables(self) -> Iterable[RenderableType]: """The renderables for rich for this report.""" from rich.panel import Panel from rich.pretty import Pretty from rich.table import Table - from rich.text import Text items: list[RenderableType] = [] table = Table.grid(padding=(0, 1), expand=False) @@ -829,12 +536,6 @@ def rich_renderables(self) -> Iterable[RenderableType]: # noqa: C901 if any(self.storage): table.add_row("storage", Pretty(self.storage)) - if self.exception: - table.add_row("exception", Text(str(self.exception), style="bold red")) - - if self.traceback: - table.add_row("traceback", Text(self.traceback, style="bold red")) - for name, profile in self.profiles.items(): table.add_row("profile:" + name, Pretty(profile)) @@ -894,44 +595,7 @@ def __rich__(self) -> Text: @dataclass class Report(RichRenderable, Generic[I2]): - """The [`Trial.Report`][amltk.optimization.Trial.Report] encapsulates - a [`Trial`][amltk.optimization.Trial], its status and any metrics/exceptions - that may have occured. - - Typically you will not create these yourself, but instead use - [`trial.success()`][amltk.optimization.Trial.success] or - [`trial.fail()`][amltk.optimization.Trial.fail] to generate them. - - ```python exec="true" source="material-block" result="python" - from amltk.optimization import Trial, Metric - - loss = Metric("loss", minimize=True) - - trial = Trial(name="trial", config={"x": 1}, metrics=[loss]) - - with trial.begin(): - # Do some work - # ... - report: Trial.Report = trial.success(loss=1) - - print(report.df()) - ``` - - These reports are used to report back metrics to an - [`Optimizer`][amltk.optimization.Optimizer] - with [`Optimizer.tell()`][amltk.optimization.Optimizer.tell] but can also be - stored for your own uses. - - You can access the original trial with the - [`.trial`][amltk.optimization.Trial.Report.trial] attribute, and the - [`Status`][amltk.optimization.Trial.Status] of the trial with the - [`.status`][amltk.optimization.Trial.Report.status] attribute. - - You may also want to check out the [`History`][amltk.optimization.History] class - for storing a collection of `Report`s, allowing for an easier time to convert - them to a dataframe or perform some common Hyperparameter optimization parsing - of metrics. - """ + """The report generated from a `Trial`.""" trial: Trial[I2] """The trial that was run.""" @@ -939,32 +603,23 @@ class Report(RichRenderable, Generic[I2]): status: Trial.Status """The status of the trial.""" - metrics: dict[str, float] = field(init=False) - """The metric values of the trial.""" + reported_at: datetime = field(default_factory=datetime.now) + """When this Report was generated. - metric_values: tuple[Metric.Value, ...] = field(default_factory=tuple) - """The metrics of the trial, linked to the metrics.""" - - metric_defs: dict[str, Metric] = field(init=False) - """A lookup to the metric definitions""" - - metric_names: tuple[str, ...] = field(init=False) - """The names of the metrics.""" + This will primarily be `None` if there was no corresponding key + when loading this report from a serialized form, such as + with [`from_df()`][amltk.optimization.Trial.Report.from_df] + or [`from_dict()`][amltk.optimization.Trial.Report.from_dict]. + """ - def __post_init__(self) -> None: - self.metrics = {value.name: value.value for value in self.metric_values} - self.metric_names = tuple(metric.name for metric in self.metric_values) - self.metric_defs = {v.metric.name: v.metric for v in self.metric_values} + exception: BaseException | None = None + """The exception reported if any.""" - @property - def exception(self) -> BaseException | None: - """The exception of the trial, if any.""" - return self.trial.exception + traceback: str | None = field(repr=False, default=None) + """The traceback reported if any.""" - @property - def traceback(self) -> str | None: - """The traceback of the trial, if any.""" - return self.trial.traceback + values: Mapping[str, float] = field(default_factory=dict) + """The reported metric values of the trial.""" @property def name(self) -> str: @@ -976,13 +631,18 @@ def config(self) -> Mapping[str, Any]: """The config of the trial.""" return self.trial.config + @property + def metrics(self) -> MetricCollection: + """The metrics of the trial.""" + return self.trial.metrics + @property def profiles(self) -> Mapping[str, Profile.Interval]: """The profiles of the trial.""" return self.trial.profiles @property - def summary(self) -> dict[str, Any]: + def summary(self) -> MutableMapping[str, Any]: """The summary of the trial.""" return self.trial.summary @@ -991,16 +651,6 @@ def storage(self) -> set[str]: """The storage of the trial.""" return self.trial.storage - @property - def time(self) -> Timer.Interval: - """The time of the trial.""" - return self.trial.time - - @property - def memory(self) -> Memory.Interval: - """The memory of the trial.""" - return self.trial.memory - @property def bucket(self) -> PathBucket: """The bucket attached to the trial.""" @@ -1043,74 +693,34 @@ def df( "exception": str(self.exception) if self.exception else "NA", "traceback": str(self.traceback) if self.traceback else "NA", "bucket": str(self.bucket.path), + "created_at": self.trial.created_at, + "reported_at": self.reported_at, } if metrics: - for value in self.metric_values: - items[f"metric:{value.metric}"] = value.value + for metric_name, value in self.values.items(): + metric_def = self.metrics[metric_name] + items[f"metric:{metric_def}"] = value if summary: items.update(**prefix_keys(self.trial.summary, "summary:")) if configs: items.update(**prefix_keys(self.trial.config, "config:")) if profiles: for name, profile in sorted(self.profiles.items(), key=lambda x: x[0]): - # We log this one seperatly - if name == "trial": - items.update(profile.to_dict()) - else: - items.update(profile.to_dict(prefix=f"profile:{name}")) + items.update(profile.to_dict(prefix=f"profile:{name}")) return pd.DataFrame(items, index=[0]).convert_dtypes().set_index("name") @overload - def retrieve( - self, - key: str, - *, - where: str | Path | Bucket[str, Any] | None = ..., - check: None = None, - ) -> Any: + def retrieve(self, key: str, *, check: None = None) -> Any: ... @overload - def retrieve( - self, - key: str, - *, - where: str | Path | Bucket[str, Any] | None = ..., - check: type[R], - ) -> R: + def retrieve(self, key: str, *, check: type[R]) -> R: ... - def retrieve( - self, - key: str, - *, - where: str | Path | Bucket[str, Any] | None = None, - check: type[R] | None = None, - ) -> R | Any: + def retrieve(self, key: str, *, check: type[R] | None = None) -> R | Any: """Retrieve items related to the trial. - !!! note "Same argument for `where=`" - - Use the same argument for `where=` as you did for `store()`. - - ```python exec="true" source="material-block" result="python" title="retrieve" hl_lines="7" - from amltk.optimization import Trial - from amltk.store import PathBucket - - bucket = PathBucket("results") - trial = Trial(name="trial", config={"x": 1}, bucket=bucket) - - trial.store({"config.json": trial.config}) - with trial.begin(): - report = trial.success() - - config = report.retrieve("config.json") - print(config) - ``` - - You could also create a Bucket and use that instead. - ```python exec="true" source="material-block" result="python" title="retrieve-bucket" hl_lines="11" from amltk.optimization import Trial @@ -1118,33 +728,20 @@ def retrieve( bucket = PathBucket("results") - trial = Trial(name="trial", config={"x": 1}, bucket=bucket) + trial = Trial.create(name="trial", config={"x": 1}, bucket=bucket) trial.store({"config.json": trial.config}) - - with trial.begin(): - report = trial.success() + report = trial.success() config = report.retrieve("config.json") print(config) + trial.bucket.rmdir() # markdown-exec: hide ``` Args: key: The key of the item to retrieve as said in `.storage`. check: If provided, will check that the retrieved item is of the - provided type. If not, will raise a `TypeError`. This - is only used if `where=` is a `str`, `Path` or `Bucket`. - where: Where to retrieve the items from. - - * If `None`, will use the bucket attached to the `Trial` if any, - otherwise it will raise an error. - - * If a `str` or `Path`, will store - a bucket will be created at the path, and the items will be - retrieved from a sub-bucket with the name of the trial. - - * If a `Bucket`, will retrieve the items from a sub-bucket with the - name of the trial. + provided type. If not, will raise a `TypeError`. Returns: The retrieved item. @@ -1153,21 +750,15 @@ def retrieve( TypeError: If `check=` is provided and the retrieved item is not of the provided type. """ # noqa: E501 - return self.trial.retrieve(key, where=where, check=check) + return self.trial.retrieve(key, check=check) - def store( - self, - items: Mapping[str, T], - *, - where: ( - str | Path | Bucket | Callable[[str, Mapping[str, T]], None] | None - ) = None, - ) -> None: + def store(self, items: Mapping[str, T]) -> None: """Store items related to the trial. - See: [`Trial.store()`][amltk.optimization.trial.Trial.store] + See Also: + * [`Trial.store()`][amltk.optimization.trial.Trial.store] """ - self.trial.store(items, where=where) + self.trial.store(items) @classmethod def from_df(cls, df: pd.DataFrame | pd.Series) -> Trial.Report: @@ -1221,22 +812,9 @@ def from_dict(cls, d: Mapping[str, Any]) -> Trial.Report: # on serialization to keep the order, which is not ideal either. # May revisit this if we need to raw_metrics: dict[str, float] = mapping_select(d, "metric:") - _intermediate = { + metrics: dict[Metric, float | None] = { Metric.from_str(name): value for name, value in raw_metrics.items() } - metrics: dict[Metric, Metric.Value] = { - metric: metric.as_value(value) - for metric, value in _intermediate.items() - } - - _trial_profile_items = { - k: v for k, v in d.items() if k.startswith(("memory:", "time:")) - } - if any(_trial_profile_items): - trial_profile = Profile.from_dict(_trial_profile_items) - profiles["trial"] = trial_profile - else: - trial_profile = Profile.na() exception = d.get("exception") traceback = d.get("traceback") @@ -1253,37 +831,56 @@ def from_dict(cls, d: Mapping[str, Any]) -> Trial.Report: else: bucket = PathBucket(f"uknown_trial_bucket-{datetime.now().isoformat()}") - trial: Trial[None] = Trial( + created_at_timestamp = d.get("created_at") + if created_at_timestamp is None: + raise ValueError( + "Cannot load report from dict without a 'created_at' field.", + ) + created_at = parse_timestamp_object(created_at_timestamp) + + trial: Trial = Trial.create( name=d["name"], config=mapping_select(d, "config:"), info=None, # We don't save this to disk so we load it back as None bucket=bucket, seed=trial_seed, fidelities=mapping_select(d, "fidelities:"), - time=trial_profile.time, - memory=trial_profile.memory, profiler=Profiler(profiles=profiles), - metrics=list(metrics.keys()), + metrics=metrics.keys(), + created_at=created_at, summary=mapping_select(d, "summary:"), - exception=exception, - traceback=traceback, + storage=set(mapping_select(d, "storage:").values()), + extras=mapping_select(d, "extras:"), ) + _values: dict[str, float] = { + m.name: v + for m, v in metrics.items() + if (v is not None and not pd.isna(v)) + } + status = Trial.Status(dict_get_not_none(d, "status", "unknown")) - _values: dict[str, float] = {m.name: r.value for m, r in metrics.items()} - if status == Trial.Status.SUCCESS: - return trial.success(**_values) - - if status == Trial.Status.FAIL: - return trial.fail(**_values) - - if status == Trial.Status.CRASHED: - return trial.crashed( - exception=Exception("Unknown status.") - if trial.exception is None - else None, + match status: + case Trial.Status.SUCCESS: + report = trial.success(**_values) + case Trial.Status.FAIL: + exc = Exception(exception) if exception else None + tb = str(traceback) if traceback else None + report = trial.fail(exc, tb, **_values) + case Trial.Status.CRASHED: + exc = Exception(exception) if exception else Exception("Unknown") + tb = str(traceback) if traceback else None + report = trial.crashed(exc, tb) + case Trial.Status.UNKNOWN | _: + report = trial.crashed(exception=Exception("Unknown status.")) + + timestamp = d.get("reported_at") + if timestamp is None: + raise ValueError( + "Cannot load report from dict without a 'reported_at' field.", ) + report.reported_at = parse_timestamp_object(timestamp) - return trial.crashed(exception=Exception("Unknown status.")) + return report def rich_renderables(self) -> Iterable[RenderableType]: """The renderables for rich for this report.""" diff --git a/src/amltk/pipeline/components.py b/src/amltk/pipeline/components.py index 45555acf..0dfce584 100644 --- a/src/amltk/pipeline/components.py +++ b/src/amltk/pipeline/components.py @@ -1,82 +1,19 @@ -"""You can use the various different node types to build a pipeline. - -You can connect these nodes together using either the constructors explicitly, -as shown in the examples. We also provide some index operators: - -* `>>` - Connect nodes together to form a [`Sequential`][amltk.pipeline.components.Sequential] -* `&` - Connect nodes together to form a [`Join`][amltk.pipeline.components.Join] -* `|` - Connect nodes together to form a [`Choice`][amltk.pipeline.components.Choice] - -There is also another short-hand that you may find useful to know: - -* `{comp1, comp2, comp3}` - This will automatically be converted into a - [`Choice`][amltk.pipeline.Choice] between the given components. -* `(comp1, comp2, comp3)` - This will automatically be converted into a - [`Join`][amltk.pipeline.Join] between the given components. -* `[comp1, comp2, comp3]` - This will automatically be converted into a - [`Sequential`][amltk.pipeline.Sequential] between the given components. - -For each of these components we will show examples using -the [`#! "sklearn"` builder][amltk.pipeline.builders.sklearn.build] - -The components are: - -### Component - -::: amltk.pipeline.components.Component - options: - members: false - -### Sequential - -::: amltk.pipeline.components.Sequential - options: - members: false - -### Choice - -::: amltk.pipeline.components.Choice - options: - members: false - -### Split - -::: amltk.pipeline.components.Split - options: - members: false - -### Join - -::: amltk.pipeline.components.Join - options: - members: false - -### Fixed - -::: amltk.pipeline.components.Fixed - options: - members: false - -### Searchable - -::: amltk.pipeline.components.Searchable - options: - members: false -""" # noqa: E501 +"""The provided subclasses of a [`Node`][amltk.pipeline.node.Node] +that can be used can be assembled into a pipeline. +""" from __future__ import annotations import inspect from collections.abc import Callable, Iterator, Mapping, Sequence -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, overload from typing_extensions import Self, override -from more_itertools import all_unique, first_true +from more_itertools import first_true from amltk._functional import entity_name, mapping_select from amltk.exceptions import ( ComponentBuildError, - DuplicateNamesError, NoChoiceMadeError, NodeNotFoundError, ) @@ -159,120 +96,84 @@ def as_node( # noqa: PLR0911 @dataclass(init=False, frozen=True, eq=True) -class Join(Node[Item, Space]): - """[`Join`][amltk.pipeline.Join] together different parts of the pipeline. - - This indicates the different children in - [`.nodes`][amltk.pipeline.Node.nodes] should act in tandem with one - another, for example, concatenating the outputs of the various members of the - `Join`. - - ```python exec="true" source="material-block" html="true" - from amltk.pipeline import Join, Component - from sklearn.decomposition import PCA - from sklearn.feature_selection import SelectKBest - - pca = Component(PCA, space={"n_components": (1, 3)}) - kbest = Component(SelectKBest, space={"k": (1, 3)}) - - join = Join(pca, kbest, name="my_feature_union") - from amltk._doc import doc_print; doc_print(print, join) # markdown-exec: hide +class Component(Node[Item, Space]): + """A [`Component`][amltk.pipeline.Component] of the pipeline with + a possible item and **no children**. - space = join.search_space("configspace") - from amltk._doc import doc_print; doc_print(print, space) # markdown-exec: hide + This is the basic building block of most pipelines, it accepts + as it's [`item=`][amltk.pipeline.node.Node.item] some function that will be + called with [`build_item()`][amltk.pipeline.components.Component.build_item] to + build that one part of the pipeline. - pipeline = join.build("sklearn") - from amltk._doc import doc_print; doc_print(print, pipeline) # markdown-exec: hide - ``` + When [`build_item()`][amltk.pipeline.Component.build_item] is called, whatever + the config of the component is at that time, will be used to construct the item. - You may also just join together nodes using an infix operator `&` if you prefer: + A common pattern is to use a [`Component`][amltk.pipeline.Component] to + wrap a constructor, specifying the [`space=`][amltk.pipeline.node.Node.space] + and [`config=`][amltk.pipeline.node.Node.config] to be used when building the + item. ```python exec="true" source="material-block" html="true" - from amltk.pipeline import Join, Component - from sklearn.decomposition import PCA - from sklearn.feature_selection import SelectKBest - - pca = Component(PCA, space={"n_components": (1, 3)}) - kbest = Component(SelectKBest, space={"k": (1, 3)}) - - # Can not parametrize or name the join - join = pca & kbest - from amltk._doc import doc_print; doc_print(print, join) # markdown-exec: hide + from amltk.pipeline import Component + from sklearn.ensemble import RandomForestClassifier - # With a parametrized join - join = ( - Join(name="my_feature_union") & pca & kbest + rf = Component( + RandomForestClassifier, + config={"max_depth": 3}, + space={"n_estimators": (10, 100)} ) - item = join.build("sklearn") - print(item._repr_html_()) # markdown-exec: hide - ``` - - Whenever some other node sees a tuple, i.e. `(comp1, comp2, comp3)`, this - will automatically be converted into a `Join`. - - ```python exec="true" source="material-block" html="true" - from amltk.pipeline import Sequential, Component - from sklearn.decomposition import PCA - from sklearn.feature_selection import SelectKBest - from sklearn.ensemble import RandomForestClassifier + from amltk._doc import doc_print; doc_print(print, rf) # markdown-exec: hide - pca = Component(PCA, space={"n_components": (1, 3)}) - kbest = Component(SelectKBest, space={"k": (1, 3)}) + config = {"n_estimators": 50} # Sample from some space or something + configured_rf = rf.configure(config) - # Can not parametrize or name the join - join = Sequential( - (pca, kbest), - RandomForestClassifier(n_estimators=5), - name="my_feature_union", - ) - from amltk._doc import doc_print; doc_print(print, join) # markdown-exec: hide + estimator = configured_rf.build_item() + from amltk._doc import doc_print; doc_print(print, estimator) # markdown-exec: hide ``` - Like all [`Node`][amltk.pipeline.node.Node]s, a `Join` accepts an explicit - [`name=`][amltk.pipeline.node.Node.name], - [`item=`][amltk.pipeline.node.Node.item], - [`config=`][amltk.pipeline.node.Node.config], - [`space=`][amltk.pipeline.node.Node.space], - [`fidelities=`][amltk.pipeline.node.Node.fidelities], - [`config_transform=`][amltk.pipeline.node.Node.config_transform] and - [`meta=`][amltk.pipeline.node.Node.meta]. - See Also: * [`Node`][amltk.pipeline.node.Node] """ - nodes: tuple[Node, ...] - """The nodes that this node leads to.""" + item: Callable[..., Item] + """A node which constructs an item in the pipeline.""" - RICH_OPTIONS: ClassVar[RichOptions] = RichOptions(panel_color="#7E6B8F") + nodes: tuple[()] + """A component has no children.""" - _NODES_INIT: ClassVar = "args" + RICH_OPTIONS: ClassVar[RichOptions] = RichOptions(panel_color="#E6AF2E") + + _NODES_INIT: ClassVar = None def __init__( self, - *nodes: Node | NodeLike, + item: Callable[..., Item], + *, name: str | None = None, - item: Item | Callable[[Item], Item] | None = None, config: Config | None = None, space: Space | None = None, fidelities: Mapping[str, Any] | None = None, config_transform: Callable[[Config, Any], Config] | None = None, meta: Mapping[str, Any] | None = None, ): - """See [`Node`][amltk.pipeline.node.Node] for details.""" - _nodes = tuple(as_node(n) for n in nodes) - if not all_unique(_nodes, key=lambda node: node.name): - raise ValueError( - f"Can't handle nodes they do not all contain unique names, {nodes=}." - "\nAll nodes must have a unique name. Please provide a `name=` to them", - ) - - if name is None: - name = f"Join-{randuid(8)}" + """Initialize a component. + Args: + item: The item attached to this node. + name: The name of the node. If not specified, the name will be + generated from the item. + config: The configuration for this node. + space: The search space for this node. This will be used when + [`search_space()`][amltk.pipeline.node.Node.search_space] is called. + fidelities: The fidelities for this node. + config_transform: A function that transforms the `config=` parameter + during [`configure(config)`][amltk.pipeline.node.Node.configure] + before return the new configured node. Useful for times where + you need to combine multiple parameters into one. + meta: Any meta information about this node. + """ super().__init__( - *_nodes, - name=name, + name=name if name is not None else entity_name(item), item=item, config=config, space=space, @@ -281,146 +182,98 @@ def __init__( meta=meta, ) - @override - def __and__(self, other: Node | NodeLike) -> Join: - other_node = as_node(other) - if any(other_node.name == node.name for node in self.nodes): - raise ValueError( - f"Can't handle node with name '{other_node.name} as" - f" there is already a node called '{other_node.name}' in {self.name}", - ) - - nodes = (*tuple(as_node(n) for n in self.nodes), other_node) - return self.mutate(name=self.name, nodes=nodes) - - -@dataclass(init=False, frozen=True, eq=True) -class Choice(Node[Item, Space]): - """A [`Choice`][amltk.pipeline.Choice] between different subcomponents. - - This indicates that a choice should be made between the different children in - [`.nodes`][amltk.pipeline.Node.nodes], usually done when you - [`configure()`][amltk.pipeline.node.Node.configure] with some `config` from - a [`search_space()`][amltk.pipeline.node.Node.search_space]. - - ```python exec="true" source="material-block" html="true" - from amltk.pipeline import Choice, Component - from sklearn.ensemble import RandomForestClassifier - from sklearn.neural_network import MLPClassifier - - rf = Component(RandomForestClassifier, space={"n_estimators": (10, 100)}) - mlp = Component(MLPClassifier, space={"activation": ["logistic", "relu", "tanh"]}) - - estimator_choice = Choice(rf, mlp, name="estimator") - from amltk._doc import doc_print; doc_print(print, estimator_choice) # markdown-exec: hide - - space = estimator_choice.search_space("configspace") - from amltk._doc import doc_print; doc_print(print, space) # markdown-exec: hide - - config = space.sample_configuration() - from amltk._doc import doc_print; doc_print(print, config) # markdown-exec: hide - - configured_choice = estimator_choice.configure(config) - from amltk._doc import doc_print; doc_print(print, configured_choice) # markdown-exec: hide - - chosen_estimator = configured_choice.chosen() - from amltk._doc import doc_print; doc_print(print, chosen_estimator) # markdown-exec: hide - - estimator = chosen_estimator.build_item() - from amltk._doc import doc_print; doc_print(print, estimator) # markdown-exec: hide - ``` + def build_item(self, **kwargs: Any) -> Item: + """Build the item attached to this component. - You may also just add nodes to a `Choice` using an infix operator `|` if you prefer: + Args: + **kwargs: Any additional arguments to pass to the item - ```python exec="true" source="material-block" html="true" - from amltk.pipeline import Choice, Component - from sklearn.ensemble import RandomForestClassifier - from sklearn.neural_network import MLPClassifier + Returns: + Item + The built item + """ + config = self.config or {} + try: + return self.item(**{**config, **kwargs}) + except TypeError as e: + new_msg = f"Failed to build `{self.item=}` with `{self.config=}`.\n" + if any(kwargs): + new_msg += f"Extra {kwargs=} were also provided.\n" + new_msg += ( + "If the item failed to initialize, a common reason can be forgetting" + " to call `configure()` on the `Component` or the pipeline it is in or" + " not calling `build()`/`build_item()` on the **returned** value of" + " `configure()`.\n" + "Reasons may also include not having fully specified the `config`" + " initially, it having not being configured fully from `configure()`" + " or from misspecfying parameters in the `space`." + ) + raise ComponentBuildError(new_msg) from e - rf = Component(RandomForestClassifier, space={"n_estimators": (10, 100)}) - mlp = Component(MLPClassifier, space={"activation": ["logistic", "relu", "tanh"]}) - estimator_choice = ( - Choice(name="estimator") | mlp | rf - ) - from amltk._doc import doc_print; doc_print(print, estimator_choice) # markdown-exec: hide - ``` +@dataclass(init=False, frozen=True, eq=True) +class Searchable(Node[None, Space]): # type: ignore + """A [`Searchable`][amltk.pipeline.Searchable] + node of the pipeline which just represents a search space, no item attached. - Whenever some other node sees a set, i.e. `{comp1, comp2, comp3}`, this - will automatically be converted into a `Choice`. + While not usually applicable to pipelines you want to build, this node + is useful for creating a search space, especially if the real pipeline you + want to optimize can not be built directly. For example, if you are optimize + a script, you may wish to use a `Searchable` to represent the search space + of that script. ```python exec="true" source="material-block" html="true" - from amltk.pipeline import Choice, Component, Sequential - from sklearn.ensemble import RandomForestClassifier - from sklearn.neural_network import MLPClassifier - from sklearn.impute import SimpleImputer - - rf = Component(RandomForestClassifier, space={"n_estimators": (10, 100)}) - mlp = Component(MLPClassifier, space={"activation": ["logistic", "relu", "tanh"]}) + from amltk.pipeline import Searchable - pipeline = Sequential( - SimpleImputer(fill_value=0), - {mlp, rf}, - name="my_pipeline", - ) - from amltk._doc import doc_print; doc_print(print, pipeline) # markdown-exec: hide + script_space = Searchable({"mode": ["orange", "blue", "red"], "n": (10, 100)}) + from amltk._doc import doc_print; doc_print(print, script_space) # markdown-exec: hide ``` - Like all [`Node`][amltk.pipeline.node.Node]s, a `Choice` accepts an explicit - [`name=`][amltk.pipeline.node.Node.name], - [`item=`][amltk.pipeline.node.Node.item], - [`config=`][amltk.pipeline.node.Node.config], - [`space=`][amltk.pipeline.node.Node.space], - [`fidelities=`][amltk.pipeline.node.Node.fidelities], - [`config_transform=`][amltk.pipeline.node.Node.config_transform] and - [`meta=`][amltk.pipeline.node.Node.meta]. - - !!! warning "Order of nodes" - - The given nodes of a choice are always ordered according - to their name, so indexing `choice.nodes` may not be reliable - if modifying the choice dynamically. - - Please use `choice["name"]` to access the nodes instead. - See Also: * [`Node`][amltk.pipeline.node.Node] """ # noqa: E501 - nodes: tuple[Node, ...] - """The nodes that this node leads to.""" + item: None = None + """A searchable has no item.""" - RICH_OPTIONS: ClassVar[RichOptions] = RichOptions(panel_color="#FF4500") - _NODES_INIT: ClassVar = "args" + nodes: tuple[()] = () + """A searchable has no children.""" + + RICH_OPTIONS: ClassVar[RichOptions] = RichOptions(panel_color="light_steel_blue") + + _NODES_INIT: ClassVar = None def __init__( self, - *nodes: Node | NodeLike, + space: Space | None = None, + *, name: str | None = None, - item: Item | Callable[[Item], Item] | None = None, config: Config | None = None, - space: Space | None = None, fidelities: Mapping[str, Any] | None = None, config_transform: Callable[[Config, Any], Config] | None = None, meta: Mapping[str, Any] | None = None, ): - """See [`Node`][amltk.pipeline.node.Node] for details.""" - _nodes: tuple[Node, ...] = tuple( - sorted((as_node(n) for n in nodes), key=lambda n: n.name), - ) - if not all_unique(_nodes, key=lambda node: node.name): - raise ValueError( - f"Can't handle nodes as we can not generate a __choice__ for {nodes=}." - "\nAll nodes must have a unique name. Please provide a `name=` to them", - ) + """Initialize a choice. + Args: + space: The search space for this node. This will be used when + [`search_space()`][amltk.pipeline.node.Node.search_space] is called. + name: The name of the node. If not specified, a random one will + be generated. + config: The configuration for this node. Useful for setting some + default values. + fidelities: The fidelities for this node. + config_transform: A function that transforms the `config=` parameter + during [`configure(config)`][amltk.pipeline.node.Node.configure] + before return the new configured node. Useful for times where + you need to combine multiple parameters into one. + meta: Any meta information about this node. + """ if name is None: - name = f"Choice-{randuid(8)}" + name = f"Searchable-{randuid(8)}" super().__init__( - *_nodes, name=name, - item=item, config=config, space=space, fidelities=fidelities, @@ -428,158 +281,87 @@ def __init__( meta=meta, ) - @override - def __or__(self, other: Node | NodeLike) -> Choice: - other_node = as_node(other) - if any(other_node.name == node.name for node in self.nodes): - raise ValueError( - f"Can't handle node with name '{other_node.name} as" - f" there is already a node called '{other_node.name}' in {self.name}", - ) - nodes = tuple( - sorted( - [as_node(n) for n in self.nodes] + [other_node], - key=lambda n: n.name, - ), - ) - return self.mutate(name=self.name, nodes=nodes) +@dataclass(init=False, frozen=True, eq=True) +class Fixed(Node[Item, None]): # type: ignore + """A [`Fixed`][amltk.pipeline.Fixed] part of the pipeline that + represents something that can not be configured and used directly as is. - def chosen(self) -> Node: - """The chosen branch. + It consists of an [`.item`][amltk.pipeline.node.Node.item] that is fixed, + non-configurable and non-searchable. It also has no children. - Returns: - The chosen branch - """ - match self.config: - case {"__choice__": choice}: - chosen = first_true( - self.nodes, - pred=lambda node: node.name == choice, - default=None, - ) - if chosen is None: - raise NodeNotFoundError(choice, self.name) + This is useful for representing parts of the pipeline that are fixed, for example + if you have a pipeline that is a `Sequential` of nodes, but you want to + fix the first component to be a `PCA` with `n_components=3`, you can use a `Fixed` + to represent that. - return chosen - case _: - raise NoChoiceMadeError(self.name) + ```python exec="true" source="material-block" html="true" + from amltk.pipeline import Component, Fixed, Sequential + from sklearn.ensemble import RandomForestClassifier + from sklearn.decomposition import PCA - @override - def configure( - self, - config: Config, - *, - prefixed_name: bool | None = None, - transform_context: Any | None = None, - params: Mapping[str, Any] | None = None, - ) -> Self: - """Configure this node and anything following it with the given config. + rf = Component(RandomForestClassifier, space={"n_estimators": (10, 100)}) + pca = Fixed(PCA(n_components=3)) - !!! note "Configuring a choice" + pipeline = Sequential(pca, rf, name="my_pipeline") + from amltk._doc import doc_print; doc_print(print, pipeline) # markdown-exec: hide + ``` - For a Choice, if the config has a `__choice__` key, then only the node - chosen will be configured. The others will not be configured at all and - their config will be discarded. + See Also: + * [`Node`][amltk.pipeline.node.Node] + """ - Args: - config: The configuration to apply - prefixed_name: Whether items in the config are prefixed by the names - of the nodes. - * If `None`, the default, then `prefixed_name` will be assumed to - be `True` if this node has a next node or if the config has - keys that begin with this nodes name. - * If `True`, then the config will be searched for items prefixed - by the name of the node (and subsequent chained nodes). - * If `False`, then the config will be searched for items without - the prefix, i.e. the config keys are exactly those matching - this nodes search space. - transform_context: Any context to give to `config_transform=` of individual - nodes. - params: The params to match any requests when configuring this node. - These will match against any ParamRequests in the config and will - be used to fill in any missing values. + item: Item + """The fixed item that this node represents.""" - Returns: - The configured node - """ - # Get the config for this node - match prefixed_name: - case True: - config = mapping_select(config, f"{self.name}:") - case False: - pass - case None if any(k.startswith(f"{self.name}:") for k in config): - config = mapping_select(config, f"{self.name}:") - case None: - pass + space: None = None + """A fixed node has no search space.""" - _kwargs: dict[str, Any] = {} + fidelities: None = None + """A fixed node has no search space.""" - # Configure all the branches if exists - # This part is what differs for a Choice - if len(self.nodes) > 0: - choice_made = config.get("__choice__", None) - if choice_made is not None: - matching_child = first_true( - self.nodes, - pred=lambda node: node.name == choice_made, - default=None, - ) - if matching_child is None: - raise ValueError( - f"Can not find matching child for choice {self.name} with child" - f" {choice_made}." - "\nPlease check the config and ensure that the choice is one of" - f" {[n.name for n in self.nodes]}." - f"\nThe config recieved at this choice node was {config=}.", - ) + config: None = None + """A fixed node has no config.""" - # We still iterate over all of them just to ensure correct ordering - nodes = tuple( - node.copy() - if node.name != choice_made - else matching_child.configure( - config, - prefixed_name=True, - transform_context=transform_context, - params=params, - ) - for node in self.nodes - ) - _kwargs["nodes"] = nodes - else: - nodes = tuple( - node.configure( - config, - prefixed_name=True, - transform_context=transform_context, - params=params, - ) - for node in self.nodes - ) - _kwargs["nodes"] = nodes + config_transform: None = None + """A fixed node has no config so no transform.""" - this_config = { - hp: v - for hp, v in config.items() - if ( - ":" not in hp - and not any(hp.startswith(f"{node.name}") for node in self.nodes) - ) - } - if self.config is not None: - this_config = {**self.config, **this_config} + nodes: tuple[()] = () + """A fixed node has no children.""" - this_config = dict(self._fufill_param_requests(this_config, params=params)) + RICH_OPTIONS: ClassVar[RichOptions] = RichOptions(panel_color="#56351E") - if self.config_transform is not None: - this_config = dict(self.config_transform(this_config, transform_context)) + _NODES_INIT: ClassVar = None - if len(this_config) > 0: - _kwargs["config"] = dict(this_config) + def __init__( # noqa: D417 + self, + item: Item, + *, + name: str | None = None, + config: None = None, + space: None = None, + fidelities: None = None, + config_transform: None = None, + meta: Mapping[str, Any] | None = None, + ): + """Initialize a fixed node. - return self.mutate(**_kwargs) + Args: + item: The item attached to this node. Will be fixed and can not + be configured. + name: The name of the node. If not specified, the name will be + generated from the item. + meta: Any meta information about this node. + """ + super().__init__( + name=name if name is not None else entity_name(item), + item=item, + config=config, + space=space, + fidelities=fidelities, + config_transform=config_transform, + meta=meta, + ) @dataclass(init=False, frozen=True, eq=True) @@ -601,68 +383,14 @@ class Sequential(Node[Item, Space]): name="my_pipeline" ) from amltk._doc import doc_print; doc_print(print, pipeline) # markdown-exec: hide - - space = pipeline.search_space("configspace") - from amltk._doc import doc_print; doc_print(print, space) # markdown-exec: hide - - configuration = space.sample_configuration() - from amltk._doc import doc_print; doc_print(print, configuration) # markdown-exec: hide - - configured_pipeline = pipeline.configure(configuration) - from amltk._doc import doc_print; doc_print(print, configured_pipeline) # markdown-exec: hide - - sklearn_pipeline = pipeline.build("sklearn") - print(sklearn_pipeline._repr_html_()) # markdown-exec: hide - ``` - - You may also just chain together nodes using an infix operator `>>` if you prefer: - - ```python exec="true" source="material-block" html="true" - from amltk.pipeline import Join, Component, Sequential - from sklearn.decomposition import PCA - from sklearn.ensemble import RandomForestClassifier - - pipeline = ( - Sequential(name="my_pipeline") - >> PCA(n_components=3) - >> Component(RandomForestClassifier, space={"n_estimators": (10, 100)}) - ) - from amltk._doc import doc_print; doc_print(print, pipeline) # markdown-exec: hide - ``` - - Whenever some other node sees a list, i.e. `[comp1, comp2, comp3]`, this - will automatically be converted into a `Sequential`. - - ```python exec="true" source="material-block" html="true" - from amltk.pipeline import Choice - from sklearn.impute import SimpleImputer - from sklearn.preprocessing import StandardScaler - from sklearn.ensemble import RandomForestClassifier - from sklearn.neural_network import MLPClassifier - - pipeline_choice = Choice( - [SimpleImputer(), RandomForestClassifier()], - [StandardScaler(), MLPClassifier()], - name="pipeline_choice" - ) - from amltk._doc import doc_print; doc_print(print, pipeline_choice) # markdown-exec: hide ``` - Like all [`Node`][amltk.pipeline.node.Node]s, a `Sequential` accepts an explicit - [`name=`][amltk.pipeline.node.Node.name], - [`item=`][amltk.pipeline.node.Node.item], - [`config=`][amltk.pipeline.node.Node.config], - [`space=`][amltk.pipeline.node.Node.space], - [`fidelities=`][amltk.pipeline.node.Node.fidelities], - [`config_transform=`][amltk.pipeline.node.Node.config_transform] and - [`meta=`][amltk.pipeline.node.Node.meta]. - See Also: * [`Node`][amltk.pipeline.node.Node] - """ # noqa: E501 + """ nodes: tuple[Node, ...] - """The nodes in series.""" + """The nodes ordered in series.""" RICH_OPTIONS: ClassVar[RichOptions] = RichOptions( panel_color="#7E6B8F", @@ -681,12 +409,26 @@ def __init__( config_transform: Callable[[Config, Any], Config] | None = None, meta: Mapping[str, Any] | None = None, ): - """See [`Node`][amltk.pipeline.node.Node] for details.""" - _nodes = tuple(as_node(n) for n in nodes) + """Initialize a sequential node. - # Perhaps we need to do a deeper check on this... - if not all_unique(_nodes, key=lambda node: node.name): - raise DuplicateNamesError(self) + Args: + nodes: The nodes that this node leads to. In the case of a `Sequential`, + the order here matters and it signifies that data should first + be passed through the first node, then the second, etc. + item: The item attached to this node (if any). + name: The name of the node. If not specified, the name will be + randomly generated. + config: The configuration for this node. + space: The search space for this node. This will be used when + [`search_space()`][amltk.pipeline.node.Node.search_space] is called. + fidelities: The fidelities for this node. + config_transform: A function that transforms the `config=` parameter + during [`configure(config)`][amltk.pipeline.node.Node.configure] + before return the new configured node. Useful for times where + you need to combine multiple parameters into one. + meta: Any meta information about this node. + """ + _nodes = tuple(as_node(n) for n in nodes) if name is None: name = f"Seq-{randuid(8)}" @@ -749,88 +491,47 @@ def walk( @dataclass(init=False, frozen=True, eq=True) -class Split(Node[Item, Space]): - """A [`Split`][amltk.pipeline.Split] of data in a pipeline. +class Choice(Node[Item, Space]): + """A [`Choice`][amltk.pipeline.Choice] between different subcomponents. - This indicates the different children in - [`.nodes`][amltk.pipeline.Node.nodes] should - act in parallel but on different subsets of data. + This indicates that a choice should be made between the different children in + [`.nodes`][amltk.pipeline.Node.nodes], usually done when you + [`configure()`][amltk.pipeline.node.Node.configure] with some `config` from + a [`search_space()`][amltk.pipeline.node.Node.search_space]. ```python exec="true" source="material-block" html="true" - from amltk.pipeline import Component, Split - from sklearn.impute import SimpleImputer - from sklearn.preprocessing import OneHotEncoder - from sklearn.compose import make_column_selector - - categorical_pipeline = [ - SimpleImputer(strategy="constant", fill_value="missing"), - OneHotEncoder(drop="first"), - ] - numerical_pipeline = Component(SimpleImputer, space={"strategy": ["mean", "median"]}) - - preprocessor = Split( - { - "categories": categorical_pipeline, - "numerical": numerical_pipeline, - }, - config={ - # This is how you would configure the split for the sklearn builder in particular - "categories": make_column_selector(dtype_include="category"), - "numerical": make_column_selector(dtype_exclude="category"), - }, - name="my_split" - ) - from amltk._doc import doc_print; doc_print(print, preprocessor) # markdown-exec: hide - - space = preprocessor.search_space("configspace") - from amltk._doc import doc_print; doc_print(print, space) # markdown-exec: hide - - configuration = space.sample_configuration() - from amltk._doc import doc_print; doc_print(print, configuration) # markdown-exec: hide + from amltk.pipeline import Choice, Component + from sklearn.ensemble import RandomForestClassifier + from sklearn.neural_network import MLPClassifier - configured_preprocessor = preprocessor.configure(configuration) - from amltk._doc import doc_print; doc_print(print, configured_preprocessor) # markdown-exec: hide + rf = Component(RandomForestClassifier, space={"n_estimators": (10, 100)}) + mlp = Component(MLPClassifier, space={"activation": ["logistic", "relu", "tanh"]}) - built_preprocessor = configured_preprocessor.build("sklearn") - print(built_preprocessor._repr_html_()) # markdown-exec: hide + estimator_choice = Choice(rf, mlp, name="estimator") + from amltk._doc import doc_print; doc_print(print, estimator_choice) # markdown-exec: hide ``` - The split is a slight oddity when compared to the other kinds of components in that - it allows a `dict` as it's first argument, where the keys are the names of the - different paths through which data will go and the values are the actual nodes that - will receive the data. - - If nodes are passed in as they are for all other components, usually the name of the - first node will be important for any builder trying to make sense of how - to use the `Split` + !!! warning "Order of nodes" + The given nodes of a choice are always ordered according + to their name, so indexing `choice.nodes` may not be reliable + if modifying the choice dynamically. - Like all [`Node`][amltk.pipeline.node.Node]s, a `Split` accepts an explicit - [`name=`][amltk.pipeline.node.Node.name], - [`item=`][amltk.pipeline.node.Node.item], - [`config=`][amltk.pipeline.node.Node.config], - [`space=`][amltk.pipeline.node.Node.space], - [`fidelities=`][amltk.pipeline.node.Node.fidelities], - [`config_transform=`][amltk.pipeline.node.Node.config_transform] and - [`meta=`][amltk.pipeline.node.Node.meta]. + Please use `choice["name"]` to access the nodes instead. See Also: * [`Node`][amltk.pipeline.node.Node] """ # noqa: E501 nodes: tuple[Node, ...] - """The nodes that this node leads to.""" - - RICH_OPTIONS: ClassVar[RichOptions] = RichOptions( - panel_color="#777DA7", - node_orientation="horizontal", - ) + """The choice of possible nodes that this choice could take.""" + RICH_OPTIONS: ClassVar[RichOptions] = RichOptions(panel_color="#FF4500") _NODES_INIT: ClassVar = "args" def __init__( self, - *nodes: Node | NodeLike | dict[str, Node | NodeLike], + *nodes: Node | NodeLike, name: str | None = None, item: Item | Callable[[Item], Item] | None = None, config: Config | None = None, @@ -839,37 +540,29 @@ def __init__( config_transform: Callable[[Config, Any], Config] | None = None, meta: Mapping[str, Any] | None = None, ): - """See [`Node`][amltk.pipeline.node.Node] for details.""" - if any(isinstance(n, dict) for n in nodes): - if len(nodes) > 1: - raise ValueError( - "Can't handle multiple nodes with a dictionary as a node.\n" - f"{nodes=}", - ) - _node = nodes[0] - assert isinstance(_node, dict) - - def _construct(key: str, value: Node | NodeLike) -> Node: - match value: - case list(): - return Sequential(*value, name=key) - case set() | tuple(): - return as_node(value, name=key) - case _: - return Sequential(value, name=key) + """Initialize a choice node. - _nodes = tuple(_construct(key, value) for key, value in _node.items()) - else: - _nodes = tuple(as_node(n) for n in nodes) - - if not all_unique(_nodes, key=lambda node: node.name): - raise ValueError( - f"Can't handle nodes they do not all contain unique names, {nodes=}." - "\nAll nodes must have a unique name. Please provide a `name=` to them", - ) + Args: + nodes: The nodes that should be chosen between for this node. + item: The item attached to this node (if any). + name: The name of the node. If not specified, the name will be + randomly generated. + config: The configuration for this node. + space: The search space for this node. This will be used when + [`search_space()`][amltk.pipeline.node.Node.search_space] is called. + fidelities: The fidelities for this node. + config_transform: A function that transforms the `config=` parameter + during [`configure(config)`][amltk.pipeline.node.Node.configure] + before return the new configured node. Useful for times where + you need to combine multiple parameters into one. + meta: Any meta information about this node. + """ + _nodes: tuple[Node, ...] = tuple( + sorted((as_node(n) for n in nodes), key=lambda n: n.name), + ) if name is None: - name = f"Split-{randuid(8)}" + name = f"Choice-{randuid(8)}" super().__init__( *_nodes, @@ -882,190 +575,265 @@ def _construct(key: str, value: Node | NodeLike) -> Node: meta=meta, ) + @override + def __or__(self, other: Node | NodeLike) -> Choice: + other_node = as_node(other) + if any(other_node.name == node.name for node in self.nodes): + raise ValueError( + f"Can't handle node with name '{other_node.name} as" + f" there is already a node called '{other_node.name}' in {self.name}", + ) -@dataclass(init=False, frozen=True, eq=True) -class Component(Node[Item, Space]): - """A [`Component`][amltk.pipeline.Component] of the pipeline with - a possible item and **no children**. - - This is the basic building block of most pipelines, it accepts - as it's [`item=`][amltk.pipeline.node.Node.item] some function that will be - called with [`build_item()`][amltk.pipeline.components.Component.build_item] to - build that one part of the pipeline. - - When [`build_item()`][amltk.pipeline.Component.build_item] is called, - The [`.config`][amltk.pipeline.node.Node.config] on this node will be passed - to the function to build the item. - - A common pattern is to use a [`Component`][amltk.pipeline.Component] to - wrap a constructor, specifying the [`space=`][amltk.pipeline.node.Node.space] - and [`config=`][amltk.pipeline.node.Node.config] to be used when building the - item. - - ```python exec="true" source="material-block" html="true" - from amltk.pipeline import Component - from sklearn.ensemble import RandomForestClassifier - - rf = Component( - RandomForestClassifier, - config={"max_depth": 3}, - space={"n_estimators": (10, 100)} - ) - from amltk._doc import doc_print; doc_print(print, rf) # markdown-exec: hide + nodes = tuple( + sorted( + [as_node(n) for n in self.nodes] + [other_node], + key=lambda n: n.name, + ), + ) + return self.mutate(name=self.name, nodes=nodes) - config = {"n_estimators": 50} # Sample from some space or something - configured_rf = rf.configure(config) + def chosen(self) -> Node: + """The chosen branch. - estimator = configured_rf.build_item() - from amltk._doc import doc_print; doc_print(print, estimator) # markdown-exec: hide - ``` + Returns: + The chosen branch + """ + match self.config: + case {"__choice__": choice}: + chosen = first_true( + self.nodes, + pred=lambda node: node.name == choice, + default=None, + ) + if chosen is None: + raise NodeNotFoundError(choice, self.name) - Whenever some other node sees a function/constructor, i.e. `RandomForestClassifier`, - this will automatically be converted into a `Component`. + return chosen + case _: + raise NoChoiceMadeError(self.name) - ```python exec="true" source="material-block" html="true" - from amltk.pipeline import Sequential - from sklearn.ensemble import RandomForestClassifier + @override + def configure( + self, + config: Config, + *, + prefixed_name: bool | None = None, + transform_context: Any | None = None, + params: Mapping[str, Any] | None = None, + ) -> Self: + """Configure this node and anything following it with the given config. - pipeline = Sequential(RandomForestClassifier, name="my_pipeline") - from amltk._doc import doc_print; doc_print(print, pipeline) # markdown-exec: hide - ``` + !!! note "Configuring a choice" - The default `.name` of a component is the name of the class/function that it will - use. You can explicitly set the `name=` if you want to when constructing the - component. + For a Choice, if the config has a `__choice__` key, then only the node + chosen will be configured. The others will not be configured at all and + their config will be discarded. - Like all [`Node`][amltk.pipeline.node.Node]s, a `Component` accepts an explicit - [`name=`][amltk.pipeline.node.Node.name], - [`item=`][amltk.pipeline.node.Node.item], - [`config=`][amltk.pipeline.node.Node.config], - [`space=`][amltk.pipeline.node.Node.space], - [`fidelities=`][amltk.pipeline.node.Node.fidelities], - [`config_transform=`][amltk.pipeline.node.Node.config_transform] and - [`meta=`][amltk.pipeline.node.Node.meta]. + Args: + config: The configuration to apply + prefixed_name: Whether items in the config are prefixed by the names + of the nodes. + * If `None`, the default, then `prefixed_name` will be assumed to + be `True` if this node has a next node or if the config has + keys that begin with this nodes name. + * If `True`, then the config will be searched for items prefixed + by the name of the node (and subsequent chained nodes). + * If `False`, then the config will be searched for items without + the prefix, i.e. the config keys are exactly those matching + this nodes search space. + transform_context: Any context to give to `config_transform=` of individual + nodes. + params: The params to match any requests when configuring this node. + These will match against any ParamRequests in the config and will + be used to fill in any missing values. - See Also: - * [`Node`][amltk.pipeline.node.Node] - """ + Returns: + The configured node + """ + # Get the config for this node + match prefixed_name: + case True: + config = mapping_select(config, f"{self.name}:") + case False: + pass + case None if any(k.startswith(f"{self.name}:") for k in config): + config = mapping_select(config, f"{self.name}:") + case None: + pass - item: Callable[..., Item] - """A node which constructs an item in the pipeline.""" + _kwargs: dict[str, Any] = {} - nodes: tuple[()] - """A component has no children.""" + # Configure all the branches if exists + # This part is what differs for a Choice + if len(self.nodes) > 0: + choice_made = config.get("__choice__", None) + if choice_made is not None: + matching_child = first_true( + self.nodes, + pred=lambda node: node.name == choice_made, + default=None, + ) + if matching_child is None: + raise ValueError( + f"Can not find matching child for choice {self.name} with child" + f" {choice_made}." + "\nPlease check the config and ensure that the choice is one of" + f" {[n.name for n in self.nodes]}." + f"\nThe config recieved at this choice node was {config=}.", + ) - RICH_OPTIONS: ClassVar[RichOptions] = RichOptions(panel_color="#E6AF2E") + # We still iterate over all of them just to ensure correct ordering + nodes = tuple( + node.copy() + if node.name != choice_made + else matching_child.configure( + config, + prefixed_name=True, + transform_context=transform_context, + params=params, + ) + for node in self.nodes + ) + _kwargs["nodes"] = nodes + else: + nodes = tuple( + node.configure( + config, + prefixed_name=True, + transform_context=transform_context, + params=params, + ) + for node in self.nodes + ) + _kwargs["nodes"] = nodes - _NODES_INIT: ClassVar = None + this_config = { + hp: v + for hp, v in config.items() + if ( + ":" not in hp + and not any(hp.startswith(f"{node.name}") for node in self.nodes) + ) + } + if self.config is not None: + this_config = {**self.config, **this_config} - def __init__( - self, - item: Callable[..., Item], - *, - name: str | None = None, - config: Config | None = None, - space: Space | None = None, - fidelities: Mapping[str, Any] | None = None, - config_transform: Callable[[Config, Any], Config] | None = None, - meta: Mapping[str, Any] | None = None, - ): - """See [`Node`][amltk.pipeline.node.Node] for details.""" - super().__init__( - name=name if name is not None else entity_name(item), - item=item, - config=config, - space=space, - fidelities=fidelities, - config_transform=config_transform, - meta=meta, - ) + this_config = dict(self._fufill_param_requests(this_config, params=params)) - def build_item(self, **kwargs: Any) -> Item: - """Build the item attached to this component. + if self.config_transform is not None: + this_config = dict(self.config_transform(this_config, transform_context)) - Args: - **kwargs: Any additional arguments to pass to the item + if len(this_config) > 0: + _kwargs["config"] = dict(this_config) - Returns: - Item - The built item - """ - config = self.config or {} - try: - return self.item(**{**config, **kwargs}) - except TypeError as e: - new_msg = f"Failed to build `{self.item=}` with `{self.config=}`.\n" - if any(kwargs): - new_msg += f"Extra {kwargs=} were also provided.\n" - new_msg += ( - "If the item failed to initialize, a common reason can be forgetting" - " to call `configure()` on the `Component` or the pipeline it is in or" - " not calling `build()`/`build_item()` on the **returned** value of" - " `configure()`.\n" - "Reasons may also include not having fully specified the `config`" - " initially, it having not being configured fully from `configure()`" - " or from misspecfying parameters in the `space`." - ) - raise ComponentBuildError(new_msg) from e + return self.mutate(**_kwargs) @dataclass(init=False, frozen=True, eq=True) -class Searchable(Node[None, Space]): # type: ignore - """A [`Searchable`][amltk.pipeline.Searchable] - node of the pipeline which just represents a search space, no item attached. +class Split(Node[Item, Space]): + """A [`Split`][amltk.pipeline.Split] of data in a pipeline. - While not usually applicable to pipelines you want to build, this component - is useful for creating a search space, especially if the real pipeline you - want to optimize can not be built directly. For example, if you are optimize - a script, you may wish to use a `Searchable` to represent the search space - of that script. + This indicates the different children in + [`.nodes`][amltk.pipeline.Node.nodes] should + act in parallel but on different subsets of data. ```python exec="true" source="material-block" html="true" - from amltk.pipeline import Searchable + from amltk.pipeline import Component, Split + from sklearn.impute import SimpleImputer + from sklearn.preprocessing import OneHotEncoder + from sklearn.compose import make_column_selector - script_space = Searchable({"mode": ["orange", "blue", "red"], "n": (10, 100)}) - from amltk._doc import doc_print; doc_print(print, script_space) # markdown-exec: hide - ``` + categorical_pipeline = [ + SimpleImputer(strategy="constant", fill_value="missing"), + OneHotEncoder(drop="first"), + ] + numerical_pipeline = Component(SimpleImputer, space={"strategy": ["mean", "median"]}) - A `Searchable` explicitly does not allow for `item=` to be set, nor can it have - any children. A `Searchable` accepts an explicit - [`name=`][amltk.pipeline.node.Node.name], - [`config=`][amltk.pipeline.node.Node.config], - [`space=`][amltk.pipeline.node.Node.space], - [`fidelities=`][amltk.pipeline.node.Node.fidelities], - [`config_transform=`][amltk.pipeline.node.Node.config_transform] and - [`meta=`][amltk.pipeline.node.Node.meta]. + preprocessor = Split( + { + "categories": categorical_pipeline, + "numerical": numerical_pipeline, + }, + config={ + "categories": make_column_selector(dtype_include="category"), + "numerical": make_column_selector(dtype_exclude="category"), + }, + name="my_split" + ) + from amltk._doc import doc_print; doc_print(print, preprocessor) # markdown-exec: hide + ``` See Also: * [`Node`][amltk.pipeline.node.Node] """ # noqa: E501 - item: None = None - """A searchable has no item.""" - - nodes: tuple[()] = () - """A component has no children.""" - - RICH_OPTIONS: ClassVar[RichOptions] = RichOptions(panel_color="light_steel_blue") + RICH_OPTIONS: ClassVar[RichOptions] = RichOptions( + panel_color="#777DA7", + node_orientation="horizontal", + ) - _NODES_INIT: ClassVar = None + _NODES_INIT: ClassVar = "args" def __init__( self, - space: Space | None = None, - *, + *nodes: Node | NodeLike | dict[str, Node | NodeLike], name: str | None = None, + item: Item | Callable[[Item], Item] | None = None, config: Config | None = None, + space: Space | None = None, fidelities: Mapping[str, Any] | None = None, config_transform: Callable[[Config, Any], Config] | None = None, meta: Mapping[str, Any] | None = None, ): - """See [`Node`][amltk.pipeline.node.Node] for details.""" + """Initialize a split node. + + Args: + nodes: The nodes that this node leads to. You may also provide + a dictionary where the keys are the names of the nodes and + the values are the nodes or list of nodes themselves. + item: The item attached to this node. The object created by `item` + should be capable of figuring out how to deal with its child nodes. + name: The name of the node. If not specified, the name will be + generated from the item. + config: The configuration for this split. + space: The search space for this node. This will be used when + [`search_space()`][amltk.pipeline.node.Node.search_space] is called. + fidelities: The fidelities for this node. + config_transform: A function that transforms the `config=` parameter + during [`configure(config)`][amltk.pipeline.node.Node.configure] + before return the new configured node. Useful for times where + you need to combine multiple parameters into one. + meta: Any meta information about this node. + """ + if any(isinstance(n, dict) for n in nodes): + if len(nodes) > 1: + raise ValueError( + "Can't handle multiple nodes with a dictionary as a node.\n" + f"{nodes=}", + ) + _node = nodes[0] + assert isinstance(_node, dict) + + def _construct(key: str, value: Node | NodeLike) -> Node: + match value: + case list(): + return Sequential(*value, name=key) + case set() | tuple(): + return as_node(value, name=key) + case _: + return Sequential(value, name=key) + + _nodes = tuple(_construct(key, value) for key, value in _node.items()) + else: + _nodes = tuple(as_node(n) for n in nodes) + if name is None: - name = f"Searchable-{randuid(8)}" + name = f"Split-{randuid(8)}" super().__init__( + *_nodes, name=name, + item=item, config=config, space=space, fidelities=fidelities, @@ -1075,94 +843,72 @@ def __init__( @dataclass(init=False, frozen=True, eq=True) -class Fixed(Node[Item, None]): # type: ignore - """A [`Fixed`][amltk.pipeline.Fixed] part of the pipeline that - represents something that can not be configured and used directly as is. - - It consists of an [`.item`][amltk.pipeline.node.Node.item] that is fixed, - non-configurable and non-searchable. It also has no children. +class Join(Node[Item, Space]): + """[`Join`][amltk.pipeline.Join] together different parts of the pipeline. - This is useful for representing parts of the pipeline that are fixed, for example - if you have a pipeline that is a `Sequential` of nodes, but you want to - fix the first component to be a `PCA` with `n_components=3`, you can use a `Fixed` - to represent that. + This indicates the different children in + [`.nodes`][amltk.pipeline.Node.nodes] should act in tandem with one + another, for example, concatenating the outputs of the various members of the + `Join`. ```python exec="true" source="material-block" html="true" - from amltk.pipeline import Component, Fixed, Sequential - from sklearn.ensemble import RandomForestClassifier + from amltk.pipeline import Join, Component from sklearn.decomposition import PCA + from sklearn.feature_selection import SelectKBest - rf = Component(RandomForestClassifier, space={"n_estimators": (10, 100)}) - pca = Fixed(PCA(n_components=3)) - - pipeline = Sequential(pca, rf, name="my_pipeline") - from amltk._doc import doc_print; doc_print(print, pipeline) # markdown-exec: hide - ``` - - Whenever some other node sees an instance of something, i.e. something that can't be - called, this will automatically be converted into a `Fixed`. - - ```python exec="true" source="material-block" html="true" - from amltk.pipeline import Sequential - from sklearn.ensemble import RandomForestClassifier - from sklearn.decomposition import PCA + pca = Component(PCA, space={"n_components": (1, 3)}) + kbest = Component(SelectKBest, space={"k": (1, 3)}) - pipeline = Sequential( - PCA(n_components=3), - RandomForestClassifier(n_estimators=50), - name="my_pipeline", - ) - from amltk._doc import doc_print; doc_print(print, pipeline) # markdown-exec: hide + join = Join(pca, kbest, name="my_feature_union") + from amltk._doc import doc_print; doc_print(print, join) # markdown-exec: hide ``` - The default `.name` of a component is the class name of the item that it will - use. You can explicitly set the `name=` if you want to when constructing the - component. - - A `Fixed` accepts only an explicit [`name=`][amltk.pipeline.node.Node.name], - [`item=`][amltk.pipeline.node.Node.item], - [`meta=`][amltk.pipeline.node.Node.meta]. - See Also: * [`Node`][amltk.pipeline.node.Node] """ - item: Item = field() - """The fixed item that this node represents.""" - - space: None = None - """A frozen node has no search space.""" - - fidelities: None = None - """A frozen node has no search space.""" - - config: None = None - """A frozen node has no config.""" - - config_transform: None = None - """A frozen node has no config so no transform.""" - - nodes: tuple[()] = () - """A component has no children.""" - - RICH_OPTIONS: ClassVar[RichOptions] = RichOptions(panel_color="#56351E") + nodes: tuple[Node, ...] + """The nodes that should be joined together in parallel.""" - _NODES_INIT: ClassVar = None + RICH_OPTIONS: ClassVar[RichOptions] = RichOptions(panel_color="#7E6B8F") + _NODES_INIT: ClassVar = "args" def __init__( self, - item: Item, - *, + *nodes: Node | NodeLike, name: str | None = None, - config: None = None, - space: None = None, - fidelities: None = None, - config_transform: None = None, + item: Item | Callable[[Item], Item] | None = None, + config: Config | None = None, + space: Space | None = None, + fidelities: Mapping[str, Any] | None = None, + config_transform: Callable[[Config, Any], Config] | None = None, meta: Mapping[str, Any] | None = None, ): - """See [`Node`][amltk.pipeline.node.Node] for details.""" + """Initialize a join node. + + Args: + nodes: The nodes that should be joined together in parallel. + item: The item attached to this node (if any). + name: The name of the node. If not specified, the name will be + randomly generated. + config: The configuration for this node. + space: The search space for this node. This will be used when + [`search_space()`][amltk.pipeline.node.Node.search_space] is called. + fidelities: The fidelities for this node. + config_transform: A function that transforms the `config=` parameter + during [`configure(config)`][amltk.pipeline.node.Node.configure] + before return the new configured node. Useful for times where + you need to combine multiple parameters into one. + meta: Any meta information about this node. + """ + _nodes = tuple(as_node(n) for n in nodes) + + if name is None: + name = f"Join-{randuid(8)}" + super().__init__( - name=name if name is not None else entity_name(item), + *_nodes, + name=name, item=item, config=config, space=space, @@ -1170,3 +916,15 @@ def __init__( config_transform=config_transform, meta=meta, ) + + @override + def __and__(self, other: Node | NodeLike) -> Join: + other_node = as_node(other) + if any(other_node.name == node.name for node in self.nodes): + raise ValueError( + f"Can't handle node with name '{other_node.name} as" + f" there is already a node called '{other_node.name}' in {self.name}", + ) + + nodes = (*tuple(as_node(n) for n in self.nodes), other_node) + return self.mutate(name=self.name, nodes=nodes) diff --git a/src/amltk/pipeline/node.py b/src/amltk/pipeline/node.py index ec1ce55e..02e3c35b 100644 --- a/src/amltk/pipeline/node.py +++ b/src/amltk/pipeline/node.py @@ -36,11 +36,16 @@ means two nodes are considered equal if they look the same and they are connected in to nodes that also look the same. """ # noqa: E501 +# ruff: noqa: PLR0913 from __future__ import annotations import inspect -from collections.abc import Callable, Iterator, Mapping, Sequence +import warnings +from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from dataclasses import dataclass, field +from datetime import datetime +from functools import partial +from pathlib import Path from typing import ( TYPE_CHECKING, Any, @@ -50,19 +55,29 @@ Literal, NamedTuple, ParamSpec, + Protocol, TypeAlias, TypeVar, overload, ) from typing_extensions import override -from more_itertools import first_true +from more_itertools import all_unique, first_true from sklearn.pipeline import Pipeline as SklearnPipeline -from amltk._functional import classname, mapping_select, prefix_keys +from amltk._functional import classname, funcname, mapping_select, prefix_keys from amltk._richutil import RichRenderable -from amltk.exceptions import RequestNotMetError -from amltk.types import Config, Item, Space +from amltk.exceptions import ( + AutomaticThreadPoolCTLWarning, + DuplicateNamesError, + RequestNotMetError, +) +from amltk.optimization.history import History +from amltk.optimization.optimizer import Optimizer +from amltk.scheduling import Task +from amltk.scheduling.plugins import Plugin +from amltk.store import PathBucket +from amltk.types import Config, Item, Seed, Space if TYPE_CHECKING: from typing_extensions import Self @@ -71,8 +86,11 @@ from rich.console import RenderableType from rich.panel import Panel + from amltk.optimization.metric import Metric + from amltk.optimization.trial import Trial from amltk.pipeline.components import Choice, Join, Sequential from amltk.pipeline.parsers.optuna import OptunaSearchSpace + from amltk.scheduling import Scheduler NodeLike: TypeAlias = ( set["Node" | "NodeLike"] @@ -84,13 +102,94 @@ SklearnPipelineT = TypeVar("SklearnPipelineT", bound=SklearnPipeline) + +class OnBeginCallbackSignature(Protocol): + """A calllback to further define control flow from + [`pipeline.optimize()`][amltk.pipeline.node.Node.optimize]. + + In one of these callbacks, you can register to specific `@events` of the + [`Scheduler`][amltk.scheduling.Scheduler] or [`Task`][amltk.scheduling.Task]. + + ```python + pipeline = ... + + # The callback will get the task, scheduler and the history in which results + # will be stored + def my_callback(task: Task[..., Trial.Report], scheduler: Scheduler, history: History) -> None: + + # You can do early stopping based on a target metric + @task.on_result + def stop_if_target_reached(_: Future, report: Trial.Report) -> None: + score = report.values["accuracy"] + if score >= 0.95: + scheduler.stop(stop_msg="Target reached!")) + + # You could also perform early stopping based on iterations + n = 0 + last_score = 0.0 + + @task.on_result + def stop_if_no_improvement_for_n_runs(_: Future, report: Trial.Report) -> None: + score = report.values["accuracy"] + if score > last_score: + n = 0 + last_score = score + elif n >= 5: + scheduler.stop() + else: + n += 1 + + # Really whatever you'd like + @task.on_result + def print_if_choice_made(_: Future, report: Trial.Report) -> None: + if report.config["estimator:__choice__"] == "random_forest": + print("yay") + + # Every callback will be called here in the main process so it's + # best not to do anything too heavy here. + # However you can also submit new tasks or jobs to the scheduler too + @task.on_result(every=30) # Do a cleanup sweep every 30 trials + def keep_on_ten_best_models_on_disk(_: Future, report: Trial.Report) -> None: + sorted_reports = history.sortby("accuracy") + reports_to_cleanup = sorted_reports[10:] + scheduler.submit(some_cleanup_function, reporteds_to_cleanup) + + history = pipeline.optimize( + ..., + on_begin=my_callback, + ) + ``` + """ # noqa: E501 + + def __call__( + self, + task: Task[[Trial, Node], Trial.Report], + scheduler: Scheduler, + history: History, + ) -> None: + """Signature for the callback. + + Args: + task: The task that will be run + scheduler: The scheduler that will be running the optimization + history: The history that will be used to collect the results + """ + ... + + T = TypeVar("T") ParserOutput = TypeVar("ParserOutput") BuilderOutput = TypeVar("BuilderOutput") P = ParamSpec("P") -_NotSet = object() +class _NotSetType: + @override + def __repr__(self) -> str: + return "" + + +_NotSet = _NotSetType() class RichOptions(NamedTuple): @@ -104,6 +203,8 @@ class RichOptions(NamedTuple): class ParamRequest(Generic[T]): """A parameter request for a node. This is most useful for things like seeds.""" + _has_default: bool + key: str """The key to request under.""" @@ -118,7 +219,10 @@ class ParamRequest(Generic[T]): @property def has_default(self) -> bool: """Whether this request has a default value.""" - return self.default is not _NotSet + # NOTE(eddiebergman): We decide to calculate this on + # initialization as when sent to new processes, these object + # ids may not match + return self._has_default def request(key: str, default: T | object = _NotSet) -> ParamRequest[T]: @@ -131,7 +235,7 @@ def request(key: str, default: T | object = _NotSet) -> ParamRequest[T]: config once [`configure`][amltk.pipeline.Node.configure] is called and nothing has been provided. """ - return ParamRequest(key=key, default=default) + return ParamRequest(key=key, default=default, _has_default=default is not _NotSet) @dataclass(init=False, frozen=True, eq=True) @@ -160,6 +264,7 @@ class Node(RichRenderable, Generic[Item, Space]): fidelities: Mapping[str, Any] | None = field(hash=False) """The fidelities for this node""" + config_transform: Callable[[Config, Any], Config] | None = field(hash=False) """A function that transforms the configuration of this node""" @@ -173,7 +278,7 @@ class Node(RichRenderable, Generic[Item, Space]): panel_color="default", node_orientation="horizontal", ) - """Options for rich printing""" + """How to display this node in rich.""" def __init__( self, @@ -186,7 +291,18 @@ def __init__( config_transform: Callable[[Config, Any], Config] | None = None, meta: Mapping[str, Any] | None = None, ): - """Initialize a choice.""" + """Initialize a choice. + + Args: + nodes: The nodes that this node leads to + name: The name of the node + item: The item attached to this node + config: The configuration for this node + space: The search space for this node + fidelities: The fidelities for this node + config_transform: A function that transforms the configuration of this node + meta: Any meta information about this node + """ super().__init__() object.__setattr__(self, "name", name) object.__setattr__(self, "item", item) @@ -197,8 +313,20 @@ def __init__( object.__setattr__(self, "meta", meta) object.__setattr__(self, "nodes", nodes) + if not all_unique(node.name for node in self.nodes): + raise DuplicateNamesError( + f"Duplicate node names in {self}. " "All nodes must have unique names.", + ) + + for child in self.nodes: + if child.name == self.name: + raise DuplicateNamesError( + f"Cannot have a child node with the same name as its parent. " + f"{self.name=} {child.name=}", + ) + def __getitem__(self, key: str) -> Node: - """Get the node with the given name.""" + """Get the first from [`.nodes`][amltk.pipeline.node.Node.nodes] with `key`.""" found = first_true( self.nodes, None, @@ -206,7 +334,7 @@ def __getitem__(self, key: str) -> Node: ) if found is None: raise KeyError( - f"Could not find node with name {key} in '{self.name}'." + f"Could not find node with name `{key}` in '{self.name}'." f" Available nodes are: {', '.join(node.name for node in self.nodes)}", ) @@ -366,15 +494,38 @@ def linearized_fidelity(self, value: float) -> dict[str, int | float | Any]: return prefix_keys(d, f"{self.name}:") - def iter(self) -> Iterator[Node]: - """Iterate the the nodes, including this node. + def iter(self, *, skip_unchosen: bool = False) -> Iterator[Node]: + """Recursively iterate through the nodes starting from this node. + + This method traverses the nodes in a depth-first manner, including + the current node and its children nodes. + + Args: + skip_unchosen (bool): Flag to skip unchosen nodes in Choice nodes. Yields: - The nodes connected to this node + Iterator[Node]: Nodes connected to this node. """ + # Import Choice node to avoid circular imports + from amltk.pipeline.components import Choice + + # Yield the current node yield self + + # Iterate through the child nodes for node in self.nodes: - yield from node.iter() + if skip_unchosen and isinstance(node, Choice): + # If the node is a Choice and skipping unchosen nodes is enabled + chosen_node = node.chosen() + if chosen_node is None: + raise RuntimeError( + f"No Node chosen in Choice node {node.name}. " + f"Did you call configure?", + ) + yield from chosen_node.iter(skip_unchosen=skip_unchosen) + else: + # Recursively iterate through the child nodes + yield from node.iter(skip_unchosen=skip_unchosen) def mutate(self, **kwargs: Any) -> Self: """Mutate the node with the given keyword arguments. @@ -776,3 +927,634 @@ def _fufill_param_requests( continue return new_config + + def register_optimization_loop( # noqa: C901, PLR0915, PLR0912 + self, + target: ( + Task[[Trial, Node], Trial.Report] | Callable[[Trial, Node], Trial.Report] + ), + metric: Metric | Sequence[Metric], + *, + optimizer: ( + type[Optimizer] | Optimizer.CreateSignature | Optimizer | None + ) = None, + seed: Seed | None = None, + max_trials: int | None = None, + n_workers: int = 1, + working_dir: str | Path | PathBucket | None = None, + scheduler: Scheduler | None = None, + history: History | None = None, + on_begin: OnBeginCallbackSignature | None = None, + on_trial_exception: Literal["raise", "end", "continue"] = "raise", + # Plugin creating arguments + plugins: Plugin | Iterable[Plugin] | None = None, + process_memory_limit: int | tuple[int, str] | None = None, + process_walltime_limit: int | tuple[float, str] | None = None, + process_cputime_limit: int | tuple[float, str] | None = None, + threadpool_limit_ctl: bool | int | None = None, + ) -> tuple[Scheduler, Task[[Trial, Node], Trial.Report], History]: + """Setup a pipeline to be optimized in a loop. + + Args: + target: + The function against which to optimize. + + * If `target` is a function, then it must take in a + [`Trial`][amltk.optimization.trial.Trial] as the first argument + and a [`Node`][amltk.pipeline.node.Node] second argument, returning a + [`Trial.Report`][amltk.optimization.trial.Trial.Report]. Please refer to + the [optimization guide](../../../guides/optimization.md) for more. + + * If `target` is a [`Task`][amltk.scheduling.task.Task], then + this will be used instead, updating the plugins with any additional + plugins specified. + metric: + The metric(s) that will be passed to `optimizer=`. These metrics + should align with what is being computed in `target=`. + optimizer: + The optimizer to use. If `None`, then AMLTK will go through a list + of known optimizers and use the first one it can find which was installed. + + Alternatively, this can be a class inheriting from + [`Optimizer`][amltk.optimization.optimizer.Optimizer] or else + a signature match [`Optimizer.CreateSignature`][amltk.optimization.Optimizer.CreateSignature]. + + ??? tip "`Optimizer.CreateSignature`" + + ::: amltk.optimization.Optimizer.CreateSignature + + Lastly, you can also pass in your own already instantiated optimizer if you prefer, however + you should make sure to set it up correctly with the given metrics and search space. + It is recommened to just pass in the class if you are unsure how to do this properly. + seed: + A [`seed`][amltk.types.Seed] for the optimizer to use. + n_workers: + The numer of workers to use to evaluate this pipeline. + If no `scheduler=` is provided, then one will be created for + you as [`Scheduler.with_processes(n_workers)`][amltk.scheduling.Scheduler.with_processes]. + If you provide your own `scheduler=` then this will limit the maximum + amount of concurrent trials for this pipeline that will be evaluating + at once. + working_dir: + A working directory to use for the optimizer and the trials. + Any items you store in trials will be located in this directory, + where the [`trial.name`][amltk.optimization.Trial.name] will be + used as a subfolder where any contents stored with + [`trial.store()`][amltk.optimization.trial.Trial.store] will be put there. + Please see the [optimization guide](../../../guides/optimization.md) + for more on trial storage. + scheduler: + The specific [`Scheduler`][amltk.scheduling.Scheduler] to use. + If `None`, then one will be created for you with + [`Scheduler.with_processes(n_workers)`][amltk.scheduling.Scheduler.with_processes] + history: + A [`History`][amltk.optimization.history.History] to store the + [`Trial.Report`][amltk.optimization.Trial.Report]s in. You + may pass in your own if you wish for this method to store + it there instead of creating its own. + on_begin: + A callback that will be called before the scheduler is run. This + can be used to hook into the life-cycle of the optimization and + perform custom routines. Please see the + [scheduling guide](../../../guides/scheduling.md) for more. + + ??? tip "on_begin signature" + + ::: amltk.pipeline.node.OnBeginCallbackSignature + + on_trial_exception: + What to do when a trial returns a fail report from + [`trial.fail()`][amltk.optimization.trial.Trial.fail] or + [`trial.crashed()`][amltk.optimization.trial.Trial.crashed] + that contains an exception. + + Please see the [optimization guide](../../../guides/optimization.md) + for more. In all cases, the exception will be attached to the + [`Trial.Report`][amltk.optimization.Trial.Report] object under + [`report.exception`][amltk.optimization.Trial.Report.exception]. + + * If `#!python "raise"`, then the exception will be raised + immediatly and the optimization process will halt. The default + and good for initial development. + * If `#!python "end"`, then the exception will be caught and + the optimization process will end gracefully. + * If `#!python "continue"`, the exception will be ignored and + the optimization procedure will continue. + + max_trials: + The maximum number of trials to run. If `None`, then the + optimization will continue for as long as the scheduler is + running. You'll likely want to configure this. + process_memory_limit: + If specified, the [`Task`][amltk.scheduling.task.Task] will + use the + [`PynisherPlugin`][amltk.scheduling.plugins.pynisher.PynisherPlugin] + to limit the memory the process can use. Please + refer to the + [plugins `pynisher` reference](../../../reference/scheduling/plugins.md#pynisher) + for more as there are platform limitations and additional + dependancies required. + process_walltime_limit: + If specified, the [`Task`][amltk.scheduling.task.Task] will + use the + [`PynisherPlugin`][amltk.scheduling.plugins.pynisher.PynisherPlugin] + to limit the wall time the process can use. Please + refer to the + [plugins `pynisher` reference](../../../reference/scheduling/plugins.md#pynisher) + for more as there are platform limitations and additional + dependancies required. + process_cputime_limit: + If specified, the [`Task`][amltk.scheduling.task.Task] will + use the + [`PynisherPlugin`][amltk.scheduling.plugins.pynisher.PynisherPlugin] + to limit the cputime the process can use. Please + refer to the + [plugins `pynisher` reference](../../../reference/scheduling/plugins.md#pynisher) + for more as there are platform limitations and additional + dependancies required. + threadpool_limit_ctl: + If specified, the [`Task`][amltk.scheduling.task.Task] will + use the + [`ThreadPoolCTLPlugin`][amltk.scheduling.plugins.threadpoolctl.ThreadPoolCTLPlugin] + to limit the number of threads used by compliant libraries. + **Notably**, this includes scikit-learn, for which running multiple + in parallel can be problematic if not adjusted accordingly. + + The default behavior (when `None`) is to auto-detect whether this + is applicable. This is done by checking if `sklearn` is installed + and if the first node in the pipeline has a `BaseEstimator` item. + Please set this to `True`/`False` depending on your preference. + plugins: + Additional plugins to attach to the eventual + [`Task`][amltk.scheduling.task.Task] that will be executed by + the [`Scheduler`][amltk.scheduling.Scheduler]. Please + refer to the + [plugins reference](../../../reference/scheduling/plugins.md) for more. + + Returns: + A tuple of the [`Scheduler`][amltk.scheduling.Scheduler], the + [`Task`][amltk.scheduling.task.Task] and the + [`History`][amltk.optimization.history.History] that reports will be put into. + """ # noqa: E501 + match history: + case None: + history = History() + case History(): + pass + case _: + raise ValueError(f"Invalid history {history}. Must be a History") + + _plugins: tuple[Plugin, ...] + match plugins: + case None: + _plugins = () + case Plugin(): + _plugins = (plugins,) + case Iterable(): + _plugins = tuple(plugins) + case _: + raise ValueError( + f"Invalid plugins {plugins}. Must be a Plugin or an Iterable of" + " Plugins", + ) + + if any( + limit is not None + for limit in ( + process_memory_limit, + process_walltime_limit, + process_cputime_limit, + ) + ): + try: + from amltk.scheduling.plugins.pynisher import PynisherPlugin + except ImportError as e: + raise ImportError( + "You must install `pynisher` to use `trial_*_limit`" + " You can do so with `pip install amltk[pynisher]`" + " or `pip install pynisher` directly", + ) from e + # TODO: I'm hesitant to add even more arguments to the `optimize` + # signature, specifically for `mp_context`. + plugin = PynisherPlugin( + memory_limit=process_memory_limit, + walltime_limit=process_walltime_limit, + cputime_limit=process_cputime_limit, + ) + _plugins = (*_plugins, plugin) + + # If threadpool_limit_ctl None, we should default to inspecting if it's + # an sklearn pipeline. This is because sklearn pipelines + # run in parallel will over-subscribe the CPU and cause + # the system to slow down. + # We use a heuristic to check this by checking if the item at the head + # of this node is a subclass of sklearn.base.BaseEstimator + match threadpool_limit_ctl: + case None: + from amltk._util import threadpoolctl_heuristic + + threadpool_limit_ctl = False + if threadpoolctl_heuristic(self.item): + threadpool_limit_ctl = 1 + warnings.warn( + "Detected an sklearn pipeline. Setting `threadpool_limit_ctl`" + " to True. This will limit the number of threads spawned by" + " sklearn to the number of cores on the machine. This is" + " because sklearn pipelines run in parallel will over-subscribe" + " the CPU and cause the system to slow down." + "\nPlease set `threadpool_limit_ctl=False` if you do not want" + " this behaviour and set it to `True` to silence this warning.", + AutomaticThreadPoolCTLWarning, + stacklevel=2, + ) + case True: + threadpool_limit_ctl = 1 + case False: + pass + case int(): + pass + case _: + raise ValueError( + f"Invalid threadpool_limit_ctl {threadpool_limit_ctl}." + " Must be a bool or an int", + ) + + if threadpool_limit_ctl is not False: + from amltk.scheduling.plugins.threadpoolctl import ThreadPoolCTLPlugin + + _plugins = (*_plugins, ThreadPoolCTLPlugin(threadpool_limit_ctl)) + + match max_trials: + case None: + pass + case int() if max_trials > 0: + from amltk.scheduling.plugins import Limiter + + _plugins = (*_plugins, Limiter(max_calls=max_trials)) + case _: + raise ValueError(f"{max_trials=} must be a positive int") + + from amltk.scheduling.scheduler import Scheduler + + match scheduler: + case None: + scheduler = Scheduler.with_processes(n_workers) + case Scheduler(): + pass + case _: + raise ValueError(f"Invalid scheduler {scheduler}. Must be a Scheduler") + + match target: + case Task(): # type: ignore # NOTE not sure why pyright complains here + for _p in _plugins: + target.attach_plugin(_p) + task = target + case _ if callable(target): + task = scheduler.task(target, plugins=_plugins) + case _: + raise ValueError(f"Invalid {target=}. Must be a function or Task.") + + if isinstance(optimizer, Optimizer): + _optimizer = optimizer + else: + # NOTE: I'm not particularly fond of this hack but I assume most people + # when prototyping don't care for the actual underlying optimizer and + # so we should just *pick one*. + create_optimizer: Optimizer.CreateSignature + match optimizer: + case None: + first_opt_class = next( + Optimizer._get_known_importable_optimizer_classes(), + None, + ) + if first_opt_class is None: + raise ValueError( + "No optimizer was given and no known importable optimizers " + " were found. Please consider giving one explicitly or" + " installing one of the following packages:\n" + "\n - optuna" + "\n - smac" + "\n - neural-pipeline-search", + ) + + create_optimizer = first_opt_class.create + opt_name = classname(first_opt_class) + case type(): + if not issubclass(optimizer, Optimizer): + raise ValueError( + f"Invalid optimizer {optimizer}. Must be a subclass of" + " Optimizer or a function that returns an Optimizer", + ) + create_optimizer = optimizer.create + opt_name = classname(optimizer) + case _: + assert not isinstance(optimizer, type) + create_optimizer = optimizer + opt_name = funcname(optimizer) + + match working_dir: + case None: + now = datetime.utcnow().isoformat() + + working_dir = PathBucket(f"{opt_name}-{self.name}-{now}") + case str() | Path(): + working_dir = PathBucket(working_dir) + case PathBucket(): + pass + case _: + raise ValueError( + f"Invalid working_dir {working_dir}." + " Must be a str, Path or PathBucket", + ) + + _optimizer = create_optimizer( + space=self, + metrics=metric, + bucket=working_dir, + seed=seed, + ) + assert _optimizer is not None + + if on_begin is not None: + hook = partial(on_begin, task, scheduler, history) + scheduler.on_start(hook) + + @scheduler.on_start + def launch_initial_trials() -> None: + trials = _optimizer.ask(n=n_workers) + for trial in trials: + task.submit(trial, self) + + from amltk.optimization.trial import Trial + + @task.on_result + def tell_optimizer(_: Any, report: Trial.Report) -> None: + _optimizer.tell(report) + + @task.on_result + def add_report_to_history(_: Any, report: Trial.Report) -> None: + history.add(report) + match report.status: + case Trial.Status.SUCCESS: + return + case Trial.Status.FAIL | Trial.Status.CRASHED | Trial.Status.UNKNOWN: + match on_trial_exception: + case "raise": + if report.exception is None: + raise RuntimeError( + f"Trial finished with status {report.status} but" + " no exception was attached!", + ) + raise report.exception + case "end": + scheduler.stop( + stop_msg=f"Trial finished with status {report.status}", + exception=report.exception, + ) + case "continue": + pass + case _: + raise ValueError(f"Invalid status {report.status}") + + @task.on_result + def run_next_trial(*_: Any) -> None: + if scheduler.running(): + trial = _optimizer.ask() + task.submit(trial, self) + + return scheduler, task, history + + def optimize( + self, + target: ( + Callable[[Trial, Node], Trial.Report] | Task[[Trial, Node], Trial.Report] + ), + metric: Metric | Sequence[Metric], + *, + optimizer: ( + type[Optimizer] | Optimizer.CreateSignature | Optimizer | None + ) = None, + seed: Seed | None = None, + max_trials: int | None = None, + n_workers: int = 1, + timeout: float | None = None, + working_dir: str | Path | PathBucket | None = None, + scheduler: Scheduler | None = None, + history: History | None = None, + on_begin: OnBeginCallbackSignature | None = None, + on_trial_exception: Literal["raise", "end", "continue"] = "raise", + # Plugin creating arguments + plugins: Plugin | Iterable[Plugin] | None = None, + process_memory_limit: int | tuple[int, str] | None = None, + process_walltime_limit: int | tuple[float, str] | None = None, + process_cputime_limit: int | tuple[float, str] | None = None, + threadpool_limit_ctl: bool | int | None = None, + # `scheduler.run()` arguments + display: bool | Literal["auto"] = "auto", + wait: bool = True, + on_scheduler_exception: Literal["raise", "end", "continue"] = "raise", + ) -> History: + """Optimize a pipeline on a given target function or evaluation protocol. + + Args: + target: + The function against which to optimize. + + * If `target` is a function, then it must take in a + [`Trial`][amltk.optimization.trial.Trial] as the first argument + and a [`Node`][amltk.pipeline.node.Node] second argument, returning a + [`Trial.Report`][amltk.optimization.trial.Trial.Report]. Please refer to + the [optimization guide](../../../guides/optimization.md) for more. + + * If `target` is a [`Task`][amltk.scheduling.task.Task], then + this will be used instead, updating the plugins with any additional + plugins specified. + metric: + The metric(s) that will be passed to `optimizer=`. These metrics + should align with what is being computed in `target=`. + optimizer: + The optimizer to use. If `None`, then AMLTK will go through a list + of known optimizers and use the first one it can find which was installed. + + Alternatively, this can be a class inheriting from + [`Optimizer`][amltk.optimization.optimizer.Optimizer] or else + a signature match [`Optimizer.CreateSignature`][amltk.optimization.Optimizer.CreateSignature] + + ??? tip "`Optimizer.CreateSignature`" + + ::: amltk.optimization.Optimizer.CreateSignature + + Lastly, you can also pass in your own already instantiated optimizer if you prefer, however + you should make sure to set it up correctly with the given metrics and search space. + It is recommened to just pass in the class if you are unsure how to do this properly. + seed: + A [`seed`][amltk.types.Seed] for the optimizer to use. + n_workers: + The numer of workers to use to evaluate this pipeline. + If no `scheduler=` is provided, then one will be created for + you as [`Scheduler.with_processes(n_workers)`][amltk.scheduling.Scheduler.with_processes]. + If you provide your own `scheduler=` then this will limit the maximum + amount of concurrent trials for this pipeline that will be evaluating + at once. + timeout: + How long to run the scheduler for. This parameter only takes + effect if `setup_only=False` which is the default. Otherwise, + it will be ignored. + display: + Whether to display the scheduler during running. By default + it is `"auto"` which means to enable the display if running + in a juptyer notebook or colab. Otherwise, it will be + `False`. + + This may work poorly if including print statements or logging. + wait: + Whether to wait for the scheduler to finish all pending jobs + if was stopped for any reason, e.g. a `timeout=` or + [`scheduler.stop()`][amltk.scheduling.Scheduler.stop] was called. + on_scheduler_exception: + What to do if an exception occured, either in the submitted task, + the callback, or any other unknown source during the loop. + + * If `#!python "raise"`, then the exception will be raised + immediatly and the optimization process will halt. The default + behavior and good for initial development. + * If `#!python "end"`, then the exception will be caught and + the optimization process will end gracefully. + * If `#!python "continue"`, the exception will be ignored and + the optimization procedure will continue. + working_dir: + A working directory to use for the optimizer and the trials. + Any items you store in trials will be located in this directory, + where the [`trial.name`][amltk.optimization.Trial.name] will be + used as a subfolder where any contents stored with + [`trial.store()`][amltk.optimization.trial.Trial.store] will be put there. + Please see the [optimization guide](../../../guides/optimization.md) + for more on trial storage. + scheduler: + The specific [`Scheduler`][amltk.scheduling.Scheduler] to use. + If `None`, then one will be created for you with + [`Scheduler.with_processes(n_workers)`][amltk.scheduling.Scheduler.with_processes] + history: + A [`History`][amltk.optimization.history.History] to store the + [`Trial.Report`][amltk.optimization.Trial.Report]s in. You + may pass in your own if you wish for this method to store + it there instead of creating its own. + on_begin: + A callback that will be called before the scheduler is run. This + can be used to hook into the life-cycle of the optimization and + perform custom routines. Please see the + [scheduling guide](../../../guides/scheduling.md) for more. + + ??? tip "on_begin signature" + + ::: amltk.pipeline.node.OnBeginCallbackSignature + + on_trial_exception: + What to do when a trial returns a fail report from + [`trial.fail()`][amltk.optimization.trial.Trial.fail] or + [`trial.crashed()`][amltk.optimization.trial.Trial.crashed] + that contains an exception. + + Please see the [optimization guide](../../../guides/optimization.md) + for more. In all cases, the exception will be attached to the + [`Trial.Report`][amltk.optimization.Trial.Report] object under + [`report.exception`][amltk.optimization.Trial.Report.exception]. + + * If `#!python "raise"`, then the exception will be raised + immediatly and the optimization process will halt. The default + and good for initial development. + * If `#!python "end"`, then the exception will be caught and + the optimization process will end gracefully. + * If `#!python "continue"`, the exception will be ignored and + the optimization procedure will continue. + + max_trials: + The maximum number of trials to run. If `None`, then the + optimization will continue for as long as the scheduler is + running. You'll likely want to configure this. + + process_memory_limit: + If specified, the [`Task`][amltk.scheduling.task.Task] will + use the + [`PynisherPlugin`][amltk.scheduling.plugins.pynisher.PynisherPlugin] + to limit the memory the process can use. Please + refer to the + [plugins `pynisher` reference](../../../reference/scheduling/plugins.md#pynisher) + for more as there are platform limitations and additional + dependancies required. + process_walltime_limit: + If specified, the [`Task`][amltk.scheduling.task.Task] will + use the + [`PynisherPlugin`][amltk.scheduling.plugins.pynisher.PynisherPlugin] + to limit the wall time the process can use. Please + refer to the + [plugins `pynisher` reference](../../../reference/scheduling/plugins.md#pynisher) + for more as there are platform limitations and additional + dependancies required. + process_cputime_limit: + If specified, the [`Task`][amltk.scheduling.task.Task] will + use the + [`PynisherPlugin`][amltk.scheduling.plugins.pynisher.PynisherPlugin] + to limit the cputime the process can use. Please + refer to the + [plugins `pynisher` reference](../../../reference/scheduling/plugins.md#pynisher) + for more as there are platform limitations and additional + dependancies required. + threadpool_limit_ctl: + If specified, the [`Task`][amltk.scheduling.task.Task] will + use the + [`ThreadPoolCTLPlugin`][amltk.scheduling.plugins.threadpoolctl.ThreadPoolCTLPlugin] + to limit the number of threads used by compliant libraries. + **Notably**, this includes scikit-learn, for which running multiple + in parallel can be problematic if not adjusted accordingly. + + The default behavior (when `None`) is to auto-detect whether this + is applicable. This is done by checking if `sklearn` is installed + and if the first node in the pipeline has a `BaseEstimator` item. + Please set this to `True`/`False` depending on your preference. + plugins: + Additional plugins to attach to the eventual + [`Task`][amltk.scheduling.task.Task] that will be executed by + the [`Scheduler`][amltk.scheduling.Scheduler]. Please + refer to the + [plugins reference](../../../reference/scheduling/plugins.md) for more. + """ # noqa: E501 + if timeout is None and max_trials is None: + raise ValueError( + "You must one or both of `timeout` or `max_trials` to" + " limit the optimization process.", + ) + + match history: + case None: + history = History() + case History(): + pass + case _: + raise ValueError(f"Invalid history {history}. Must be a History") + + scheduler, _, _ = self.register_optimization_loop( + target=target, + metric=metric, + optimizer=optimizer, + seed=seed, + max_trials=max_trials, + n_workers=n_workers, + working_dir=working_dir, + scheduler=scheduler, + history=history, + on_begin=on_begin, + on_trial_exception=on_trial_exception, + plugins=plugins, + process_memory_limit=process_memory_limit, + process_walltime_limit=process_walltime_limit, + process_cputime_limit=process_cputime_limit, + threadpool_limit_ctl=threadpool_limit_ctl, + ) + scheduler.run( + wait=wait, + timeout=timeout, + on_exception=on_scheduler_exception, + display=display, + ) + return history diff --git a/src/amltk/pipeline/parsers/optuna.py b/src/amltk/pipeline/parsers/optuna.py index f6387b55..ce31e4a1 100644 --- a/src/amltk/pipeline/parsers/optuna.py +++ b/src/amltk/pipeline/parsers/optuna.py @@ -91,9 +91,11 @@ from __future__ import annotations from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any import numpy as np +import optuna from optuna.distributions import ( BaseDistribution, CategoricalChoiceType, @@ -103,17 +105,95 @@ ) from amltk._functional import prefix_keys +from amltk.pipeline.components import Choice if TYPE_CHECKING: - from typing import TypeAlias - from amltk.pipeline import Node - OptunaSearchSpace: TypeAlias = dict[str, BaseDistribution] - PAIR = 2 +@dataclass +class OptunaSearchSpace: + """A class to represent an Optuna search space. + + Wraps a dictionary of hyperparameters and their Optuna distributions. + """ + + distributions: dict[str, BaseDistribution] = field(default_factory=dict) + + def __repr__(self) -> str: + return f"OptunaSearchSpace({self.distributions})" + + def __str__(self) -> str: + return str(self.distributions) + + @classmethod + def parse(cls, *args: Any, **kwargs: Any) -> OptunaSearchSpace: + """Parse a Node into an Optuna search space.""" + return parser(*args, **kwargs) + + def sample_configuration(self) -> dict[str, Any]: + """Sample a configuration from the search space using a default Optuna Study.""" + study = optuna.create_study() + trial = self.get_trial(study) + return trial.params + + def get_trial(self, study: optuna.Study) -> optuna.Trial: + """Get a trial from a given Optuna Study using this search space.""" + optuna_trial: optuna.Trial + if any("__choice__" in k for k in self.distributions): + optuna_trial = study.ask() + # do all __choice__ suggestions with suggest_categorical + workspace = self.distributions.copy() + filter_patterns = [] + for name, distribution in workspace.items(): + if "__choice__" in name and isinstance( + distribution, + CategoricalDistribution, + ): + possible_choices = distribution.choices + choice_made = optuna_trial.suggest_categorical( + name, + choices=possible_choices, + ) + for c in possible_choices: + if c != choice_made: + # deletable options have the name of the unwanted choices + filter_patterns.append(f":{c}:") + # filter all parameters for the unwanted choices + filtered_workspace = { + k: v + for k, v in workspace.items() + if ( + ("__choice__" not in k) + and ( + not any( + filter_pattern in k for filter_pattern in filter_patterns + ) + ) + ) + } + # do all remaining suggestions with the correct suggest function + for name, distribution in filtered_workspace.items(): + match distribution: + case CategoricalDistribution(choices=choices): + optuna_trial.suggest_categorical(name, choices=choices) + case IntDistribution( + low=low, + high=high, + log=log, + ): + optuna_trial.suggest_int(name, low=low, high=high, log=log) + case FloatDistribution(low=low, high=high): + optuna_trial.suggest_float(name, low=low, high=high) + case _: + raise ValueError(f"Unknown distribution: {distribution}") + else: + optuna_trial = study.ask(self.distributions) + return optuna_trial + + def _convert_hp_to_optuna_distribution( name: str, hp: tuple | Sequence | CategoricalChoiceType | BaseDistribution, @@ -149,7 +229,7 @@ def _convert_hp_to_optuna_distribution( raise ValueError(f"Could not parse {name} as a valid Optuna distribution.\n{hp=}") -def _parse_space(node: Node) -> OptunaSearchSpace: +def _parse_space(node: Node) -> dict[str, BaseDistribution]: match node.space: case None: space = {} @@ -196,13 +276,21 @@ def parser( delim: The delimiter to use for the names of the hyperparameters. """ - if conditionals: - raise NotImplementedError("Conditionals are not yet supported with Optuna.") - space = prefix_keys(_parse_space(node), prefix=f"{node.name}{delim}") - for child in node.nodes: - subspace = parser(child, flat=flat, conditionals=conditionals, delim=delim) + children = node.nodes + + if isinstance(node, Choice) and any(children): + name = f"{node.name}{delim}__choice__" + space[name] = CategoricalDistribution([child.name for child in children]) + + for child in children: + subspace = parser( + child, + flat=flat, + conditionals=conditionals, + delim=delim, + ).distributions if not flat: subspace = prefix_keys(subspace, prefix=f"{node.name}{delim}") @@ -214,4 +302,4 @@ def parser( ) space[name] = hp - return space + return OptunaSearchSpace(distributions=space) diff --git a/src/amltk/profiling/profiler.py b/src/amltk/profiling/profiler.py index 2b98ede7..c2cf0a77 100644 --- a/src/amltk/profiling/profiler.py +++ b/src/amltk/profiling/profiler.py @@ -1,80 +1,4 @@ -"""Whether for debugging, building an AutoML system or for optimization -purposes, we provide a powerful [`Profiler`][amltk.profiling.Profiler], -which can generate a [`Profile`][amltk.profiling.Profile] of different sections -of code. This is particularly useful with [`Trial`][amltk.optimization.Trial]s, -so much so that we attach one to every `Trial` made as -[`trial.profiler`][amltk.optimization.Trial.profiler]. - -When done profiling, you can export all generated profiles as a dataframe using -[`profiler.df()`][amltk.profiling.Profiler.df]. - -```python exec="true" result="python" source="material-block" -from amltk.profiling import Profiler -import numpy as np - -profiler = Profiler() - -with profiler("loading-data"): - X = np.random.rand(1000, 1000) - -with profiler("training-model"): - model = np.linalg.inv(X) - -with profiler("predicting"): - y = model @ X - -print(profiler.df()) -``` - -You'll find these profiles as keys in the [`Profiler`][amltk.profiling.Profiler], -e.g. `#! python profiler["loading-data"]`. - -This will measure both the time it took within the block but also -the memory consumed before and after the block finishes, allowing -you to get an estimate of the memory consumed. - - -??? tip "Memory, vms vs rms" - - While not entirely accurate, this should be enough for info - for most use cases. - - Given the main process uses 2GB of memory and the process - then spawns a new process in which you are profiling, as you - might do from a [`Task`][amltk.scheduling.Task]. In this new - process you use another 2GB on top of that, then: - - * The virtual memory size (**vms**) will show 4GB as the - new process will share the 2GB with the main process and - have it's own 2GB. - - * The resident set size (**rss**) will show 2GB as the - new process will only have 2GB of it's own memory. - - -If you need to profile some iterator, like a for loop, you can use -[`Profiler.each()`][amltk.profiling.Profiler.each] which will measure -the entire loop but also each individual iteration. This can be useful -for iterating batches of a deep-learning model, splits of a cross-validator -or really any loop with work you want to profile. - -```python exec="true" result="python" source="material-block" -from amltk.profiling import Profiler -import numpy as np - -profiler = Profiler() - -for i in profiler.each(range(3), name="for-loop"): - X = np.random.rand(1000, 1000) - -print(profiler.df()) -``` - -Lastly, to disable profiling without editing much code, -you can always use [`Profiler.disable()`][amltk.profiling.Profiler.disable] -and [`Profiler.enable()`][amltk.profiling.Profiler.enable] to toggle -profiling on and off. -""" +"""The profiler module provides classes for profiling code.""" from __future__ import annotations @@ -252,7 +176,7 @@ def each( itr: Iterable[T], *, name: str, - itr_name: Callable[[int, T], str] | None = None, + itr_name: str | Callable[[int, T], str] | None = None, ) -> Iterator[T]: """Profile each item in an iterable. @@ -267,11 +191,17 @@ def each( Yields: The the items """ - if itr_name is None: - itr_name = lambda i, _: str(i) + match itr_name: + case None: + _itr_name = lambda i, _: str(i) + case str(): + _itr_name = lambda i, _: f"{itr_name}_{i}" + case _: + _itr_name = itr_name + with self.measure(name=name): for i, item in enumerate(itr): - with self.measure(name=itr_name(i, item)): + with self.measure(name=_itr_name(i, item)): yield item @contextmanager diff --git a/src/amltk/pytorch/__init__.py b/src/amltk/pytorch/__init__.py new file mode 100644 index 00000000..0a1e6954 --- /dev/null +++ b/src/amltk/pytorch/__init__.py @@ -0,0 +1,11 @@ +from amltk.pytorch.builders import ( + MatchChosenDimensions, + MatchDimensions, + build_model_from_pipeline, +) + +__all__ = [ + "MatchDimensions", + "MatchChosenDimensions", + "build_model_from_pipeline", +] diff --git a/src/amltk/pytorch/builders.py b/src/amltk/pytorch/builders.py new file mode 100644 index 00000000..958bb1de --- /dev/null +++ b/src/amltk/pytorch/builders.py @@ -0,0 +1,166 @@ +"""This module contains functionality to construct a pytorch model from a pipeline. + +It also includes classes for handling dimension matching between layers. +""" + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any + +from torch import nn + +from amltk import Choice, Component, Node, Sequential +from amltk.exceptions import MatchChosenDimensionsError, MatchDimensionsError + + +@dataclass +class MatchDimensions: + """Handles matching dimensions between layers in a pipeline. + + This class helps ensure compatibility between layers with search spaces + during HPO optimization. It takes the layer name and parameter name + and stores them for later reference. + + Not intended to be used inside a Choice node. + """ + + layer_name: str + """The name of the layer.""" + + param: str + """The name of the parameter to match.""" + + def evaluate(self, pipeline: Node) -> int: + """Retrieves the corresponding configuration value from the pipeline. + + Args: + pipeline: The pipeline to search for the matching configuration. + + Returns: + The value of the matching configuration parameter. + """ + layer = pipeline[self.layer_name] + layer_config = layer.config + if layer_config is None: + raise MatchDimensionsError(self.layer_name, None) + value = layer_config.get(self.param) + if value is None: + raise MatchDimensionsError(self.layer_name, self.param) + return value + + +@dataclass +class MatchChosenDimensions: + """Handles matching dimensions for chosen nodes in a pipeline. + + This class helps ensure compatibility between layers with search spaces + during HPO optimization. It takes the choice name and the corresponding + dimensions for that choice and stores them for later reference. + + """ + + choice_name: str + """The name of the choice node.""" + + choices: Mapping[str, Any] + """The mapping of choice taken to the dimension to use.""" + + def evaluate(self, chosen_nodes: dict[str, str]) -> int: + """Retrieves the corresponding dimension for the chosen node. + + If the chosen node is not found in the choices dictionary, an error is raised. + If the dimensions provided are not valid, an error is not raised. + It is up to the user to ensure that the dimensions are valid. + + Args: + chosen_nodes: The chosen nodes. + + Returns: + The value of the matching dimension for a chosen node. + """ + chosen_node_name = chosen_nodes.get(self.choice_name, None) + + if chosen_node_name is None: + raise MatchChosenDimensionsError(self.choice_name, chosen_node_name=None) + + try: + return self.choices[chosen_node_name] + except KeyError as e: + raise MatchChosenDimensionsError(self.choice_name, chosen_node_name) from e + + @staticmethod + def collect_chosen_nodes_names(pipeline: Node) -> dict[str, str]: + """Collects the names of chosen nodes in the pipeline. + + Each pipeline has a unique set of chosen nodes, which we collect separately + to handle dimension matching between layers with search spaces. + + Args: + pipeline: The pipeline containing the model architecture. + + Returns: + The names of the chosen nodes in the pipeline. + """ + chosen_nodes_names = {} # Class variable to store chosen node names + + for node in pipeline.iter(): + if isinstance(node, Choice): + chosen_node = node.chosen() + if chosen_node: + chosen_nodes_names[node.name] = chosen_node.name + + return chosen_nodes_names + + +def build_model_from_pipeline(pipeline: Node, /) -> nn.Module: + """Builds a model from the provided pipeline. + + This function iterates through the pipeline nodes, constructing the model + layers dynamically based on the node types and configurations. It also + utilizes the `MatchDimensions` and `MatchChosenDimensions` objects to handle + dimension matching between layers with search spaces. + + Args: + pipeline: The pipeline containing the model architecture. + + Returns: + The constructed PyTorch model. + """ + model_layers = [] + + # Mapping of choice node names to what was chosen for that choice + chosen_nodes_names = MatchChosenDimensions.collect_chosen_nodes_names(pipeline) + + # NOTE: pipeline.iter() may not be sufficient as we relying on some implied ordering + # for this to work, i.e. we might not know when we're iterating through nodes of a + # Join or Split + for node in pipeline.iter(skip_unchosen=True): + match node: + case Component(config=config): + layer_config = dict(config) if config else {} + + for key, instance in layer_config.items(): + match instance: + case MatchDimensions(): + layer_config[key] = instance.evaluate(pipeline) + case MatchChosenDimensions(): + layer_config[key] = instance.evaluate(chosen_nodes_names) + case _: + # Just used the value directly + pass + + layer = node.build_item(**layer_config) + model_layers.append(layer) + # Check if node is a Fixed layer (e.g., Flatten, ReLU), + # Flatten layer or any other layer without config parameter + case Node(item=built_object) if built_object is not None: + model_layers.append(built_object) + case Sequential() | Choice(): + pass # Skip these as it will come up in iteration... + case _: + # TODO: Support other node types + raise NotImplementedError(f"Node type {type(node)} not supported yet.") + + return nn.Sequential(*model_layers) diff --git a/src/amltk/scheduling/events.py b/src/amltk/scheduling/events.py index 516d7616..13876e65 100644 --- a/src/amltk/scheduling/events.py +++ b/src/amltk/scheduling/events.py @@ -1,261 +1,26 @@ -"""One of the primary ways to respond to `@events` emitted -with by a [`Task`][amltk.scheduling.Task] -the [`Scheduler`][amltk.scheduling.Scheduler] -is through use of a **callback**. - -The reason for this is to enable an easier time for API's to utilize -multiprocessing and remote compute from the `Scheduler`, without having -to burden users with knowing the details of how to use multiprocessing. - -A callback subscribes to some event using a decorator but can also be done in -a functional style if preferred. The below example is based on the -event [`@scheduler.on_start`][amltk.scheduling.Scheduler.on_start] but -the same applies to all events. - -=== "Decorators" - - ```python exec="true" source="material-block" html="true" - from amltk.scheduling import Scheduler - - scheduler = Scheduler.with_processes(1) - - @scheduler.on_start - def print_hello() -> None: - print("hello") - - scheduler.run() - from amltk._doc import doc_print; doc_print(print, scheduler, fontsize="small") # markdown-exec: hide - ``` - -=== "Functional" - - ```python exec="true" source="material-block" html="true" - from amltk.scheduling import Scheduler - - scheduler = Scheduler.with_processes(1) - - def print_hello() -> None: - print("hello") - - scheduler.on_start(print_hello) - scheduler.run() - from amltk._doc import doc_print; doc_print(print, scheduler, fontsize="small") # markdown-exec: hide - ``` - -There are a number of ways to customize the behaviour of these callbacks, notably -to control how often they get called and when they get called. - -??? tip "Callback customization" - - - === "`on('event', repeat=...)`" - - This will cause the callback to be called `repeat` times successively. - This is most useful in combination with - [`@scheduler.on_start`][amltk.scheduling.Scheduler.on_start] to launch - a number of tasks at the start of the scheduler. - - ```python exec="true" source="material-block" html="true" hl_lines="11" - from amltk import Scheduler - - N_WORKERS = 2 - - def f(x: int) -> int: - return x * 2 - from amltk._doc import make_picklable; make_picklable(f) # markdown-exec: hide - - scheduler = Scheduler.with_processes(N_WORKERS) - task = scheduler.task(f) - - @scheduler.on_start(repeat=N_WORKERS) - def on_start(): - task.submit(1) - - scheduler.run() - from amltk._doc import doc_print; doc_print(print, scheduler, fontsize="small") # markdown-exec: hide - ``` - - === "`on('event', max_calls=...)`" - - Limit the number of times a callback can be called, after which, the callback - will be ignored. - - ```python exec="true" source="material-block" html="True" hl_lines="13" - from asyncio import Future - from amltk.scheduling import Scheduler - - scheduler = Scheduler.with_processes(2) - - def expensive_function(x: int) -> int: - return x ** 2 - from amltk._doc import make_picklable; make_picklable(expensive_function) # markdown-exec: hide - - @scheduler.on_start - def submit_calculations() -> None: - scheduler.submit(expensive_function, 2) - - @scheduler.on_future_result(max_calls=3) - def print_result(future, result) -> None: - scheduler.submit(expensive_function, 2) - - scheduler.run() - from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fontsize="small") # markdown-exec: hide - ``` - - === "`on('event', when=...)`" - - A callable which takes no arguments and returns a `bool`. The callback - will only be called when the `when` callable returns `True`. - - Below is a rather contrived example, but it shows how we can use the - `when` parameter to control when the callback is called. - - ```python exec="true" source="material-block" html="True" hl_lines="8 12" - import random - from amltk.scheduling import Scheduler - - LOCALE = random.choice(["English", "German"]) - - scheduler = Scheduler.with_processes(1) - - @scheduler.on_start(when=lambda: LOCALE == "English") - def print_hello() -> None: - print("hello") - - @scheduler.on_start(when=lambda: LOCALE == "German") - def print_guten_tag() -> None: - print("guten tag") - - scheduler.run() - from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fontsize="small") # markdown-exec: hide - ``` - - === "`on('event', every=...)`" - - Only call the callback every `every` times the event is emitted. This - includes the first time it's called. - - ```python exec="true" source="material-block" html="True" hl_lines="6" - from amltk.scheduling import Scheduler - - scheduler = Scheduler.with_processes(1) - - # Print "hello" only every 2 times the scheduler starts. - @scheduler.on_start(every=2) - def print_hello() -> None: - print("hello") - - # Run the scheduler 5 times - scheduler.run() - scheduler.run() - scheduler.run() - scheduler.run() - scheduler.run() - from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fontsize="small") # markdown-exec: hide - ``` - -### Emitter, Subscribers and Events -This part of the documentation is not necessary to understand or use for AMLTK. People -wishing to build tools upon AMLTK may still find this a useful component to add to their -arsenal. - -The core of making this functionality work is the [`Emitter`][amltk.scheduling.events.Emitter]. -Its purpose is to have `@events` that can be emitted and subscribed to. Classes like the -[`Scheduler`][amltk.scheduling.Scheduler] and [`Task`][amltk.scheduling.Task] carry -around with them an `Emitter` to enable all of this functionality. - -Creating an `Emitter` is rather straight-forward, but we must also create -[`Events`][amltk.scheduling.events.Event] that people can subscribe to. - -```python -from amltk.scheduling import Emitter, Event -emitter = Emitter("my-emitter") - -event: Event[int] = Event("my-event") # (1)! - -@emitter.on(event) -def my_callback(x: int) -> None: - print(f"Got {x}!") - -emitter.emit(event, 42) # (2)! -``` - -1. The typing `#!python Event[int]` is used to indicate that the event will be emitting - an integer. This is not necessary, but it is useful for type-checking and - documentation. -2. The `#!python emitter.emit(event, 42)` is used to emit the event. This will call - all the callbacks registered for the event, i.e. `#!python my_callback()`. - -!!! warning "Independent Events" - - Given a single `Emitter` and a single instance of an `Event`, there is no way to - have different `@events` for callbacks. There are two options, both used extensively - in AMLTK. - - The first is to have different `Events` quite naturally, i.e. you distinguish - between different things that can happen. However, you often want to have different - objects emit the same `Event` but have different callbacks for each object. - - This makes most sense in the context of a `Task` the `Event` instances are shared as - class variables in the `Task` class, however a user likely want's to subscribe to - the `Event` for a specific instance of the `Task`. - - This is where the second option comes in, in which each object carries around its - own `Emitter` instance. This is how a user can subscribe to the same kind of `Event` - but individually for each `Task`. - - -However, to shield users from this and to create named access points for users to -subscribe to, we can use the [`Subscriber`][amltk.scheduling.events.Subscriber] class, -conveniently created by the [`Emitter.subscriber()`][amltk.scheduling.events.Emitter.subscriber] -method. - -```python -from amltk.scheduling import Emitter, Event -emitter = Emitter("my-emitter") - -class GPT: - - event: Event[str] = Event("my-event") - - def __init__(self) -> None: - self.on_answer: Subscriber[str] = emitter.subscriber(self.event) - - def ask(self, question: str) -> None: - emitter.emit(self.event, "hello world!") - -gpt = GPT() - -@gpt.on_answer -def print_answer(answer: str) -> None: - print(answer) - -gpt.ask("What is the conical way for an AI to greet someone?") -``` - -Typically these event based systems make little sense in a synchronous context, however -with the [`Scheduler`][amltk.scheduling.Scheduler] and [`Task`][amltk.scheduling.Task] -classes, they are used to enable a simple way to use multiprocessing and remote compute. -""" # noqa: E501 +"""THe event system in AMLTK.""" from __future__ import annotations import logging import math import time from collections import Counter, defaultdict -from collections.abc import Callable, Iterable, Iterator, Mapping +from collections.abc import Callable, Iterable from dataclasses import dataclass, field from functools import partial -from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload +from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar, overload from typing_extensions import ParamSpec, override -from uuid import uuid4 -from amltk._functional import callstring, funcname +from more_itertools import first_true + +from amltk._functional import funcname from amltk._richutil.renderers.function import Function +from amltk.exceptions import EventNotKnownError +from amltk.randomness import randuid if TYPE_CHECKING: + from rich.console import RenderableType from rich.text import Text - from rich.tree import Tree T = TypeVar("T") P = ParamSpec("P") @@ -273,12 +38,12 @@ def execute( cls, events: Iterable[ tuple[ - Iterable[Handler[P]], + Iterable[Handler[P, R]], tuple[Any, ...] | None, dict[str, Any] | None, ] ], - ) -> None: + ) -> list[tuple[Handler[P, R], R | None]]: """Call all events in the scheduler.""" all_handlers = [] for handlers, args, kwargs in events: @@ -287,12 +52,14 @@ def execute( ] sorted_handlers = sorted(all_handlers, key=lambda item: item[0].registered_at) - for handler, args, kwargs in sorted_handlers: - handler(*args, **kwargs) + return [ + (handler, handler(*args, **kwargs)) + for handler, args, kwargs in sorted_handlers + ] @dataclass(frozen=True) -class Event(Generic[P]): +class Event(Generic[P, R]): """An event that can be emitted. Attributes: @@ -321,7 +88,7 @@ def __rich__(self) -> Text: @dataclass -class Subscriber(Generic[P]): +class Subscriber(Generic[P, R]): """An object that can be used to easily subscribe to a certain event. ```python @@ -353,17 +120,46 @@ def callback(a: int, b: str) -> None: """ emitter: Emitter - event: Event[P] - when: Callable[[], bool] | None = None - max_calls: int | None = None - repeat: int = 1 - every: int = 1 + event: Event[P, R] @property def event_counts(self) -> int: """The number of times this event has been emitted.""" return self.emitter.event_counts[self.event] + def register( + self, + callback: Callable[P, R], + *, + when: Callable[[], bool] | None = None, + max_calls: int | None = None, + repeat: int = 1, + every: int = 1, + hidden: bool = False, + ) -> None: + """Register a callback for this subscriber. + + Args: + callback: The callback to register. + when: A predicate that must be satisfied for the callback to be called. + every: The callback will be called every `every` times the event is emitted. + repeat: The callback will be called `repeat` times successively. + max_calls: The maximum number of times the callback can be called. + hidden: Whether to hide the callback in visual output. + This is mainly used to facilitate Plugins who + act upon events but don't want to be seen, primarily + as they are just book-keeping callbacks. + """ + self.emitter.register( + event=self.event, + callback=callback, + when=when, + max_calls=max_calls, + repeat=repeat, + every=every, + hidden=hidden, + ) + @overload def __call__( self, @@ -373,36 +169,41 @@ def __call__( max_calls: int | None = ..., repeat: int = ..., every: int = ..., - ) -> partial[Callable[P, Any]]: + hidden: bool = ..., + ) -> Callable[[Callable[P, R]], None]: ... @overload def __call__( self, - callback: Callable[P, Any], + callback: Callable[P, R], *, when: Callable[[], bool] | None = ..., max_calls: int | None = ..., repeat: int = ..., every: int = ..., hidden: bool = ..., - ) -> Callable[P, Any]: + ) -> None: ... def __call__( self, - callback: Callable[P, Any] | None = None, + callback: Callable[P, R] | None = None, *, when: Callable[[], bool] | None = None, max_calls: int | None = None, repeat: int = 1, every: int = 1, hidden: bool = False, - ) -> Callable[P, Any] | partial[Callable[P, Any]]: - """Subscribe to the event associated with this object. + ) -> Callable[[Callable[P, R]], None] | None: + """A decorator to register a callback for this subscriber. Args: - callback: The callback to register. + callback: The callback to register. If `None`, then this + acts as a decorator, as you would normally use it. Prefer + to leave this as `None` and use + [`register()`][amltk.scheduling.events.Subscriber.register] if + you have a direct reference to the function and are not decorating it. when: A predicate that must be satisfied for the callback to be called. every: The callback will be called every `every` times the event is emitted. repeat: The callback will be called `repeat` times successively. @@ -411,22 +212,17 @@ def __call__( This is mainly used to facilitate Plugins who act upon events but don't want to be seen, primarily as they are just book-keeping callbacks. - - Returns: - The callback if it was provided, otherwise it acts - as a decorator. """ if callback is None: return partial( - self.__call__, + self.register, when=when, max_calls=max_calls, repeat=repeat, every=every, - ) # type: ignore - - self.emitter.on( - self.event, + hidden=hidden, + ) + self.register( callback, when=when, max_calls=max_calls, @@ -434,22 +230,26 @@ def __call__( every=every, hidden=hidden, ) - return callback + return None - def emit(self, *args: P.args, **kwargs: P.kwargs) -> None: + def emit( + self, + *args: P.args, + **kwargs: P.kwargs, + ) -> list[tuple[Handler[P, R], R | None]]: """Emit this subscribers event.""" - self.emitter.emit(self.event, *args, **kwargs) + return self.emitter.emit(self.event, *args, **kwargs) @dataclass -class Handler(Generic[P]): +class Handler(Generic[P, R]): """A handler for an event. This is a simple class that holds a callback and any predicate that must be satisfied for it to be triggered. """ - callback: Callable[P, Any] + callback: Callable[P, R] when: Callable[[], bool] | None = None every: int = 1 n_calls_to_handler: int = 0 @@ -459,26 +259,39 @@ class Handler(Generic[P]): registered_at: int = field(default_factory=time.time_ns) hidden: bool = False - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> None: + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R | None: """Call the callback if the predicate is satisfied. If the predicate is not satisfied, then `None` is returned. """ self.n_calls_to_handler += 1 if self.every > 1 and self.n_calls_to_handler % self.every != 0: - return + return None if self.when is not None and not self.when(): - return + return None max_calls = self.max_calls if self.max_calls is not None else math.inf - for _ in range(self.repeat): + + if self.repeat == 1: if self.n_calls_to_callback >= max_calls: - return + return None - logger.debug(f"Calling: {callstring(self.callback)}") - self.callback(*args, **kwargs) self.n_calls_to_callback += 1 + return self.callback(*args, **kwargs) + + if self.n_calls_to_callback >= max_calls: + return None + + responses = iter(self.callback(*args, **kwargs) for _ in range(self.repeat)) + self.n_calls_to_callback += 1 + first_response = next(responses) + if first_response is not None: + raise ValueError("A callback with a response cannot have `repeat` > 1.") + + # Otherwise just exhaust the iterator + list(responses) + return None def __rich__(self) -> Text: from rich.text import Text @@ -495,7 +308,7 @@ def __rich__(self) -> Text: ) -class Emitter(Mapping[Event, list[Handler]]): +class Emitter: """An event emitter. This class is used to emit events and register callbacks for those events. @@ -505,6 +318,9 @@ class Emitter(Mapping[Event, list[Handler]]): to directly subscribe to their [`Events`][amltk.scheduling.events.Event]. """ + HandlerResponses: TypeAlias = Iterable[tuple[Handler[P, R], R | None]] + """The stream of responses from handlers when an event is triggered.""" + name: str | None """The name of the emitter.""" @@ -522,14 +338,19 @@ def __init__(self, name: str | None = None) -> None: will be used. """ super().__init__() - self.unique_ref = f"{name}-{uuid4()}" + self.unique_ref = f"{name}-{randuid()}" self.emitted_events: set[Event] = set() self.name = name self.handlers = defaultdict(list) self.event_counts = Counter() - def emit(self, event: Event[P], *args: Any, **kwargs: Any) -> None: + def emit( + self, + event: Event[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> list[tuple[Handler[P, R], R | None]]: """Emit an event. This will call all the handlers for the event. @@ -547,64 +368,61 @@ def emit(self, event: Event[P], *args: Any, **kwargs: Any) -> None: logger.debug(f"{self.name}: Emitting {event}") self.event_counts[event] += 1 - for handler in self.handlers[event]: - handler(*args, **kwargs) + return [(handler, handler(*args, **kwargs)) for handler in self.handlers[event]] - def emit_many( - self, - events: dict[Event, tuple[tuple[Any, ...] | None, dict[str, Any] | None]], - ) -> None: - """Emit multiple events. + @property + def events(self) -> list[Event]: + """Return a list of the events.""" + return list(self.handlers.keys()) - This is useful for cases where you don't want to favour one callback - over another, and so uses the time a callback was registered to call - the callback instead. + def subscriber(self, event: Event[P, R]) -> Subscriber[P, R]: + """Create a subscriber for an event. Args: - events: A mapping of event keys to a tuple of positional - arguments and keyword arguments to pass to the handlers. + event: The event to register the callback for. """ - for event in events: - self.event_counts[event] += 1 - - items = [ - (handlers, args, kwargs) - for event, (args, kwargs) in events.items() - if (handlers := self.get(event)) is not None - ] - - header = f"{self.name}: Emitting many events" - logger.debug(header) - logger.debug(",".join(str(event) for event in events)) - RegisteredTimeCallOrderStrategy.execute(items) - - @override - def __getitem__(self, event: Event) -> list[Handler]: - return self.handlers[event] + if event not in self.handlers: + self.handlers[event] = [] - @override - def __iter__(self) -> Iterator[Event]: - return iter(self.handlers) + return Subscriber(self, event) - @override - def __len__(self) -> int: - return len(self.handlers) + @overload + def on( + self, + event: str, + *, + when: Callable[[], bool] | None = ..., + every: int = ..., + repeat: int = ..., + max_calls: int | None = ..., + hidden: bool = ..., + ) -> Callable[[Callable[..., Any | None]], None]: + ... - @property - def events(self) -> list[Event]: - """Return a list of the events.""" - return list(self.handlers.keys()) + @overload + def on( + self, + event: Event[P, R], + *, + when: Callable[[], bool] | None = ..., + every: int = ..., + repeat: int = ..., + max_calls: int | None = ..., + hidden: bool = ..., + ) -> Callable[[Callable[P, R | None]], None]: + ... - def subscriber( + def on( self, - event: Event[P], + event: Event[P, R] | str, *, when: Callable[[], bool] | None = None, every: int = 1, repeat: int = 1, max_calls: int | None = None, - ) -> Subscriber[P]: - """Create a subscriber for an event. + hidden: bool = False, + ) -> Callable[[Callable[P, R | None]], None]: + """Register a callback for an event as a decorator. Args: event: The event to register the callback for. @@ -612,23 +430,24 @@ def subscriber( every: The callback will be called every `every` times the event is emitted. repeat: The callback will be called `repeat` times successively. max_calls: The maximum number of times the callback can be called. + hidden: Whether to hide the callback in visual output. + This is mainly used to facilitate Plugins who + act upon events but don't want to be seen, primarily + as they are just book-keeping callbacks. """ - if event not in self.handlers: - self.handlers[event] = [] - - return Subscriber( - self, - event, # type: ignore + return partial( + self.subscriber(event=self.as_event(event)), # type: ignore when=when, every=every, repeat=repeat, max_calls=max_calls, + hidden=hidden, ) - def on( + def register( self, - event: Event[P], - callback: Callable, + event: Event[P, R] | str, + callback: Callable[P, R], *, when: Callable[[], bool] | None = None, every: int = 1, @@ -650,6 +469,8 @@ def on( act upon events but don't want to be seen, primarily as they are just book-keeping callbacks. """ + event = self.as_event(event) + if repeat <= 0: raise ValueError(f"{repeat=} must be a positive integer.") @@ -659,9 +480,6 @@ def on( # Make sure it shows up in the event counts, setting it to 0 if it # doesn't exist self.event_counts.setdefault(event, 0) - - # This hackery is just to get down to a flat list of events that need - # to be set up self.handlers[event].append( Handler( callback, @@ -695,7 +513,7 @@ def add_event(self, *event: Event) -> None: if e not in self.handlers: self.handlers[e] = [] - def __rich__(self) -> Tree: + def __rich__(self) -> RenderableType: from rich.tree import Tree tree = Tree(self.name or "", hide_root=self.name is None) @@ -718,3 +536,27 @@ def __rich__(self) -> Tree: event_tree.add(handler) return tree + + @overload + def as_event(self, key: str) -> Event: + ... + + @overload + def as_event(self, key: Event[P, R]) -> Event[P, R]: + ... + + def as_event(self, key: str | Event) -> Event: + """Return the event associated with the key.""" + match key: + case Event(): + return key + case str(): + match = first_true(self.events, None, lambda e: e.name == key) + if match is None: + raise EventNotKnownError( + f"{key=} is not a valid event for {self.name}." + f"\nKnown events are: {[e.name for e in self.events]}", + ) + return match + case _: + raise TypeError(f"{key=} must be a string or an Event.") diff --git a/src/amltk/scheduling/executors/dask_jobqueue.py b/src/amltk/scheduling/executors/dask_jobqueue.py index 9652c422..ff2524e4 100644 --- a/src/amltk/scheduling/executors/dask_jobqueue.py +++ b/src/amltk/scheduling/executors/dask_jobqueue.py @@ -99,7 +99,8 @@ def __init__( self.cluster.scale(n_workers) self.n_workers = n_workers - self.executor: ClientExecutor = self.cluster.get_client().get_executor() + self._client = self.cluster.get_client() + self.executor: ClientExecutor = self._client.get_executor() @override def __enter__(self) -> Self: @@ -116,6 +117,7 @@ def __enter__(self) -> Self: @override def __exit__(self, *args: Any, **kwargs: Any) -> None: self.executor.__exit__(*args, **kwargs) + self._client.close() @override def submit( diff --git a/src/amltk/scheduling/plugins/comm.py b/src/amltk/scheduling/plugins/comm.py index f1b5b748..b1ad73a3 100644 --- a/src/amltk/scheduling/plugins/comm.py +++ b/src/amltk/scheduling/plugins/comm.py @@ -8,13 +8,14 @@ ??? tip "Usage" To setup a `Task` to work with a `Comm`, the `Task` **must accept a `comm` as - it's first argument**. + a keyword argument**. This is to prevent it conflicting with any args passed + through during the call to `submit()`. ```python exec="true" source="material-block" result="python" hl_lines="4-7 10 17-19 21-23" from amltk.scheduling import Scheduler from amltk.scheduling.plugins import Comm - def powers_of_two(comm: Comm, start: int, n: int) -> None: + def powers_of_two(start: int, n: int, *, comm: Comm) -> None: with comm.open(): for i in range(n): comm.send(start ** (i+1)) @@ -208,14 +209,15 @@ class Comm: id: The id of the comm. """ - MESSAGE: Event[Comm.Msg] = Event("comm-message") + MESSAGE: Event[[Comm.Msg], Any] = Event("comm-message") """A Task has sent a message to the main process. ```python exec="true" source="material-block" html="true" hl_lines="6 11-13" from amltk.scheduling import Scheduler from amltk.scheduling.plugins import Comm - def fn(comm: Comm, x: int) -> int: + def fn(x: int, comm: Comm | None = None) -> int: + assert comm is not None with comm.open(): comm.send(x + 1) @@ -229,14 +231,15 @@ def callback(msg: Comm.Msg): ``` """ - REQUEST: Event[Comm.Msg] = Event("comm-request") + REQUEST: Event[[Comm.Msg], Any] = Event("comm-request") """A Task has sent a request. ```python exec="true" source="material-block" html="true" hl_lines="6 16-18" from amltk.scheduling import Scheduler from amltk.scheduling.plugins import Comm - def greeter(comm: Comm, greeting: str) -> None: + def greeter(greeting: str, comm: Comm | None = None) -> None: + assert comm is not None with comm.open(): name = comm.request() comm.send(f"{greeting} {name}!") @@ -262,7 +265,7 @@ def on_msg(msg: Comm.Msg): ``` """ # noqa: E501 - OPEN: Event[Comm.Msg] = Event("comm-open") + OPEN: Event[[Comm.Msg], Any] = Event("comm-open") """The task has signalled it's open. ```python exec="true" source="material-block" html="true" hl_lines="5 15-17" @@ -290,7 +293,7 @@ def callback(msg: Comm.Msg): ``` """ - CLOSE: Event[Comm.Msg] = Event("comm-close") + CLOSE: Event[[Comm.Msg], Any] = Event("comm-close") """The task has signalled it's close. ```python exec="true" source="material-block" html="true" hl_lines="7 17-19" @@ -312,7 +315,7 @@ def on_start(): task.submit() @task.on("comm-close") - def on_close(msg: Comm.msg): + def on_close(msg: Comm.Msg): print(f"Worker close with {msg}") scheduler.run() @@ -487,11 +490,13 @@ class Plugin(TaskPlugin): def __init__( self, + parameter_name: str = "comm", create_comms: Callable[[], tuple[Comm, Comm]] | None = None, ) -> None: """Initialize the plugin. Args: + parameter_name: The name of the parameter to inject the comm into. create_comms: A function that creates a pair of communication channels. Defaults to `Comm.create`. """ @@ -499,6 +504,7 @@ def __init__( if create_comms is None: create_comms = Comm.create + self.parameter_name = parameter_name self.create_comms = create_comms self.comms: dict[CommID, tuple[Comm, Comm]] = {} self.communication_tasks: list[asyncio.Task] = [] @@ -517,7 +523,7 @@ def attach_task(self, task: Task) -> None: task: The task the plugin is being attached to. """ self.task = task - task.emitter.add_event(Comm.MESSAGE, Comm.REQUEST, Comm.OPEN, Comm.CLOSE) + task.add_event(Comm.MESSAGE, Comm.REQUEST, Comm.OPEN, Comm.CLOSE) task.on_submitted(self._begin_listening, hidden=True) @override @@ -542,28 +548,28 @@ def pre_submit( not be submitted. """ host_comm, worker_comm = self.create_comms() + if self.parameter_name in kwargs: + raise ValueError( + f"Parameter {self.parameter_name} already exists in kwargs!", + ) + + kwargs[self.parameter_name] = worker_comm # We don't necessarily know if the future will be submitted. If so, # we will use this index later to retrieve the host_comm self.comms[worker_comm.id] = (host_comm, worker_comm) # Make sure to include the Comm - return fn, (worker_comm, *args), kwargs - - @override - def copy(self) -> Self: - """Return a copy of the plugin. - - Please see [`Plugin.copy()`][amltk.scheduling.Plugin.copy]. - """ - return self.__class__(create_comms=self.create_comms) - - def _begin_listening(self, f: asyncio.Future, *args: Any, **_: Any) -> Any: - match args: - case (worker_comm, *_) if isinstance(worker_comm, Comm): - worker_comm = args[0] - case _: - raise ValueError(f"Expected first arg to be a Comm, got {args[0]}") + return fn, args, kwargs + + def _begin_listening(self, f: asyncio.Future, *args: Any, **kwargs: Any) -> Any: + worker_comm = kwargs.get(self.parameter_name) + if worker_comm is None: + raise ValueError( + f"Expected Comm in `{self.parameter_name}` but it didn't exist." + f" This is likely a bug in the plugin." + f"\nargs: {args} kwargs: {kwargs}", + ) host_comm, worker_comm = self.comms[worker_comm.id] @@ -582,7 +588,11 @@ def _deregister_comm_coroutine(self, coroutine: asyncio.Task) -> None: else: logger.warning(f"Communication coroutine {coroutine} not found!") - if (exception := coroutine.exception()) is not None: + if coroutine.cancelled(): + logger.debug( + f"Coroutine {coroutine} was cancelled. Not treated as an error.", + ) + elif (exception := coroutine.exception()) is not None: raise exception async def _communicate( @@ -627,7 +637,7 @@ async def _communicate( future=future, task=self.task, ) - self.task.emitter.emit(event, msg) + self.task.emit(event, msg) except EOFError: # This means the connection dropped to the worker, however this is not diff --git a/src/amltk/scheduling/plugins/emissions_tracker_plugin.py b/src/amltk/scheduling/plugins/emissions_tracker_plugin.py index 82601ffc..2e2617a4 100644 --- a/src/amltk/scheduling/plugins/emissions_tracker_plugin.py +++ b/src/amltk/scheduling/plugins/emissions_tracker_plugin.py @@ -8,7 +8,7 @@ from collections.abc import Callable from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar -from typing_extensions import ParamSpec, Self +from typing_extensions import ParamSpec from codecarbon import EmissionsTracker @@ -128,10 +128,6 @@ def pre_submit( ) return wrapped_f, args, kwargs - def copy(self) -> Self: - """Return a copy of the plugin.""" - return self.__class__(*self.codecarbon_args, **self.codecarbon_kwargs) - def __rich__(self) -> Panel: """Return a rich panel.""" from rich.panel import Panel diff --git a/src/amltk/scheduling/plugins/limiter.py b/src/amltk/scheduling/plugins/limiter.py index c3abcd52..4c7a3085 100644 --- a/src/amltk/scheduling/plugins/limiter.py +++ b/src/amltk/scheduling/plugins/limiter.py @@ -8,7 +8,7 @@ ??? tip "Usage" ```python exec="true" source="material-block" html="true" - from amltk.scheduling import Scheduler + from amltk.scheduling import Scheduler, Task from amltk.scheduling.plugins import Limiter def fn(x: int) -> int: @@ -42,15 +42,15 @@ def callback(task: Task, *args, **kwargs): from collections.abc import Callable, Iterable from typing import TYPE_CHECKING, Any, ClassVar, TypeVar -from typing_extensions import ParamSpec, Self, override +from typing_extensions import ParamSpec, override from amltk.scheduling.events import Event from amltk.scheduling.plugins.plugin import Plugin +from amltk.scheduling.task import Task if TYPE_CHECKING: from rich.panel import Panel - from amltk.scheduling.task import Task P = ParamSpec("P") R = TypeVar("R") @@ -70,14 +70,14 @@ class Limiter(Plugin): name: ClassVar = "limiter" """The name of the plugin.""" - CALL_LIMIT_REACHED: Event[...] = Event("call-limit-reached") + CALL_LIMIT_REACHED: Event[..., Any] = Event("call-limit-reached") """The event emitted when the task has reached its call limit. Will call any subscribers with the task as the first argument, followed by the arguments and keyword arguments that were passed to the task. ```python exec="true" source="material-block" html="true" - from amltk.scheduling import Scheduler + from amltk.scheduling import Scheduler, Task from amltk.scheduling.plugins import Limiter def fn(x: int) -> int: @@ -94,14 +94,14 @@ def callback(task: Task, *args, **kwargs): ``` """ - CONCURRENT_LIMIT_REACHED: Event[...] = Event("concurrent-limit-reached") + CONCURRENT_LIMIT_REACHED: Event[..., Any] = Event("concurrent-limit-reached") """The event emitted when the task has reached its concurrent call limit. Will call any subscribers with the task as the first argument, followed by the arguments and keyword arguments that were passed to the task. ```python exec="true" source="material-block" html="true" - from amltk.scheduling import Scheduler + from amltk.scheduling import Scheduler, Task from amltk.scheduling.plugins import Limiter def fn(x: int) -> int: @@ -118,7 +118,9 @@ def callback(task: Task, *args, **kwargs): ``` """ - DISABLED_DUE_TO_RUNNING_TASK: Event[...] = Event("disabled-due-to-running-task") + DISABLED_DUE_TO_RUNNING_TASK: Event[..., Any] = Event( + "disabled-due-to-running-task", + ) """The event emitter when the task was not submitted due to some other running task. @@ -126,7 +128,7 @@ def callback(task: Task, *args, **kwargs): the arguments and keyword arguments that were passed to the task. ```python exec="true" source="material-block" html="true" - from amltk.scheduling import Scheduler + from amltk.scheduling import Scheduler, Task from amltk.scheduling.plugins import Limiter def fn(x: int) -> int: @@ -162,12 +164,13 @@ def __init__( """ super().__init__() - if not_while_running is None: - not_while_running = [] - elif isinstance(not_while_running, Iterable): - not_while_running = list(not_while_running) - else: - not_while_running = [not_while_running] + match not_while_running: + case None: + not_while_running = [] + case Task(): + not_while_running = [not_while_running] + case _: + not_while_running = list(not_while_running) self.max_calls = max_calls self.max_concurrent = max_concurrent @@ -194,7 +197,7 @@ def attach_task(self, task: Task) -> None: " has sufficient use case.", ) - task.emitter.add_event( + task.add_event( self.CALL_LIMIT_REACHED, self.CONCURRENT_LIMIT_REACHED, self.DISABLED_DUE_TO_RUNNING_TASK, @@ -223,21 +226,16 @@ def pre_submit( assert self.task is not None if self.max_calls is not None and self._calls >= self.max_calls: - self.task.emitter.emit(self.CALL_LIMIT_REACHED, self.task, *args, **kwargs) + self.task.emit(self.CALL_LIMIT_REACHED, self.task, *args, **kwargs) return None if self.max_concurrent is not None and self.n_running >= self.max_concurrent: - self.task.emitter.emit( - self.CONCURRENT_LIMIT_REACHED, - self.task, - *args, - **kwargs, - ) + self.task.emit(self.CONCURRENT_LIMIT_REACHED, self.task, *args, **kwargs) return None for other_task in self.not_while_running: if other_task.running(): - self.task.emitter.emit( + self.task.emit( self.DISABLED_DUE_TO_RUNNING_TASK, other_task, self.task, @@ -248,14 +246,6 @@ def pre_submit( return fn, args, kwargs - @override - def copy(self) -> Self: - """Return a copy of the plugin.""" - return self.__class__( - max_calls=self.max_calls, - max_concurrent=self.max_concurrent, - ) - def _increment_call_count(self, *_: Any, **__: Any) -> None: self._calls += 1 diff --git a/src/amltk/scheduling/plugins/plugin.py b/src/amltk/scheduling/plugins/plugin.py index 18d121b7..1a9209e2 100644 --- a/src/amltk/scheduling/plugins/plugin.py +++ b/src/amltk/scheduling/plugins/plugin.py @@ -41,7 +41,7 @@ def attach_task(self, task) -> None: self.task = task # Register an event with the task, this lets the task know valid events # people can subscribe to and helps it show up in visuals - task.emitter.add_event(self.PRINTED) + task.add_event(self.PRINTED) task.on_submitted(self._print_submitted, hidden=True) # You can hide this callback from visuals def pre_submit(self, fn, *args, **kwargs) -> tuple[Callable, tuple, dict]: @@ -51,11 +51,7 @@ def pre_submit(self, fn, *args, **kwargs) -> tuple[Callable, tuple, dict]: def _print_submitted(self, future, *args, **kwargs) -> None: msg = f"Task was submitted {self.task} {args} {kwargs}" - self.task.emitter.emit(self.PRINTED, msg) # Emit the event with a msg - - def copy(self) -> Printer: - # Plugins need to be able to copy themselves as if fresh - return self.__class__(self.greeting) + self.task.emit(self.PRINTED, msg) # Emit the event with a msg def __rich__(self): # Custome how the plugin is displayed in rich (Optional) @@ -102,11 +98,11 @@ def callback(msg: str): from __future__ import annotations import logging -from abc import ABC, abstractmethod +from abc import ABC from collections.abc import Callable from itertools import chain from typing import TYPE_CHECKING, ClassVar, TypeVar -from typing_extensions import ParamSpec, Self, override +from typing_extensions import ParamSpec, override from amltk._richutil.renderable import RichRenderable from amltk.scheduling.events import Event @@ -180,17 +176,6 @@ def events(self) -> list[Event]: ) return [attr for attr in inherited_attrs if isinstance(attr, Event)] - @abstractmethod - def copy(self) -> Self: - """Return a copy of the plugin. - - This method is used to create a copy of the plugin when a task is - copied. This is useful if the plugin stores a reference to the task - it is attached to, as the copy will need to store a reference to the - copy of the task. - """ - ... - @override def __rich__(self) -> Panel: from rich.panel import Panel diff --git a/src/amltk/scheduling/plugins/pynisher.py b/src/amltk/scheduling/plugins/pynisher.py index 65cd43ef..2b5fb908 100644 --- a/src/amltk/scheduling/plugins/pynisher.py +++ b/src/amltk/scheduling/plugins/pynisher.py @@ -86,8 +86,8 @@ def callback(exception): from collections.abc import Callable from dataclasses import dataclass from multiprocessing.context import BaseContext -from typing import TYPE_CHECKING, ClassVar, Generic, Literal, TypeAlias, TypeVar -from typing_extensions import ParamSpec, Self, override +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeAlias, TypeVar +from typing_extensions import ParamSpec, override import pynisher import pynisher.exceptions @@ -150,9 +150,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: return fn(*args, **kwargs) except pynisher.PynisherException as e: tb = traceback.format_exc() - trial.exception = e - trial.traceback = tb - return trial.fail() # type: ignore + return trial.fail(e, tb) # type: ignore else: return fn(*args, **kwargs) @@ -180,7 +178,7 @@ class PynisherPlugin(Plugin): name: ClassVar = "pynisher-plugin" """The name of the plugin.""" - TIMEOUT: Event[PynisherPlugin.TimeoutException] = Event("pynisher-timeout") + TIMEOUT: Event[[PynisherPlugin.TimeoutException], Any] = Event("pynisher-timeout") """A Task timed out, either due to the wall time or cpu time limit. Will call any subscribers with the exception as the argument. @@ -204,7 +202,10 @@ def callback(exception): ``` """ - MEMORY_LIMIT_REACHED: Event[pynisher.exceptions.MemoryLimitException] = Event( + MEMORY_LIMIT_REACHED: Event[ + [pynisher.exceptions.MemoryLimitException], + Any, + ] = Event( "pynisher-memory-limit", ) """A Task was submitted but reached it's memory limit. @@ -232,7 +233,10 @@ def callback(exception): ``` """ - CPU_TIME_LIMIT_REACHED: Event[pynisher.exceptions.CpuTimeoutException] = Event( + CPU_TIME_LIMIT_REACHED: Event[ + [pynisher.exceptions.CpuTimeoutException], + Any, + ] = Event( "pynisher-cputime-limit", ) """A Task was submitted but reached it's cpu time limit. @@ -263,7 +267,10 @@ def callback(exception): ``` """ - WALL_TIME_LIMIT_REACHED: Event[pynisher.exceptions.WallTimeoutException] = Event( + WALL_TIME_LIMIT_REACHED: Event[ + [pynisher.exceptions.WallTimeoutException], + Any, + ] = Event( "pynisher-walltime-limit", ) """A Task was submitted but reached it's wall time limit. @@ -369,7 +376,7 @@ def trial_evaluator_two(..., trial: Trial) -> int: ) # Will auto-detect - trial = Trial(...) + trial = Trial.create(...) task_one.submit(trial, ...) task_two.submit(..., trial=trial) @@ -378,6 +385,19 @@ def trial_evaluator_two(..., trial: Trial) -> int: ``` """ super().__init__() + + for limit, name in [ + (memory_limit, "memory"), + (cputime_limit, "cpu_time"), + (walltime_limit, "wall_time"), + ]: + if limit is not None and not self.supports(name): # type: ignore + raise RuntimeError( + f"Your platform does not support {name} limits." + " Please see pynisher documentation for more:" + "\nhttps://github.com/automl/pynisher#features", + ) + self.memory_limit = memory_limit self.cputime_limit = cputime_limit self.walltime_limit = walltime_limit @@ -409,7 +429,7 @@ def pre_submit( def attach_task(self, task: Task) -> None: """Attach the plugin to a task.""" self.task = task - task.emitter.add_event( + task.add_event( self.TIMEOUT, self.MEMORY_LIMIT_REACHED, self.CPU_TIME_LIMIT_REACHED, @@ -419,18 +439,6 @@ def attach_task(self, task: Task) -> None: # Check the exception and emit pynisher specific ones too task.on_exception(self._check_to_emit_pynisher_exception, hidden=True) - @override - def copy(self) -> Self: - """Return a copy of the plugin. - - Please see [`Plugin.copy()`][amltk.Plugin.copy]. - """ - return self.__class__( - memory_limit=self.memory_limit, - cputime_limit=self.cputime_limit, - walltime_limit=self.walltime_limit, - ) - def _check_to_emit_pynisher_exception( self, _: asyncio.Future, @@ -438,13 +446,13 @@ def _check_to_emit_pynisher_exception( ) -> None: """Check if the exception is a pynisher exception and emit it.""" if isinstance(exception, pynisher.CpuTimeoutException): - self.task.emitter.emit(self.TIMEOUT, exception) - self.task.emitter.emit(self.CPU_TIME_LIMIT_REACHED, exception) + self.task.emit(self.TIMEOUT, exception) + self.task.emit(self.CPU_TIME_LIMIT_REACHED, exception) elif isinstance(exception, pynisher.WallTimeoutException): - self.task.emitter.emit(self.TIMEOUT) - self.task.emitter.emit(self.WALL_TIME_LIMIT_REACHED, exception) + self.task.emit(self.TIMEOUT, exception) + self.task.emit(self.WALL_TIME_LIMIT_REACHED, exception) elif isinstance(exception, pynisher.MemoryLimitException): - self.task.emitter.emit(self.MEMORY_LIMIT_REACHED, exception) + self.task.emit(self.MEMORY_LIMIT_REACHED, exception) @classmethod def supports(cls, kind: Literal["wall_time", "cpu_time", "memory"]) -> bool: diff --git a/src/amltk/scheduling/plugins/threadpoolctl.py b/src/amltk/scheduling/plugins/threadpoolctl.py index 655293b4..5740417e 100644 --- a/src/amltk/scheduling/plugins/threadpoolctl.py +++ b/src/amltk/scheduling/plugins/threadpoolctl.py @@ -41,7 +41,7 @@ def f() -> None: import logging from collections.abc import Callable from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar -from typing_extensions import ParamSpec, Self, override +from typing_extensions import ParamSpec, override from amltk.scheduling.plugins.plugin import Plugin @@ -137,14 +137,6 @@ def pre_submit( ) return fn, args, kwargs - @override - def copy(self) -> Self: - """Return a copy of the plugin. - - Please see [`Plugin.copy()`][amltk.Plugin.copy]. - """ - return self.__class__(max_threads=self.max_threads, user_api=self.user_api) - @override def __rich__(self) -> Panel: from rich.panel import Panel diff --git a/src/amltk/scheduling/plugins/wandb.py b/src/amltk/scheduling/plugins/wandb.py index 205faf59..d1ddb4aa 100644 --- a/src/amltk/scheduling/plugins/wandb.py +++ b/src/amltk/scheduling/plugins/wandb.py @@ -22,7 +22,7 @@ TypeAlias, TypeVar, ) -from typing_extensions import ParamSpec, Self, override +from typing_extensions import ParamSpec, override import numpy as np import wandb @@ -232,11 +232,6 @@ def pre_submit( fn = WandbLiveRunWrap(self.params, fn, modify=self.modify) # type: ignore return fn, args, kwargs - @override - def copy(self) -> Self: - """Copy the plugin.""" - return self.__class__(modify=self.modify, params=replace(self.params)) - def _check_explicit_reinit_arg_with_executor( self, scheduler: Scheduler, diff --git a/src/amltk/scheduling/plugins/warning_filter.py b/src/amltk/scheduling/plugins/warning_filter.py index c998c735..4d950cf0 100644 --- a/src/amltk/scheduling/plugins/warning_filter.py +++ b/src/amltk/scheduling/plugins/warning_filter.py @@ -28,7 +28,7 @@ def f() -> None: import warnings from collections.abc import Callable from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar -from typing_extensions import ParamSpec, Self, override +from typing_extensions import ParamSpec, override from amltk.scheduling.plugins.plugin import Plugin @@ -110,11 +110,6 @@ def pre_submit( wrapped_f = _IgnoreWarningWrapper(fn, *self.warning_args, **self.warning_kwargs) return wrapped_f, args, kwargs - @override - def copy(self) -> Self: - """Return a copy of the plugin.""" - return self.__class__(*self.warning_args, **self.warning_kwargs) - @override def __rich__(self) -> Panel: from rich.panel import Panel diff --git a/src/amltk/scheduling/queue_monitor.py b/src/amltk/scheduling/queue_monitor.py index f7522b7a..d8b09df9 100644 --- a/src/amltk/scheduling/queue_monitor.py +++ b/src/amltk/scheduling/queue_monitor.py @@ -1,68 +1,4 @@ -"""A [`QueueMonitor`][amltk.scheduling.queue_monitor.QueueMonitor] is a -monitor for the scheduler queue. - -This module contains a monitor for the scheduler queue. The monitor tracks the -queue state at every event emitted by the scheduler. The data can be converted -to a pandas DataFrame or plotted as a stacked barchart. - -!!! note "Monitoring Frequency" - - To prevent repeated polling, we sample the scheduler queue at every scheduler event. - This is because the queue is only modified upon one of these events. This means we - don't need to poll the queue at a fixed interval. However, if you need more fine - grained updates, you can add extra events/timings at which the monitor should - [`update()`][amltk.scheduling.queue_monitor.QueueMonitor.update]. - -!!! warning "Performance impact" - - If your tasks and callbacks are very fast (~sub 10ms), then the monitor has a - non-nelgible impact however for most use cases, this should not be a problem. - As anything, you should profile how much work the scheduler can get done, - with and without the monitor, to see if it is a problem for your use case. - -In the below example, we have a very fast running function that runs on repeat, -sometimes too fast for the scheduler to keep up, letting some futures buildup needing -to be processed. - -```python exec="true" source="material-block" result="python" session="queue-monitor" -import time -import matplotlib.pyplot as plt -from amltk.scheduling import Scheduler -from amltk.scheduling.queue_monitor import QueueMonitor - -def fast_function(x: int) -> int: - return x + 1 -from amltk._doc import make_picklable; make_picklable(fast_function) # markdown-exec: hide - -N_WORKERS = 2 -scheduler = Scheduler.with_processes(N_WORKERS) -monitor = QueueMonitor(scheduler) -task = scheduler.task(fast_function) - -@scheduler.on_start(repeat=N_WORKERS) -def start(): - task.submit(1) - -@task.on_result -def result(_, x: int): - if scheduler.running(): - task.submit(x) - -scheduler.run(timeout=1) -df = monitor.df() -print(df) -``` - -We can also [`plot()`][amltk.scheduling.queue_monitor.QueueMonitor.plot] the data as a -stacked barchart with a set interval. - -```python exec="true" source="material-block" html="true" session="queue-monitor" -fig, ax = plt.subplots() -monitor.plot(interval=(50, "ms")) -from io import StringIO; fig.tight_layout(); buffer = StringIO(); plt.savefig(buffer, format="svg"); print(buffer.getvalue()) # markdown-exec: hide -``` - -""" # noqa: E501 +"""The queue monitoring.""" from __future__ import annotations import time diff --git a/src/amltk/scheduling/scheduler.py b/src/amltk/scheduling/scheduler.py index 02bda827..f684c4d9 100644 --- a/src/amltk/scheduling/scheduler.py +++ b/src/amltk/scheduling/scheduler.py @@ -1,253 +1,4 @@ -"""The [`Scheduler`][amltk.scheduling.Scheduler] uses -an [`Executor`][concurrent.futures.Executor], a builtin python native with -a `#!python submit(f, *args, **kwargs)` function to submit compute to -be compute else where, whether it be locally or remotely. - -The `Scheduler` is primarily used to dispatch compute to an `Executor` and -emit `@events`, which can trigger user callbacks. - -Typically you should not use the `Scheduler` directly for dispatching and -responding to computed functions, but rather use a [`Task`][amltk.scheduling.Task] - -??? note "Running in a Jupyter Notebook/Colab" - - If you are using a Jupyter Notebook, you likley need to use the following - at the top of your notebook: - - ```python - import nest_asyncio # Only necessary in Notebooks - nest_asyncio.apply() - - scheduler.run(...) - ``` - - This is due to the fact a notebook runs in an async context. If you do not - wish to use the above snippet, you can instead use: - - ```python - await scheduler.async_run(...) - ``` - -??? tip "Basic Usage" - - In this example, we create a scheduler that uses local processes as - workers. We then create a task that will run a function `fn` and submit it - to the scheduler. Lastly, a callback is registered to `@on_future_result` to print the - result when the compute is done. - - ```python exec="true" source="material-block" html="true" - from amltk.scheduling import Scheduler - - def fn(x: int) -> int: - return x + 1 - from amltk._doc import make_picklable; make_picklable(fn) # markdown-exec: hide - - scheduler = Scheduler.with_processes(1) - - @scheduler.on_start - def launch_the_compute(): - scheduler.submit(fn, 1) - - @scheduler.on_future_result - def callback(future, result): - print(f"Result: {result}") - - scheduler.run() - from amltk._doc import doc_print; doc_print(print, scheduler) # markdown-exec: hide - ``` - - The last line in the previous example called - [`scheduler.run()`][amltk.scheduling.Scheduler.run] is what starts the scheduler - running, in which it will first emit the `@on_start` event. This triggered the - callback `launch_the_compute()` which submitted the function `fn` with the - arguments `#!python 1`. - - The scheduler then ran the compute and waited for it to complete, emitting the - `@on_future_result` event when it was done successfully. This triggered the callback - `callback()` which printed the result. - - At this point, there is no more compute happening and no more events to respond to - so the scheduler will halt. - -??? example "`@events`" - - === "Scheduler Status Events" - - When the scheduler enters some important state, it will emit an event - to let you know. - - === "`@on_start`" - - ::: amltk.scheduling.Scheduler.on_start - - === "`@on_finishing`" - - ::: amltk.scheduling.Scheduler.on_finishing - - === "`@on_finished`" - - ::: amltk.scheduling.Scheduler.on_finished - - === "`@on_stop`" - - ::: amltk.scheduling.Scheduler.on_stop - - === "`@on_timeout`" - - ::: amltk.scheduling.Scheduler.on_timeout - - === "`@on_empty`" - - ::: amltk.scheduling.Scheduler.on_empty - - === "Submitted Compute Events" - - When any compute goes through the `Scheduler`, it will emit an event - to let you know. You should however prefer to use a - [`Task`][amltk.scheduling.Task] as it will emit specific events - for the task at hand, and not all compute. - - === "`@on_future_submitted`" - - ::: amltk.scheduling.Scheduler.on_future_submitted - - === "`@on_future_result`" - - ::: amltk.scheduling.Scheduler.on_future_result - - === "`@on_future_exception`" - - ::: amltk.scheduling.Scheduler.on_future_exception - - === "`@on_future_done`" - - ::: amltk.scheduling.Scheduler.on_future_done - - === "`@on_future_cancelled`" - - ::: amltk.scheduling.Scheduler.on_future_cancelled - - -??? tip "Common usages of `run()`" - - There are various ways to [`run()`][amltk.scheduling.Scheduler.run] the - scheduler, notably how long it should run with `timeout=` and also how - it should react to any exception that may have occurred within the `Scheduler` - itself or your callbacks. - - Please see the [`run()`][amltk.scheduling.Scheduler.run] API doc for more - details and features, however we show two common use cases of using the `timeout=` - parameter. - - You can render a live display using [`run(display=...)`][amltk.scheduling.Scheduler.run]. - This require [`rich`](https://github.com/Textualize/rich) to be installed. You - can install this with `#!bash pip install rich` or `#!bash pip install amltk[rich]`. - - - === "`run(timeout=...)`" - - You can tell the `Scheduler` to stop after a certain amount of time - with the `timeout=` argument to [`run()`][amltk.scheduling.Scheduler.run]. - - This will also trigger the `@on_timeout` event as seen in the `Scheduler` output. - - ```python exec="true" source="material-block" html="True" hl_lines="19" - import time - from amltk.scheduling import Scheduler - - scheduler = Scheduler.with_processes(1) - - def expensive_function() -> int: - time.sleep(0.1) - return 42 - from amltk._doc import make_picklable; make_picklable(expensive_function) # markdown-exec: hide - - @scheduler.on_start - def submit_calculations() -> None: - scheduler.submit(expensive_function) - - # The will endlessly loop the scheduler - @scheduler.on_future_done - def submit_again(future: Future) -> None: - if scheduler.running(): - scheduler.submit(expensive_function) - - scheduler.run(timeout=1) # End after 1 second - from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fontsize="small") # markdown-exec: hide - ``` - - === "`run(timeout=..., wait=False)`" - - By specifying that the `Scheduler` should not wait for ongoing tasks - to finish, the `Scheduler` will attempt to cancel and possibly terminate - any running tasks. - - ```python exec="true" source="material-block" html="True" - import time - from amltk.scheduling import Scheduler - - scheduler = Scheduler.with_processes(1) - - def expensive_function() -> None: - time.sleep(10) - - from amltk._doc import make_picklable; make_picklable(expensive_function) # markdown-exec: hide - - @scheduler.on_start - def submit_calculations() -> None: - scheduler.submit(expensive_function) - - scheduler.run(timeout=1, wait=False) # End after 1 second - from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fontsize="small") # markdown-exec: hide - ``` - - ??? info "Forcibly Terminating Workers" - - As an `Executor` does not provide an interface to forcibly - terminate workers, we provide `Scheduler(terminate=...)` as a custom - strategy for cleaning up a provided executor. It is not possible - to terminate running thread based workers, for example using - `ThreadPoolExecutor` and any Executor using threads to spawn - tasks will have to wait until all running tasks are finish - before python can close. - - It's likely `terminate` will trigger the `EXCEPTION` event for - any tasks that are running during the shutdown, **not*** - a cancelled event. This is because we use a - [`Future`][concurrent.futures.Future] - under the hood and these can not be cancelled once running. - However there is no guarantee of this and is up to how the - `Executor` handles this. - -??? example "Scheduling something to be run later" - - You can schedule some function to be run later using the - [`#!python scheduler.call_later()`][amltk.scheduling.Scheduler.call_later] method. - - !!! note - - This does not run the function in the background, it just schedules some - function to be called later, where you could perhaps then use submit to - scheduler a [`Task`][amltk.scheduling.Task] to run the function in the - background. - - ```python exec="true" source="material-block" result="python" - from amltk.scheduling import Scheduler - - scheduler = Scheduler.with_processes(1) - - def fn() -> int: - print("Ending now!") - scheduler.stop() - - @scheduler.on_start - def schedule_fn() -> None: - scheduler.call_later(1, fn) - - scheduler.run(end_on_empty=False) - ``` - -""" # noqa: E501 +"""The scheduler for AMLTK.""" from __future__ import annotations import asyncio @@ -256,6 +7,7 @@ def schedule_fn() -> None: from asyncio import Future from collections.abc import Callable, Iterable, Mapping from concurrent.futures import CancelledError, Executor, ProcessPoolExecutor +from contextlib import nullcontext from dataclasses import dataclass from enum import Enum, auto from functools import partial @@ -275,6 +27,7 @@ def schedule_fn() -> None: from amltk._asyncm import ContextEvent from amltk._functional import Flag, subclass_map from amltk._richutil.renderable import RichRenderable +from amltk._util import ignore_warnings, mutli_context from amltk.exceptions import SchedulerNotRunningError from amltk.scheduling.events import Emitter, Event, Subscriber from amltk.scheduling.executors import SequentialExecutor @@ -310,7 +63,7 @@ class Scheduler(RichRenderable): queue: dict[Future, tuple[Callable, tuple, dict]] """The queue of tasks running.""" - on_start: Subscriber[[]] + on_start: Subscriber[[], Any] """A [`Subscriber`][amltk.scheduling.events.Subscriber] which is called when the scheduler starts. This is the first event emitted by the scheduler and one of the only ways to submit the initial compute to the scheduler. @@ -321,7 +74,7 @@ def my_callback(): ... ``` """ - on_future_submitted: Subscriber[Future] + on_future_submitted: Subscriber[[Future], Any] """A [`Subscriber`][amltk.scheduling.events.Subscriber] which is called when some compute is submitted. @@ -331,7 +84,7 @@ def my_callback(future: Future): ... ``` """ - on_future_done: Subscriber[Future] + on_future_done: Subscriber[[Future], Any] """A [`Subscriber`][amltk.scheduling.events.Subscriber] which is called when some compute is done, regardless of whether it was successful or not. @@ -341,7 +94,7 @@ def my_callback(future: Future): ... ``` """ - on_future_result: Subscriber[Future, Any] + on_future_result: Subscriber[[Future, Any], Any] """A [`Subscriber`][amltk.scheduling.events.Subscriber] which is called when a future returned with a result, no exception raise. @@ -351,7 +104,7 @@ def my_callback(future: Future, result: Any): ... ``` """ - on_future_exception: Subscriber[Future, BaseException] + on_future_exception: Subscriber[[Future, BaseException], Any] """A [`Subscriber`][amltk.scheduling.events.Subscriber] which is called when some compute raised an uncaught exception. @@ -361,7 +114,7 @@ def my_callback(future: Future, exception: BaseException): ... ``` """ - on_future_cancelled: Subscriber[Future] + on_future_cancelled: Subscriber[[Future], Any] """A [`Subscriber`][amltk.scheduling.events.Subscriber] which is called when a future is cancelled. This usually occurs due to the underlying Scheduler, and is not something we do directly, other than when shutting down the scheduler. @@ -372,7 +125,7 @@ def my_callback(future: Future): ... ``` """ - on_finishing: Subscriber[[]] + on_finishing: Subscriber[[], Any] """A [`Subscriber`][amltk.scheduling.events.Subscriber] which is called when the scheduler is finishing up. This occurs right before the scheduler shuts down the executor. @@ -383,7 +136,7 @@ def my_callback(): ... ``` """ - on_finished: Subscriber[[]] + on_finished: Subscriber[[], Any] """A [`Subscriber`][amltk.scheduling.events.Subscriber] which is called when the scheduler is finished, has shutdown the executor and possibly terminated any remaining compute. @@ -394,7 +147,7 @@ def my_callback(): ... ``` """ - on_stop: Subscriber[str, BaseException | None] + on_stop: Subscriber[[str, BaseException | None], Any] """A [`Subscriber`][amltk.scheduling.events.Subscriber] which is called when the scheduler is has been stopped due to the [`stop()`][amltk.scheduling.Scheduler.stop] method being called. @@ -405,7 +158,7 @@ def my_callback(stop_msg: str, exception: BaseException | None): ... ``` """ - on_timeout: Subscriber[[]] + on_timeout: Subscriber[[], Any] """A [`Subscriber`][amltk.scheduling.events.Subscriber] which is called when the scheduler reaches the timeout. @@ -415,7 +168,7 @@ def my_callback(): ... ``` """ - on_empty: Subscriber[[]] + on_empty: Subscriber[[], Any] """A [`Subscriber`][amltk.scheduling.events.Subscriber] which is called when the queue is empty. This can be useful to re-fill the queue and prevent the scheduler from exiting. @@ -427,17 +180,19 @@ def my_callback(): ``` """ - STARTED: Event[[]] = Event("on_start") - FINISHING: Event[[]] = Event("on_finishing") - FINISHED: Event[[]] = Event("on_finished") - STOP: Event[str, BaseException | None] = Event("on_stop") - TIMEOUT: Event[[]] = Event("on_timeout") - EMPTY: Event[[]] = Event("on_empty") - FUTURE_SUBMITTED: Event[Future] = Event("on_future_submitted") - FUTURE_DONE: Event[Future] = Event("on_future_done") - FUTURE_CANCELLED: Event[Future] = Event("on_future_cancelled") - FUTURE_RESULT: Event[Future, Any] = Event("on_future_result") - FUTURE_EXCEPTION: Event[Future, BaseException] = Event("on_future_exception") + STARTED: Event[[], Any] = Event("on_start") + FINISHING: Event[[], Any] = Event("on_finishing") + FINISHED: Event[[], Any] = Event("on_finished") + STOP: Event[[str, BaseException | None], Any] = Event("on_stop") + TIMEOUT: Event[[], Any] = Event("on_timeout") + EMPTY: Event[[], Any] = Event("on_empty") + FUTURE_SUBMITTED: Event[[Future], Any] = Event("on_future_submitted") + FUTURE_DONE: Event[[Future], Any] = Event("on_future_done") + FUTURE_CANCELLED: Event[[Future], Any] = Event("on_future_cancelled") + FUTURE_RESULT: Event[[Future, Any], Any] = Event("on_future_result") + FUTURE_EXCEPTION: Event[[Future, BaseException], Any] = Event( + "on_future_exception", + ) def __init__( self, @@ -497,7 +252,7 @@ def __init__( self._queue_has_items_event = asyncio.Event() # This is triggered when run is called - self._running_event = asyncio.Event() + self._running_event = ContextEvent() # This is set once `run` is called. # Either contains the mapping from exception to what to do, @@ -1116,8 +871,8 @@ def _register_complete(self, future: Future) -> None: self.stop( stop_msg=( f"raising on exception '{type(exception)}'" - f" as {err_type} is 'raise' as specified from" - f"`on_exception={self._on_exc_method_map.value}" + f" as scheduler was run with" + f" `on_exception={self._on_exc_method_map.value}" ), exception=exception, ) @@ -1170,123 +925,120 @@ async def _stop_when_triggered(self, stop_event: ContextEvent) -> bool: logger.debug("Stop event triggered, stopping scheduler") return True - async def _run_scheduler( # noqa: C901, PLR0912, PLR0915 + async def _run_scheduler( # noqa: C901, PLR0915 self, *, timeout: float | None = None, end_on_empty: bool = True, wait: bool = True, ) -> ExitState.Code | BaseException: - self.executor.__enter__() - self._stop_event = ContextEvent() - - if self._live_output is not None: - self._live_output.__enter__() - - # If we are doing a live display, we have to disable - # warnings as they will screw up the display rendering - # However, we re-enable it after the scheduler has finished running - warning_catcher = warnings.catch_warnings() - warning_catcher.__enter__() - warnings.filterwarnings("ignore") - else: - warning_catcher = None - - # Declare we are running - self._running_event.set() - - # Start a Thread Timer as our timing mechanism. - # HACK: This is required because the SequentialExecutor mode - # will not allow the async loop to run, meaning we can't update - # any internal state. - if timeout is not None: - self._timeout_timer = Timer(timeout, lambda: None) - self._timeout_timer.start() - - self.on_start.emit() - - # Monitor for `stop` being triggered - stop_triggered = asyncio.create_task( - self._stop_when_triggered(self._stop_event), - ) - - # Monitor for the queue being empty - monitor_empty = asyncio.create_task(self._monitor_queue_empty()) - if end_on_empty: - self.on_empty(lambda: monitor_empty.cancel(), hidden=True) - - # The timeout criterion is satisfied by the `timeout` arg - await asyncio.wait( - [stop_triggered, monitor_empty], - timeout=timeout, - return_when=asyncio.FIRST_COMPLETED, - ) - - # Determine the reason for stopping - stop_reason: BaseException | ExitState.Code - if stop_triggered.done() and self._stop_event.is_set(): - stop_reason = ExitState.Code.STOPPED - - msg, exception = self._stop_event.context - _log = logger.exception if exception else logger.debug - _log(f"Stop Message: {msg}", exc_info=exception) - - self.on_stop.emit(str(msg), exception) - if self._on_exc_method_map and exception: - stop_reason = exception - else: - stop_reason = ExitState.Code.STOPPED - elif monitor_empty.done(): - logger.debug("Scheduler stopped due to being empty.") - stop_reason = ExitState.Code.EXHAUSTED - elif timeout is not None: - logger.debug(f"Scheduler stopping as {timeout=} reached.") - stop_reason = ExitState.Code.TIMEOUT - self.on_timeout.emit() - else: - logger.warning("Scheduler stopping for unknown reason!") - stop_reason = ExitState.Code.UNKNOWN - - # Stop all running async tasks, i.e. monitoring the queue to trigger an event - tasks = [monitor_empty, stop_triggered] - for task in tasks: - task.cancel() + with mutli_context( + self.executor, + self._live_output if self._live_output is not None else nullcontext(), + ignore_warnings() if self._live_output is not None else nullcontext(), + ): + self._stop_event = ContextEvent() + + # Start a Thread Timer as our timing mechanism. + # HACK: This is required because the SequentialExecutor mode + # will not allow the async loop to run, meaning we can't update + # any internal state. For this reason, we use a Thread + if timeout is not None: + self._timeout_timer = Timer(timeout, lambda: None) + self._timeout_timer.start() + + with self._running_event: + self.on_start.emit() + + # Monitor for `stop` being triggered + stop_triggered = asyncio.create_task( + self._stop_when_triggered(self._stop_event), + ) - # Await all the cancelled tasks and read the exceptions - await asyncio.gather(*tasks, return_exceptions=True) + # Monitor for the queue being empty + monitor_empty = asyncio.create_task(self._monitor_queue_empty()) + if end_on_empty: + self.on_empty.register(lambda: monitor_empty.cancel(), hidden=True) - self.on_finishing.emit() - logger.debug("Scheduler is finished") - logger.debug(f"Shutting down scheduler executor with {wait=}") + # The timeout criterion is satisfied by the `timeout` arg + await asyncio.wait( + [stop_triggered, monitor_empty], + timeout=timeout, + return_when=asyncio.FIRST_COMPLETED, + ) - # The scheduler is now refusing jobs - self._running_event.clear() - logger.debug("Scheduler has shutdown and declared as no longer running") + # Determine the reason for stopping + stop_reason: BaseException | ExitState.Code + if stop_triggered.done() and self._stop_event.is_set(): + stop_reason = ExitState.Code.STOPPED + + msg, exception = self._stop_event.context + _log = logger.exception if exception else logger.debug + _log(f"Stop Message: {msg}", exc_info=exception) + + self.on_stop.emit(str(msg), exception) + if self._on_exc_method_map and exception: + stop_reason = exception + else: + stop_reason = ExitState.Code.STOPPED + elif monitor_empty.done(): + logger.debug("Scheduler stopped due to being empty.") + stop_reason = ExitState.Code.EXHAUSTED + elif timeout is not None: + logger.debug(f"Scheduler stopping as {timeout=} reached.") + stop_reason = ExitState.Code.TIMEOUT + self.on_timeout.emit() + else: + logger.warning("Scheduler stopping for unknown reason!") + stop_reason = ExitState.Code.UNKNOWN + + # Stop all running async tasks, i.e. monitoring the queue to trigger + # an event + tasks = [monitor_empty, stop_triggered] + for task in tasks: + task.cancel() + + # Await all the cancelled tasks and read the exceptions + await asyncio.gather(*tasks, return_exceptions=True) + + self.on_finishing.emit() + logger.debug("Scheduler is finished") + logger.debug(f"Shutting down scheduler executor with {wait=}") + + # Scheduler is now refusing jobs + logger.debug("Scheduler has shutdown and declared as no longer running") + + # This will try to end the tasks based on wait and self._terminate + if not wait: + if self._terminate is None: + logger.warning( + "Cancelling currently running tasks and then waiting " + f" as there is no termination strategy for {self.executor=}`.", + ) - # This will try to end the tasks based on wait and self._terminate - Scheduler._end_pending( - wait=wait, - futures=list(self.queue.keys()), - executor=self.executor, - termination_strategy=self._terminate, - ) + for future in self.queue: + if not future.done(): + logger.debug(f"Cancelling {future=}") + future.cancel() - self.on_finished.emit() - logger.debug(f"Scheduler finished with status {stop_reason}") + if self._terminate is not None: + logger.debug(f"Terminating workers with {termination_strategy=}") + self._terminate(self.executor) - # Clear all events - self._stop_event.clear() - self._queue_has_items_event.clear() + logger.debug("Waiting for executor to finish.") + self.executor.shutdown(wait=wait) + self.on_finished.emit() + logger.debug(f"Scheduler finished with status {stop_reason}") - if self._live_output is not None: - self._live_output.refresh() - self._live_output.stop() + # Clear all events + self._stop_event.clear() + self._queue_has_items_event.clear() - if self._timeout_timer is not None: - self._timeout_timer.cancel() + if self._live_output is not None: + self._live_output.refresh() - if warning_catcher is not None: - warning_catcher.__exit__() # type: ignore + if self._timeout_timer is not None: + self._timeout_timer.cancel() return stop_reason @@ -1302,7 +1054,7 @@ def run( ) = "raise", on_cancelled: Literal["raise", "end", "continue"] = "raise", asyncio_debug_mode: bool = False, - display: bool | Iterable[RenderableType] = False, + display: bool | Iterable[RenderableType] | Literal["auto"] = "auto", ) -> ExitState: """Run the scheduler. @@ -1378,6 +1130,11 @@ def run( Defaults to `False`. Please see [asyncio.run][] for more. display: Whether to display the scheduler live in the console. + * If `#!python "auto"`, will display the scheduler if in + a notebook or colab environemnt. Otherwise, it will not display + it. If left as "auto" and the display occurs, a warning will + be printed alongside it. + * If `#!python False`, will not display anything. * If `#!python True`, will display the scheduler and all its tasks. * If a `#!python list[RenderableType]` , will display the scheduler itself plus those renderables. @@ -1388,6 +1145,20 @@ def run( Raises: RuntimeError: If the scheduler is already running. """ + if display == "auto": + from amltk._richutil import is_jupyter + + display = is_jupyter() + if display is True: + warnings.warn( + "Detected that current running context is in a notebook!" + " When `display='auto'`, the default, the scheduler will" + " automatically be set to display. If you do not want this or" + " wish to disable this warning, please set `display=False`.", + UserWarning, + stacklevel=2, + ) + return asyncio.run( self.async_run( timeout=timeout, @@ -1669,41 +1440,6 @@ def call_later( loop = asyncio.get_running_loop() return loop.call_later(delay, _fn) - @staticmethod - def _end_pending( - *, - futures: list[Future], - executor: Executor, - wait: bool = True, - termination_strategy: Callable[[Executor], Any] | None = None, - ) -> None: - if wait: - logger.debug("Waiting for currently running tasks to finish.") - executor.shutdown(wait=wait) - elif termination_strategy is None: - logger.warning( - "Cancelling currently running tasks and then waiting " - f" as there is no termination strategy provided for {executor=}`.", - ) - # Just try to cancel the tasks. Will cancel pending tasks - # but executors like dask will even kill the job - for future in futures: - if not future.done(): - logger.debug(f"Cancelling {future=}") - future.cancel() - - # Here we wait, if we could cancel, then we wait for that - # to happen, otherwise we are just waiting as anticipated. - executor.shutdown(wait=wait) - else: - logger.debug(f"Terminating workers with {termination_strategy=}") - for future in futures: - if not future.done(): - logger.debug(f"Cancelling {future=}") - future.cancel() - termination_strategy(executor) - executor.shutdown(wait=wait) - def add_renderable(self, renderable: RenderableType) -> None: """Add a renderable object to the scheduler. diff --git a/src/amltk/scheduling/task.py b/src/amltk/scheduling/task.py index a6550666..2e8cc7c4 100644 --- a/src/amltk/scheduling/task.py +++ b/src/amltk/scheduling/task.py @@ -1,89 +1,15 @@ -"""A [`Task`][amltk.scheduling.task.Task] is a unit of work that can be scheduled by the -[`Scheduler`][amltk.scheduling.Scheduler]. - -It is defined by its `function=` to call. Whenever a `Task` -has its [`submit()`][amltk.scheduling.task.Task.submit] method called, -the function will be dispatched to run by a `Scheduler`. - -When a task has returned, either successfully, or with an exception, -it will emit `@events` to indicate so. You can subscribe to these events -with callbacks and act accordingly. - - -??? example "`@events`" - - Check out the `@events` reference - for more on how to customize these callbacks. You can also take a look - at the API of [`on()`][amltk.scheduling.task.Task.on] for more information. - - === "`@on_result`" - - ::: amltk.scheduling.task.Task.on_result - - === "`@on_exception`" - - ::: amltk.scheduling.task.Task.on_exception - - === "`@on_done`" - - ::: amltk.scheduling.task.Task.on_done - - === "`@on_submitted`" - - ::: amltk.scheduling.task.Task.on_submitted - - === "`@on_cancelled`" - - ::: amltk.scheduling.task.Task.on_cancelled - -??? tip "Usage" - - The usual way to create a task is with - [`Scheduler.task()`][amltk.scheduling.scheduler.Scheduler.task], - where you provide the `function=` to call. - - ```python exec="true" source="material-block" html="true" - from amltk import Scheduler - - def f(x: int) -> int: - return x * 2 - from amltk._doc import make_picklable; make_picklable(f) # markdown-exec: hide - - scheduler = Scheduler.with_processes(2) - task = scheduler.task(f) - - @scheduler.on_start - def on_start(): - task.submit(1) - - @task.on_result - def on_result(future: Future[int], result: int): - print(f"Task {future} returned {result}") - - scheduler.run() - from amltk._doc import doc_print; doc_print(print, scheduler) # markdown-exec: hide - ``` - - If you'd like to simply just call the original function, without submitting it to - the scheduler, you can always just call the task directly, i.e. `#!python task(1)`. - -You can also provide [`Plugins`][amltk.scheduling.plugins.Plugin] to the task, -to modify tasks, add functionality and add new events. -""" +"""The task module.""" from __future__ import annotations import logging from asyncio import Future from collections.abc import Callable, Iterable -from typing import TYPE_CHECKING, Any, Concatenate, Generic, TypeVar, overload +from typing import TYPE_CHECKING, Any, Concatenate, Generic, TypeVar from typing_extensions import ParamSpec, Self, override -from more_itertools import first_true - -from amltk._functional import callstring +from amltk._functional import callstring, funcname from amltk._richutil.renderable import RichRenderable -from amltk.exceptions import EventNotKnownError, SchedulerNotRunningError -from amltk.randomness import randuid +from amltk.exceptions import SchedulerNotRunningError from amltk.scheduling.events import Emitter, Event, Subscriber from amltk.scheduling.plugins.plugin import Plugin @@ -103,7 +29,7 @@ def on_result(future: Future[int], result: int): CallableT = TypeVar("CallableT", bound=Callable) -class Task(RichRenderable, Generic[P, R]): +class Task(Emitter, RichRenderable, Generic[P, R]): """The task class.""" unique_ref: str @@ -121,7 +47,7 @@ class Task(RichRenderable, Generic[P, R]): emitter: Emitter """The emitter for events of this task.""" - on_submitted: Subscriber[Concatenate[Future[R], P]] + on_submitted: Subscriber[Concatenate[Future[R], P], Any] """An event that is emitted when a future is submitted to the scheduler. It will pass the future as the first argument with args and kwargs following. @@ -133,7 +59,7 @@ def on_submitted(future: Future[R], *args, **kwargs): print(f"Future {future} was submitted with {args=} and {kwargs=}") ``` """ - on_done: Subscriber[Future[R]] + on_done: Subscriber[[Future[R]], Any] """Called when a task is done running with a result or exception. ```python @task.on_done @@ -141,7 +67,7 @@ def on_done(future: Future[R]): print(f"Future {future} is done") ``` """ - on_cancelled: Subscriber[Future[R]] + on_cancelled: Subscriber[[Future[R]], Any] """Called when a task is cancelled. ```python @task.on_cancelled @@ -149,7 +75,7 @@ def on_cancelled(future: Future[R]): print(f"Future {future} was cancelled") ``` """ - on_result: Subscriber[Future[R], R] + on_result: Subscriber[[Future[R], R], Any] """Called when a task has successfully returned a value. Comes with Future ```python @@ -158,7 +84,7 @@ def on_result(future: Future[R], result: R): print(f"Future {future} returned {result}") ``` """ - on_exception: Subscriber[Future[R], BaseException] + on_exception: Subscriber[[Future[R], BaseException], Any] """Called when a task failed to return anything but an exception. Comes with Future ```python @@ -168,12 +94,12 @@ def on_exception(future: Future[R], error: BaseException): ``` """ - SUBMITTED: Event[Concatenate[Future[R], P]] = Event("on_submitted") - DONE: Event[Future[R]] = Event("on_done") + SUBMITTED: Event[Concatenate[Future[R], P], Any] = Event("on-submitted") + DONE: Event[[Future[R]], Any] = Event("on-done") - CANCELLED: Event[Future[R]] = Event("on_cancelled") - RESULT: Event[Future[R], R] = Event("on_result") - EXCEPTION: Event[Future[R], BaseException] = Event("on_exception") + CANCELLED: Event[[Future[R]], Any] = Event("on-cancelled") + RESULT: Event[[Future[R], R], Any] = Event("on-result") + EXCEPTION: Event[[Future[R], BaseException], Any] = Event("on-exception") def __init__( self: Self, @@ -191,11 +117,7 @@ def __init__( plugins: The plugins to use for this task. init_plugins: Whether to initialize the plugins or not. """ - super().__init__() - self.unique_ref = randuid(8) - - self.emitter = Emitter() - self.event_counts = self.emitter.event_counts + super().__init__(name=f"Task-{funcname(function)}") self.plugins: list[Plugin] = ( [plugins] if isinstance(plugins, Plugin) else list(plugins) ) @@ -205,11 +127,11 @@ def __init__( self.queue: list[Future[R]] = [] # Set up subscription methods to events - self.on_submitted = self.emitter.subscriber(self.SUBMITTED) - self.on_done = self.emitter.subscriber(self.DONE) - self.on_result = self.emitter.subscriber(self.RESULT) - self.on_exception = self.emitter.subscriber(self.EXCEPTION) - self.on_cancelled = self.emitter.subscriber(self.CANCELLED) + self.on_submitted = self.subscriber(self.SUBMITTED) + self.on_done = self.subscriber(self.DONE) + self.on_result = self.subscriber(self.RESULT) + self.on_exception = self.subscriber(self.EXCEPTION) + self.on_cancelled = self.subscriber(self.CANCELLED) if init_plugins: for plugin in self.plugins: @@ -223,97 +145,6 @@ def futures(self) -> list[Future[R]]: """ return self.queue - @overload - def on( - self, - event: Event[P2], - callback: None = None, - *, - when: Callable[[], bool] | None = ..., - max_calls: int | None = ..., - repeat: int = ..., - every: int = ..., - ) -> Subscriber[P2]: - ... - - @overload - def on( - self, - event: str, - callback: None = None, - *, - when: Callable[[], bool] | None = ..., - max_calls: int | None = ..., - repeat: int = ..., - every: int = ..., - ) -> Subscriber[...]: - ... - - @overload - def on( - self, - event: str, - callback: Callable, - *, - when: Callable[[], bool] | None = ..., - max_calls: int | None = ..., - repeat: int = ..., - every: int = ..., - ) -> None: - ... - - def on( - self, - event: Event[P2] | str, - callback: Callable[P2, Any] | None = None, - *, - when: Callable[[], bool] | None = None, - max_calls: int | None = None, - repeat: int = 1, - every: int = 1, - hidden: bool = False, - ) -> Subscriber[P2] | Subscriber[...] | None: - """Subscribe to an event. - - Args: - event: The event to subscribe to. - callback: The callback to call when the event is emitted. - If not specified, what is returned can be used as a decorator. - when: A predicate to determine whether to call the callback. - max_calls: The maximum number of times to call the callback. - repeat: The number of times to repeat the subscription. - every: The number of times to wait between repeats. - hidden: Whether to hide the callback in visual output. - This is mainly used to facilitate Plugins who - act upon events but don't want to be seen, primarily - as they are just book-keeping callbacks. - - Returns: - The subscriber if no callback was provided, otherwise `None`. - """ - if isinstance(event, str): - _e = first_true(self.emitter.events, None, lambda e: e.name == event) - if _e is None: - raise EventNotKnownError( - f"{event=} is not a valid event." - f"\nKnown events are: {[e.name for e in self.emitter.events]}", - ) - else: - _e = event - - subscriber = self.emitter.subscriber( - _e, # type: ignore - when=when, - max_calls=max_calls, - repeat=repeat, - every=every, - ) - if callback is None: - return subscriber - - subscriber(callback, hidden=hidden) - return None - @property def n_running(self) -> int: """Get the number of futures for this task that are currently running.""" @@ -370,6 +201,7 @@ def submit(self, *args: P.args, **kwargs: P.kwargs) -> Future[R] | None: # original function name. msg = f"Submitted {callstring(self.function, *args, **kwargs)} from {self}." logger.debug(msg) + self.on_submitted.emit(future, *args, **kwargs) # Process the task once it's completed @@ -378,27 +210,6 @@ def submit(self, *args: P.args, **kwargs: P.kwargs) -> Future[R] | None: future.add_done_callback(self._process_future) return future - def copy(self, *, init_plugins: bool = True) -> Self: - """Create a copy of this task. - - Will use the same scheduler and function, but will have a different - event manager such that any events listend to on the old task will - **not** trigger with the copied task. - - Args: - init_plugins: Whether to initialize the copied plugins on the copied - task. Usually you will want to leave this as `True`. - - Returns: - A copy of this task. - """ - return self.__class__( - self.function, - self.scheduler, - plugins=tuple(p.copy() for p in self.plugins), - init_plugins=init_plugins, - ) - def _process_future(self, future: Future[R]) -> None: try: self.queue.remove(future) @@ -444,7 +255,7 @@ def _when_future_from_submission( @override def __repr__(self) -> str: - kwargs = {"unique_ref": self.unique_ref} + kwargs = {"unique_ref": self.unique_ref, "plugins": self.plugins} kwargs_str = ", ".join(f"{k}={v}" for k, v in kwargs.items()) return f"{self.__class__.__name__}({kwargs_str})" @@ -464,7 +275,6 @@ def __rich__(self) -> Panel: items.append(plugin) tree = Tree(label="", hide_root=True) - tree.add(self.emitter) items.append(tree) return Panel( diff --git a/src/amltk/scheduling/termination_strategies.py b/src/amltk/scheduling/termination_strategies.py index d93dcf92..5133496c 100644 --- a/src/amltk/scheduling/termination_strategies.py +++ b/src/amltk/scheduling/termination_strategies.py @@ -23,6 +23,29 @@ _Executor = TypeVar("_Executor", bound=Executor) +def polite_kill(process: psutil.Process, timeout: int | None = None) -> None: + """Politely kill a process. + + This works by first sending a SIGTERM to the process, and then if it + doesn't respond to that, sending a SIGKILL. + + On Windows, SIGTERM is not available, so `terminate()` will + send a `SIGKILL` directly. + + Args: + process: The process to kill. + timeout: The time to wait for the process after sending SIGTERM. + before resorting to SIGKILL. If None, wait indefinitely. + """ + with suppress(psutil.NoSuchProcess): + process.terminate() + process.wait(timeout=timeout) + + # Forcibly kill it if it's not responding to the SIGTERM + if process.is_running(): + process.kill() + + def _terminate_with_psutil(executor: ProcessPoolExecutor) -> None: """Terminate all processes in the given executor using psutil. @@ -35,19 +58,19 @@ def _terminate_with_psutil(executor: ProcessPoolExecutor) -> None: if not executor._processes: return - worker_processes = [psutil.Process(p.pid) for p in executor._processes.values()] - for worker_process in worker_processes: + for process in executor._processes.values(): try: - child_preocesses = worker_process.children(recursive=True) + worker_process = psutil.Process(process.pid) + # We reverse here to start from leaf processes first, giving parents + # time to cleanup after their terminated subprocesses. + child_processes = reversed(worker_process.children(recursive=True)) except psutil.NoSuchProcess: continue - for child_process in child_preocesses: - with suppress(psutil.NoSuchProcess): - child_process.terminate() + for child_process in child_processes: + polite_kill(child_process, timeout=5) - with suppress(psutil.NoSuchProcess): - worker_process.terminate() + polite_kill(worker_process, timeout=5) def termination_strategy(executor: _Executor) -> Callable[[_Executor], None] | None: diff --git a/src/amltk/sklearn/__init__.py b/src/amltk/sklearn/__init__.py index 24949eaf..2189af06 100644 --- a/src/amltk/sklearn/__init__.py +++ b/src/amltk/sklearn/__init__.py @@ -3,6 +3,7 @@ StoredPredictionClassifier, StoredPredictionRegressor, ) +from amltk.sklearn.evaluation import CVEvaluation from amltk.sklearn.voting import voting_with_preffited_estimators __all__ = [ @@ -11,4 +12,5 @@ "StoredPredictionRegressor", "StoredPredictionClassifier", "voting_with_preffited_estimators", + "CVEvaluation", ] diff --git a/src/amltk/sklearn/evaluation.py b/src/amltk/sklearn/evaluation.py new file mode 100644 index 00000000..8d9d3bd3 --- /dev/null +++ b/src/amltk/sklearn/evaluation.py @@ -0,0 +1,1940 @@ +"""This module contains the cross-validation evaluation protocol. + +This protocol will create a cross-validation task to be used in parallel and +optimization. It represents a typical cross-validation evaluation for sklearn, +handling some of the minor nuances of sklearn and it's interaction with +optimization and parallelization. + +Please see [`CVEvaluation`][amltk.sklearn.evaluation.CVEvaluation] for more +information on usage. +""" +from __future__ import annotations + +import logging +import tempfile +import warnings +from asyncio import Future +from collections import defaultdict +from collections.abc import Callable, Iterator, Mapping, MutableMapping, Sized +from contextlib import nullcontext +from dataclasses import dataclass +from datetime import datetime +from functools import partial +from numbers import Number +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Literal, + NamedTuple, + ParamSpec, + Protocol, + TypeAlias, + TypeVar, +) +from typing_extensions import override + +import numpy as np +import pandas as pd +from more_itertools import all_equal +from sklearn.base import BaseEstimator, clone +from sklearn.exceptions import UnsetMetadataPassedError +from sklearn.metrics import get_scorer +from sklearn.metrics._scorer import _MultimetricScorer, _Scorer +from sklearn.model_selection import ( + KFold, + ShuffleSplit, + StratifiedKFold, + StratifiedShuffleSplit, +) +from sklearn.model_selection._validation import _score +from sklearn.utils import Bunch +from sklearn.utils._metadata_requests import _routing_enabled +from sklearn.utils.metadata_routing import ( + MetadataRouter, + MethodMapping, + process_routing, +) +from sklearn.utils.metaestimators import _safe_split +from sklearn.utils.multiclass import type_of_target +from sklearn.utils.validation import _check_method_params + +import amltk.randomness +from amltk._functional import subclass_map +from amltk.exceptions import ( + AutomaticTaskTypeInferredWarning, + CVEarlyStoppedError, + ImplicitMetricConversionWarning, + MismatchedTaskTypeWarning, + TrialError, +) +from amltk.optimization import Trial +from amltk.profiling.profiler import Profiler +from amltk.scheduling import Plugin, Task +from amltk.scheduling.events import Emitter, Event +from amltk.scheduling.plugins.comm import Comm +from amltk.store import Stored +from amltk.store.paths.path_bucket import PathBucket + +if TYPE_CHECKING: + from sklearn.model_selection import ( + BaseCrossValidator, + BaseShuffleSplit, + ) + + from amltk.pipeline import Node + from amltk.randomness import Seed + +P = ParamSpec("P") +R = TypeVar("R") + + +logger = logging.getLogger(__name__) + +BaseEstimatorT = TypeVar("BaseEstimatorT", bound=BaseEstimator) + +TaskTypeName: TypeAlias = Literal[ + "binary", + "multiclass", + "multilabel-indicator", + "multiclass-multioutput", + "continuous", + "continuous-multioutput", +] +"""A type alias for the task type name as defined by sklearn.""" + +_valid_task_types: tuple[TaskTypeName, ...] = ( + "binary", + "multiclass", + "multilabel-indicator", + "multiclass-multioutput", + "continuous", + "continuous-multioutput", +) + +XLike: TypeAlias = pd.DataFrame | np.ndarray +"""A type alias for X input data type as defined by sklearn.""" + +YLike: TypeAlias = pd.Series | pd.DataFrame | np.ndarray +"""A type alias for y input data type as defined by sklearn.""" + +PostSplitSignature: TypeAlias = Callable[ + [Trial, int, "CVEvaluation.PostSplitInfo"], + "CVEvaluation.PostSplitInfo", +] +"""A type alias for the post split callback signature. + +Please see [`PostSplitInfo`][amltk.sklearn.evaluation.CVEvaluation.PostSplitInfo] +for more information on the information available to this callback. + +```python +def my_post_split( + trial: Trial, + split_number: int, + eval: CVEvalauation.PostSplitInfo +) -> CVEvaluation.PostSplitInfo: + ... +``` +""" + + +def resample_if_minority_class_too_few_for_n_splits( + X_train: pd.DataFrame, # noqa: N803 + y_train: pd.Series, + *, + n_splits: int, + seed: Seed | None = None, + _warning_if_occurs: str | None = None, +) -> tuple[pd.DataFrame, pd.DataFrame | pd.Series]: + """Rebalance the training data to allow stratification. + + If your data only contains something such as 3 labels for a single class, and you + wish to perform 5 fold cross-validation, you will need to rebalance the data to + allow for stratification. This function will take the training data and labels and + and resample the data to allow for stratification. + + Args: + X_train: The training data. + y_train: The training labels. + n_splits: The number of splits to perform. + seed: Used for deciding which instances to resample. + + Returns: + The rebalanced training data and labels. + """ + if y_train.ndim != 1: + raise NotImplementedError( + "Rebalancing for multi-output classification is not yet supported.", + ) + + # If we are in binary/multilclass setting and there is not enough instances + # with a given label to perform stratified sampling with `n_splits`, we first + # find these labels, take the first N instances which have these labels and allows + # us to reach `n_splits` instances for each label. + indices_to_resample = None + label_counts = y_train.value_counts() + under_represented_labels = label_counts[label_counts < n_splits] # type: ignore + + collected_indices = [] + if any(under_represented_labels): + if _warning_if_occurs is not None: + warnings.warn(_warning_if_occurs, UserWarning, stacklevel=2) + under_rep_instances = y_train[y_train.isin(under_represented_labels.index)] # type: ignore + + grouped_by_label = under_rep_instances.to_frame("label").groupby( # type: ignore + "label", + observed=True, # Handles categoricals + ) + for _label, instances_with_label in grouped_by_label: + n_to_take = n_splits - len(instances_with_label) + + need_to_sample_repeatedly = n_to_take > len(instances_with_label) + resampled_instances = instances_with_label.sample( + n=n_to_take, + random_state=seed, # type: ignore + # It could be that we have to repeat sample if there are not enough + # instances to hit `n_splits` for a given label. + replace=need_to_sample_repeatedly, + ) + collected_indices.append(np.asarray(resampled_instances.index)) + + indices_to_resample = np.concatenate(collected_indices) + + if indices_to_resample is not None: + # Give the new samples a new index to not overlap with the original data. + new_start_idx = X_train.index.max() + 1 # type: ignore + new_end_idx = new_start_idx + len(indices_to_resample) + new_idx = pd.RangeIndex(start=new_start_idx, stop=new_end_idx) + resampled_X = X_train.loc[indices_to_resample].set_index(new_idx) + resampled_y = y_train.loc[indices_to_resample].set_axis(new_idx) + X_train = pd.concat([X_train, resampled_X]) + y_train = pd.concat([y_train, resampled_y]) + + return X_train, y_train + + +def _check_valid_scores( + scores: Mapping[str, float] | Number, + split: str, +) -> Mapping[str, float]: + assert isinstance(scores, Mapping) + for k, v in scores.items(): + # Can return list or np.bool_ + # We do not want a list to pass (i.e. if [x] shouldn't pass if check) + # We can't use `np.bool_` is `True` as `np.bool_(True) is not True`. + # Hence we have to use equality checking + # God I feel like I'm doing javascript + if np.isfinite(v) != True: # noqa: E712 + raise ValueError( + f"Scorer {k} returned {v} for {split} split. The scorer should" + " should return a finite float", + ) + + return scores + + +def _route_params( + splitter: BaseShuffleSplit | BaseCrossValidator, + estimator: BaseEstimator, + _scorer: _Scorer | _MultimetricScorer, + **params: Any, +) -> Bunch: + if _routing_enabled(): + # NOTE: This is basically copied out of sklearns 1.4 cross_validate + + # For estimators, a MetadataRouter is created in get_metadata_routing + # methods. For these router methods, we create the router to use + # `process_routing` on it. + router = ( + MetadataRouter(owner="cross_validate") + .add( + splitter=splitter, + method_mapping=MethodMapping().add(caller="fit", callee="split"), + ) + .add( + estimator=estimator, + # TODO(SLEP6): also pass metadata to the predict method for + # scoring? + # ^ Taken from cross_validate source code in sklearn + method_mapping=MethodMapping().add(caller="fit", callee="fit"), + ) + .add( + scorer=_scorer, + method_mapping=MethodMapping().add(caller="fit", callee="score"), + ) + ) + try: + return process_routing(router, "fit", **params) # type: ignore + except UnsetMetadataPassedError as e: + # The default exception would mention `fit` since in the above + # `process_routing` code, we pass `fit` as the caller. However, + # the user is not calling `fit` directly, so we change the message + # to make it more suitable for this case. + raise UnsetMetadataPassedError( + message=( + f"{sorted(e.unrequested_params.keys())} are passed to cross" + " validation but are not explicitly requested or unrequested. See" + " the Metadata Routing User guide" + " for more" + " information." + ), + unrequested_params=e.unrequested_params, + routed_params=e.routed_params, + ) from e + else: + routed_params = Bunch() + routed_params.splitter = Bunch(split={"groups": None}) + routed_params.estimator = Bunch(fit=params) + routed_params.scorer = Bunch(score={}) + return routed_params + + +def _default_holdout( + task_type: TaskTypeName, + holdout_size: float, + *, + random_state: Seed | None = None, +) -> ShuffleSplit | StratifiedShuffleSplit: + if not (0 < holdout_size < 1): + raise ValueError(f"`{holdout_size=}` must be in (0, 1)") + + rs = amltk.randomness.as_int(random_state) + match task_type: + case "binary" | "multiclass": + return StratifiedShuffleSplit(1, random_state=rs, test_size=holdout_size) + case "multilabel-indicator" | "multiclass-multioutput": + return ShuffleSplit(1, random_state=rs, test_size=holdout_size) + case "continuous" | "continuous-multioutput": + return ShuffleSplit(1, random_state=rs, test_size=holdout_size) + case _: + raise ValueError(f"Don't know how to handle {task_type=}") + + +def _default_cv_resampler( + task_type: TaskTypeName, + n_splits: int, + *, + random_state: Seed | None = None, +) -> StratifiedKFold | KFold: + if n_splits < 1: + raise ValueError(f"Must have at least one split but got {n_splits=}") + + rs = amltk.randomness.as_int(random_state) + + match task_type: + case "binary" | "multiclass": + return StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=rs) + case "multilabel-indicator" | "multiclass-multioutput": + # NOTE: They don't natively support multilabel-indicator for stratified + return KFold(n_splits=n_splits, shuffle=True, random_state=rs) + case "continuous" | "continuous-multioutput": + return KFold(n_splits=n_splits, shuffle=True, random_state=rs) + case _: + raise ValueError(f"Don't know how to handle {task_type=} with {n_splits=}") + + +def identify_task_type( # noqa: PLR0911 + y: YLike, + *, + task_hint: Literal["classification", "regression", "auto"] = "auto", +) -> TaskTypeName: + """Identify the task type from the target data.""" + inferred_type: TaskTypeName = type_of_target(y) + if task_hint == "auto": + warnings.warn( + f"`{task_hint=}` was not provided. The task type was inferred from" + f" the target data to be '{inferred_type}'." + " To silence this warning, please provide `task_hint`.", + AutomaticTaskTypeInferredWarning, + stacklevel=2, + ) + return inferred_type + + match task_hint, inferred_type: + # First two cases are everything is fine + case ( + "classification", + "binary" + | "multiclass" + | "multilabel-indicator" + | "multiclass-multioutput", + ): + return inferred_type + case ("regression", "continuous" | "continuous-multioutput"): + return inferred_type + # Hinted to be regression but we got a single column classification task + case ("regression", "binary" | "multiclass"): + warnings.warn( + f"`{task_hint=}` but `{inferred_type=}`." + " Set to `continuous` as there is only one target column.", + MismatchedTaskTypeWarning, + stacklevel=2, + ) + return "continuous" + # Hinted to be regression but we got multi-column classification task + case ("regression", "multilabel-indicator" | "multiclass-multioutput"): + warnings.warn( + f"`{task_hint=}` but `{inferred_type=}`." + " Set to `continuous-multiouput` as there are more than 1 target" + " columns.", + MismatchedTaskTypeWarning, + stacklevel=2, + ) + return "continuous" + # Hinted to be classification but we got a single column regression task + case ("classification", "continuous"): + match len(np.unique(y)): + case 1: + raise ValueError( + "The target data has only one unique value. This is" + f" not a valid classification task.\n{y=}", + ) + case 2: + warnings.warn( + f"`{task_hint=}` but `{inferred_type=}`." + " Set to `binary` as only 2 unique values." + " To silence this, provide a specific task type to" + f"`task_hint=` from {_valid_task_types}.", + MismatchedTaskTypeWarning, + stacklevel=2, + ) + return "binary" + case _: + warnings.warn( + f"`{task_hint=}` but `{inferred_type=}`." + " Set to `multiclass` as >2 unique values." + " To silence this, provide a specific task type to" + f"`task_hint=` from {_valid_task_types}.", + MismatchedTaskTypeWarning, + stacklevel=2, + ) + return "multiclass" + # Hinted to be classification but we got multi-column regression task + case ("classification", "continuous-multioutput"): + # NOTE: this is a matrix wide .unique, I'm not sure how things + # work with multiclass-multioutput and whether it should be + # done by 2 unique per column + uniques_per_col = [np.unique(col) for col in y.T] + binary_columns = all(len(col) <= 2 for col in uniques_per_col) # noqa: PLR2004 + if binary_columns: + warnings.warn( + f"`{task_hint=}` but `{inferred_type=}`." + " Set to `multilabel-indicator` as <=2 unique values per column." + " To silence this, provide a specific task type to" + f"`task_hint=` from {_valid_task_types}.", + MismatchedTaskTypeWarning, + stacklevel=2, + ) + return "multilabel-indicator" + else: # noqa: RET505 + warnings.warn( + f"`{task_hint=}` but `{inferred_type=}`." + " Set to `multiclass-multioutput` as at least one column has" + " >2 unique values." + " To silence this, provide a specific task type to" + f"`task_hint=` from {_valid_task_types}.", + MismatchedTaskTypeWarning, + stacklevel=2, + ) + return "multiclass-multioutput" + case _: + raise RuntimeError( + f"Unreachable, please report this bug. {task_hint=}, {inferred_type=}", + ) + + +def _fit( + estimator: BaseEstimatorT, + X: XLike, # noqa: N803 + y: YLike, + i_train: np.ndarray, + *, + profiler: Profiler, + fit_params: Mapping[str, Any], + train_score: bool, + scorers: dict[str, _Scorer] | _MultimetricScorer, + scorer_params: Mapping[str, Any], +) -> tuple[BaseEstimatorT, dict[str, Any], Mapping[str, float] | None, dict[str, Any]]: + _fit_params = _check_method_params(X, params=fit_params, indices=i_train) + X_train, y_train = _safe_split(estimator, X, y, indices=i_train) + + with profiler("fit"): + if y_train is None: + estimator.fit(X_train, **_fit_params) # type: ignore + else: + estimator.fit(X_train, y_train, **_fit_params) # type: ignore + + train_scores = None + train_scorer_params: dict[str, Any] = {} + if train_score is True: + scorer_params_train = _check_method_params(X, scorer_params, indices=i_train) + + with profiler("train_score"): + train_scores = _score( + estimator=estimator, + X_test=X_train, + y_test=y_train, + scorer=scorers, + score_params=scorer_params_train, + error_score="raise", + ) + train_scores = _check_valid_scores(train_scores, split="train") + + return estimator, _fit_params, train_scores, train_scorer_params + + +def _val_score( + fitted_estimator: BaseEstimator, + X: XLike, # noqa: N803 + y: YLike, + i_train: np.ndarray, + i_val: np.ndarray, + *, + profiler: Profiler, + scorers: dict[str, _Scorer] | _MultimetricScorer, + scorer_params: Mapping[str, Any], +) -> tuple[Mapping[str, float], dict[str, Any]]: + scorer_params_val = _check_method_params(X, scorer_params, indices=i_val) + X_val, y_val = _safe_split( + fitted_estimator, + X, + y, + indices=i_val, + train_indices=i_train, + ) + with profiler("score"): + val_scores = _score( + estimator=fitted_estimator, + X_test=X_val, + y_test=y_val, + scorer=scorers, + score_params=scorer_params_val, + error_score="raise", + ) + val_scores = _check_valid_scores(val_scores, split="val") + + return val_scores, scorer_params_val + + +def _evaluate_split( # noqa: PLR0913 + estimator: BaseEstimator, + X: XLike, # noqa: N803 + y: YLike, + *, + X_test: XLike | None = None, # noqa: N803 + y_test: YLike | None = None, + i_train: np.ndarray, + i_val: np.ndarray, + profiler: Profiler, + scorers: _MultimetricScorer, + fit_params: Mapping[str, Any], + scorer_params: Mapping[str, Any], + test_scorer_params: Mapping[str, Any], + train_score: bool, +) -> CVEvaluation.PostSplitInfo: + # We shove all logic that requires indexing into X for train into `_fit`. + # This is because it's easy to create an accidental copy, i.e. with _safe_split. + # We want that memory to only exist inside that `_fit` function and to not persists + # throughout the body here. + fitted_estimator, fitting_params, train_scores, train_scorer_params = _fit( + estimator=clone(estimator), # type: ignore + X=X, + y=y, + i_train=i_train, + profiler=profiler, + scorers=scorers, + fit_params=fit_params, + train_score=train_score, + scorer_params=scorer_params, + ) + + val_scores, val_scorer_params = _val_score( + fitted_estimator=fitted_estimator, + X=X, + y=y, + i_train=i_train, + i_val=i_val, + profiler=profiler, + scorers=scorers, + scorer_params=scorer_params, + ) + + test_scores = None + if X_test is not None and y_test is not None: + with profiler("test_score"): + test_scores = _score( + estimator=fitted_estimator, + X_test=X_test, + y_test=y_test, + scorer=scorers, + score_params=test_scorer_params, + error_score="raise", + ) + test_scores = _check_valid_scores(test_scores, split="test") + + return CVEvaluation.PostSplitInfo( + X=X, + y=y, + X_test=X_test, + y_test=y_test, + i_train=i_train, + i_val=i_val, + model=fitted_estimator, + train_scores=train_scores, + val_scores=val_scores, + test_scores=test_scores, + fitting_params=fitting_params, + train_scorer_params=train_scorer_params, + val_scorer_params=val_scorer_params, + test_scorer_params=test_scorer_params, + ) + + +def _iter_cross_validate( # noqa: PLR0913 + estimator: BaseEstimator, + X: XLike, # noqa: N803 + y: YLike, + splitter: BaseShuffleSplit | BaseCrossValidator, + scorers: Mapping[str, _Scorer], + *, + X_test: XLike | None = None, # noqa: N803 + y_test: YLike | None = None, + fit_params: Mapping[str, Any] | None = None, + scorer_params: Mapping[str, Any] | None = None, + splitter_params: Mapping[str, Any] | None = None, + test_scorer_params: Mapping[str, Any] | None = None, + profiler: Profiler | None = None, + train_score: bool = False, +) -> Iterator[CVEvaluation.PostSplitInfo]: + if (X_test is not None and y_test is None) or ( + y_test is not None and X_test is None + ): + raise ValueError( + "Both `X_test`, `y_test` must be provided together if one is provided.", + ) + + profiler = Profiler(disabled=True) if profiler is None else profiler + + fit_params = fit_params if fit_params is not None else {} + scorer_params = scorer_params if scorer_params is not None else {} + splitter_params = splitter_params if splitter_params is not None else {} + test_scorer_params = test_scorer_params if test_scorer_params is not None else {} + + for i_train, i_val in splitter.split(X, y, **splitter_params): + # Sadly this function needs the full X and y due to its internal checks + yield _evaluate_split( + estimator=estimator, + X=X, + y=y, + X_test=X_test, + y_test=y_test, + i_train=i_train, + i_val=i_val, + profiler=profiler, + scorers=_MultimetricScorer(scorers=scorers, raise_exc=True), + fit_params=fit_params, + scorer_params=scorer_params, + test_scorer_params=test_scorer_params, + train_score=train_score, + ) + + +def cross_validate_task( # noqa: D103, C901, PLR0915, PLR0913 + trial: Trial, + pipeline: Node, + *, + X: Stored[XLike], # noqa: N803 + y: Stored[YLike], + X_test: Stored[XLike] | None = None, # noqa: N803 + y_test: Stored[YLike] | None = None, + splitter: BaseShuffleSplit | BaseCrossValidator, + additional_scorers: Mapping[str, _Scorer] | None, + train_score: bool = False, + store_models: bool = True, + params: MutableMapping[str, Stored[Any] | Any] | None = None, + on_error: Literal["fail", "raise"] = "fail", + comm: Comm | None = None, + post_split: PostSplitSignature | None = None, + post_processing: ( + Callable[[Trial.Report, Node, CVEvaluation.CompleteEvalInfo], Trial.Report] + | None + ) = None, + post_processing_requires_models: bool = False, +) -> Trial.Report: + params = {} if params is None else params + # Make sure to load all the stored values + + configure_params = params.pop("configure", {}) + if not isinstance(configure_params, MutableMapping): + raise ValueError( + f"Expected `params['configure']` to be a dict but got {configure_params=}", + ) + + if "random_state" in configure_params: + raise ValueError( + "You should not provide `'random_state'` in `params['configure']`" + " as the seed is set by the optimizer.", + ) + random_state = amltk.randomness.as_randomstate(trial.seed) + configure_params["random_state"] = random_state + + build_params = params.pop("build", {"builder": "sklearn"}) # type: ignore + if not isinstance(build_params, MutableMapping): + raise ValueError( + f"Expected `params['build']` to be a dict but got {build_params=}", + ) + + transform_context = params.pop("transform_context", None) # type: ignore + + configured_pipeline = pipeline.configure( + trial.config, + transform_context=transform_context, + params=configure_params, + ) + estimator = configured_pipeline.build(**build_params) + + scorers: dict[str, _Scorer] = {} + for metric_name, metric in trial.metrics.items(): + match metric.fn: + case None: + try: + scorer = get_scorer(metric_name) + scorers[metric_name] = scorer + except ValueError as e: + raise ValueError( + f"Could not find scorer for {metric_name=} in sklearn." + " Please provide one with `Metric(fn=...)` or a valid" + " name that can be used with sklearn's `get_scorer`", + ) from e + case _Scorer(): # type: ignore + scorers[metric_name] = metric.fn + case _: + # We do a best effort here and try to convert the metric to + # an sklearn scorer. + warnings.warn( + f"Found a metric with a custom function for {metric_name=}." + " Attempting to convert it to an sklearn scorer. This may" + " fail. If it does, please first your function to an sklearn" + " scorer with `sklearn.metrics.make_scorer` and then pass" + " it to `Metric(fn=...)`", + ImplicitMetricConversionWarning, + stacklevel=2, + ) + # This may fail + scorers[metric_name] = metric.as_scorer() + + if additional_scorers is not None: + scorers.update(additional_scorers) + + _X = X.load() + _y = y.load() + _X_test = X_test.load() if X_test is not None else None + _y_test = y_test.load() if y_test is not None else None + + n_splits = splitter.get_n_splits() + if n_splits is None: + raise NotImplementedError("Needs to be handled") + + loaded_params: dict[str, Any] = { + k: v.load() if isinstance(v, Stored) else v for k, v in params.items() + } + + # Unfortunatly there's two things that can happen here. + # 1. The scorer requires does not require split specific param data (e.g. pos_label) + # * In this case, the param['pos_label'] can be used for train/val and test + # 2. The scorer requires required split specific param data (e.g. sample_weight) + # * In this case, we use the split indices to select the part of + # `params['sample_weight']` that is required for the repsective train/val split. + # * This means we can not use `params['sample_weight']` for the test split, as + # this would require some odd hack of concatenating them and having seperate + # test indices passed in by the user, a pretty dumb idea. + # + # The easy workaround is to have the user provide `test_{key}` for something + # like `params['test_sample_weight']`, which we can then use for the test split. + # However this breaks the metadata routing, which introspects the objects as + # sees that yes, indeed something has requested `sample_weight` but nothing + # has requested `test_sample_weight`. Worse still, we would need to pass + # `params['test_sample_weight']` to the `sample_weight=` parameter of scorer. + # + # Our workaround is to have users provide `test_{key}` for all the scorer params + # which we pop into a new dict with just `{key}`, where the `test_` prefix has been + # removed. The router will never see this dict. + # + # As an important caveats: + # * We assume all keys prefixed with `test_` are scorer params. + # * Things like `pos_label` which are split agnostic needs to be + # provided twice, once as `pos_label` and once as `test_pos_label`, such that + # the scores in test recieve th params. + test_scorer_params = { + k: v + for k in list(loaded_params) + if (v := loaded_params.pop(f"test_{k}", None)) is not None + } + + # We've now popped out all the test params, so we can safely call + # to `_route_params` without it complaining that nothing has requested `test_{key}` + + # NOTE: This flow adapted from sklearns 1.4 cross_validate + # This scorer is only created for routing purposes + multimetric_scorer = _MultimetricScorer(scorers=scorers, raise_exc=True) + routed_params = _route_params( + splitter=splitter, + estimator=estimator, + _scorer=multimetric_scorer, + **loaded_params, + ) + + fit_params = routed_params["estimator"]["fit"] + scorer_params = routed_params["scorer"]["score"] + splitter_params = routed_params["splitter"]["split"] + + cv_iter = _iter_cross_validate( + estimator=estimator, + X=_X, + y=_y, + X_test=_X_test, + y_test=_y_test, + splitter=splitter, + scorers=scorers, + profiler=trial.profiler, + train_score=train_score, + fit_params=fit_params, + scorer_params=scorer_params, + splitter_params=splitter_params, + test_scorer_params=test_scorer_params, + ) + + split_scores = CVEvaluation.SplitScores( + val=defaultdict(list), + train=defaultdict(list) if train_score else None, + test=defaultdict(list) if X_test is not None else None, + ) + models: list[BaseEstimator] | None = ( + None if not post_processing_requires_models else [] + ) + + with comm.open() if comm is not None else nullcontext(): + try: + # Open up comms if passed in, allowing for the cv early stopping mechanism + # to communicate back to the main process + # Main cv loop + for i, _split_eval in trial.profiler.each( + enumerate(cv_iter), + name="cv", + itr_name="split", + ): + if post_split is not None: + split_eval = post_split(trial, i, _split_eval) + else: + split_eval = _split_eval + + # Update the report + if store_models: + trial.store({f"model_{i}.pkl": split_eval.model}) + + trial.summary.update( + {f"split_{i}:val_{k}": v for k, v in split_eval.val_scores.items()}, + ) + for k, v in split_eval.val_scores.items(): + split_scores.val[k].append(v) + + if split_eval.train_scores is not None: + trial.summary.update( + { + f"split_{i}:train_{k}": v + for k, v in split_eval.train_scores.items() + }, + ) + for k, v in split_eval.train_scores.items(): + split_scores.train[k].append(v) # type: ignore + + if split_eval.test_scores is not None: + trial.summary.update( + { + f"split_{i}:test_{k}": v + for k, v in split_eval.test_scores.items() + }, + ) + for k, v in split_eval.test_scores.items(): + split_scores.test[k].append(v) # type: ignore + + if post_processing_requires_models: + assert models is not None + models.append(split_eval.model) + + # At this point, we wish to remove the split_eval object from memory + # if possible. This doesn't actually clean up memory but marks it + # as being viable for garbage collection. + del split_eval + + # If there was a comm passed, we are operating under cv early stopping + # mode, in which case we request information from the main process, + # should we continue or stop? + if comm is not None and i < n_splits: + match response := comm.request( + (trial, split_scores), + timeout=10, + ): + case True: + raise CVEarlyStoppedError("Early stopped!") + case False: + pass + case np.bool_(): + if bool(response) is True: + raise CVEarlyStoppedError("Early stopped!") + case Exception(): + raise response + case _: + raise RuntimeError( + f"Recieved {response=} which we can't handle." + " Please return `True`, `False` or an `Exception`" + f" and not a type {type(response)=}", + ) + + except Exception as e: # noqa: BLE001 + trial.dump_exception(e) + report = trial.fail(e) + if on_error == "raise": + raise TrialError(f"Trial failed: {report}") from e + + if post_processing is not None: + final_eval_info = CVEvaluation.CompleteEvalInfo( + X=_X, + y=_y, + X_test=_X_test, + y_test=_y_test, + splitter=splitter, + max_splits=n_splits, + scores=split_scores, + scorers=scorers, + models=models, + splitter_params=splitter_params, + fit_params=fit_params, + scorer_params=scorer_params, + test_scorer_params=test_scorer_params, + ) + report = post_processing(report, pipeline, final_eval_info) + + return report + else: + for mname, fold_scores in split_scores.val.items(): + trial.summary[f"val_mean_{mname}"] = float(np.mean(fold_scores)) + trial.summary[f"val_std_{mname}"] = float(np.std(fold_scores)) + + if split_scores.train is not None: + for mname, fold_scores in split_scores.train.items(): + trial.summary[f"train_mean_{mname}"] = float(np.mean(fold_scores)) + trial.summary[f"train_std_{mname}"] = float(np.std(fold_scores)) + + if split_scores.test is not None: + for mname, fold_scores in split_scores.test.items(): + trial.summary[f"test_mean_{mname}"] = float(np.mean(fold_scores)) + trial.summary[f"test_std_{mname}"] = float(np.std(fold_scores)) + + means_to_report = {k: trial.summary[f"val_mean_{k}"] for k in trial.metrics} + report = trial.success(**means_to_report) + + if post_processing is not None: + final_eval_info = CVEvaluation.CompleteEvalInfo( + X=_X, + y=_y, + X_test=_X_test, + y_test=_y_test, + splitter=splitter, + max_splits=n_splits, + scores=split_scores, + scorers=scorers, + models=models, + splitter_params=splitter_params, + fit_params=fit_params, + scorer_params=scorer_params, + test_scorer_params=test_scorer_params, + ) + report = post_processing(report, pipeline, final_eval_info) + + return report + + +class CVEvaluation(Emitter): + """Cross-validation evaluation protocol. + + This protocol will create a cross-validation task to be used in parallel and + optimization. It represents a typical cross-validation evaluation for sklearn. + + Aside from the init parameters, it expects: + * The pipeline you are optimizing can be made into a [sklearn.pipeline.Pipeline][] + calling [`.build("sklearn")`][amltk.pipeline.Node.build]. + * The seed for the trial will be passed as a param to + [`.configure()`][amltk.pipeline.Node.configure]. If you have a component + that accepts a `random_state` parameter, you can use a + [`request()`][amltk.pipeline.request] so that it will be seeded correctly. + + ```python exec="true" source="material-block" result="python" + from amltk.sklearn import CVEvaluation + from amltk.pipeline import Component, request + from amltk.optimization import Metric + + from sklearn.ensemble import RandomForestClassifier + from sklearn.metrics import get_scorer + from sklearn.datasets import load_iris + from pathlib import Path + + pipeline = Component( + RandomForestClassifier, + config={"random_state": request("random_state")}, + space={"n_estimators": (10, 100), "criterion": ["gini", "entropy"]}, + ) + + working_dir = Path("./some-path") + X, y = load_iris(return_X_y=True) + evaluator = CVEvaluation( + X, + y, + n_splits=3, + splitter="cv", + additional_scorers={"roc_auc": get_scorer("roc_auc_ovr")}, + store_models=False, + train_score=True, + working_dir=working_dir, + ) + + history = pipeline.optimize( + target=evaluator.fn, + metric=Metric("accuracy", minimize=False, bounds=(0, 1)), + working_dir=working_dir, + max_trials=1, + ) + print(history.df()) + evaluator.bucket.rmdir() # Cleanup + ``` + + If you need to pass specific configuration items to your pipeline during + configuration, you can do so using a [`request()`][amltk.pipeline.request] + in the config of your pipeline. + + In the below example, we allow the pipeline to be configured with `"n_jobs"` + and pass it in to the `CVEvalautor` using the `params` argument. + + ```python exec="true" source="material-block" result="python" + from amltk.sklearn import CVEvaluation + from amltk.pipeline import Component, request + from amltk.optimization import Metric + + from sklearn.ensemble import RandomForestClassifier + from sklearn.metrics import get_scorer + from sklearn.datasets import load_iris + from pathlib import Path + + working_dir = Path("./some-path") + X, y = load_iris(return_X_y=True) + + pipeline = Component( + RandomForestClassifier, + config={ + "random_state": request("random_state"), + # Allow it to be configured with n_jobs + "n_jobs": request("n_jobs", default=None) + }, + space={"n_estimators": (10, 100), "criterion": ["gini", "entropy"]}, + ) + + evaluator = CVEvaluation( + X, + y, + working_dir=working_dir, + # Use the `configure` keyword in params to pass to the `n_jobs` + # Anything in the pipeline requesting `n_jobs` will get the value + params={"configure": {"n_jobs": 2}} + ) + history = pipeline.optimize( + target=evaluator.fn, + metric=Metric("accuracy"), + working_dir=working_dir, + max_trials=1, + ) + print(history.df()) + evaluator.bucket.rmdir() # Cleanup + ``` + + !!! tip "CV Early Stopping" + + To see more about early stopping, please see + [`CVEvaluation.cv_early_stopping_plugin()`][amltk.sklearn.evaluation.CVEvaluation.cv_early_stopping_plugin]. + + """ + + SPLIT_EVALUATED: Event[[Trial, SplitScores], bool | Exception] = Event( + "split-evaluated", + ) + """Event that is emitted when a split has been evaluated. + + Only emitted if the evaluator plugin is being used. + """ + + TMP_DIR_PREFIX: ClassVar[str] = "amltk-sklearn-cv-evaluation-data-" + """Prefix for temporary directory names. + + This is only used when `working_dir` is not specified. If not specified + you can control the tmp dir location by setting the `TMPDIR` + environment variable. By default this is `/tmp`. + + When using a temporary directory, it will be deleted by default, + controlled by the `delete_working_dir=` argument. + """ + + _X_FILENAME: ClassVar[str] = "X.pkl" + """The name of the file to store the features in.""" + + _X_TEST_FILENAME: ClassVar[str] = "X_test.pkl" + """The name of the file to store the test features in.""" + + _Y_FILENAME: ClassVar[str] = "y.pkl" + """The name of the file to store the targets in.""" + + _Y_TEST_FILENAME: ClassVar[str] = "y_test.pkl" + """The name of the file to store the test targets in.""" + + PARAM_EXTENSION_MAPPING: ClassVar[dict[type[Sized], str]] = { + np.ndarray: "npy", + pd.DataFrame: "pdpickle", + pd.Series: "pdpickle", + } + """The mapping from types to extensions in + [`params`][amltk.sklearn.evaluation.CVEvaluation.params]. + + If the parameter is an instance of one of these types, and is larger than + [`LARGE_PARAM_HEURISTIC`][amltk.sklearn.evaluation.CVEvaluation.LARGE_PARAM_HEURISTIC], + then it will be stored to disk and loaded back up in the task. + + Please feel free to overwrite this class variable as needed. + """ + + LARGE_PARAM_HEURISTIC: ClassVar[int] = 100 + """Any item in `params=` which is greater will be stored to disk when sent to the + worker. + + When launching tasks, pickling and streaming large data to tasks can be expensive. + This parameter checks if the object is large and if so, stores it to disk and + gives it to the task as a [`Stored`][amltk.store.stored.Stored] object instead. + + Please feel free to overwrite this class variable as needed. + """ + + task_type: TaskTypeName + """The inferred task type.""" + + additional_scorers: Mapping[str, _Scorer] | None + """Additional scorers that will be used.""" + + bucket: PathBucket + """The bucket to use for storing data. + + For cleanup, you can call + [`bucket.rmdir()`][amltk.store.paths.path_bucket.PathBucket.rmdir]. + """ + + splitter: BaseShuffleSplit | BaseCrossValidator + """The splitter that will be used.""" + + params: Mapping[str, Any | Stored[Any]] + """Parameters to pass to the estimator, splitter or scorers. + + Please see https://scikit-learn.org/stable/metadata_routing.html for + more. + """ + + store_models: bool + """Whether models will be stored in the trial.""" + + train_score: bool + """Whether scores will be calculated on the training data as well.""" + + X_stored: Stored[XLike] + """The stored features. + + You can call [`.load()`][amltk.store.stored.Stored.load] to load the + data. + """ + + y_stored: Stored[YLike] + """The stored target. + + You can call [`.load()`][amltk.store.stored.Stored.load] to load the + data. + """ + + class PostSplitInfo(NamedTuple): + """Information about the evaluation of a split. + + Attributes: + X: The features to used for training. + y: The targets used for training. + X_test: The features used for testing if it was passed in. + y_test: The targets used for testing if it was passed in. + i_train: The train indices for this split. + i_val: The validation indices for this split. + model: The model that was trained in this split. + train_scores: The training scores for this split if requested. + val_scores: The validation scores for this split. + test_scores: The test scores for this split if requested. + fitting_params: Any additional fitting parameters that were used. + train_scorer_params: Any additional scorer parameters used for evaluating + scorers on training set. + val_scorer_params: Any additional scorer parameters used for evaluating + scorers on training set. + test_scorer_params: Any additional scorer parameters used for evaluating + scorers on training set. + """ + + X: XLike + y: YLike + X_test: XLike | None + y_test: YLike | None + i_train: np.ndarray + i_val: np.ndarray + model: BaseEstimator + val_scores: Mapping[str, float] + train_scores: Mapping[str, float] | None + test_scores: Mapping[str, float] | None + fitting_params: Mapping[str, Any] + train_scorer_params: Mapping[str, Any] + val_scorer_params: Mapping[str, Any] + test_scorer_params: Mapping[str, Any] + + class SplitScores(NamedTuple): + """The scores for a split. + + Attributes: + val: The validation scores for all evaluated split. + train: The training scores for all evaluated splits if requested. + test: The test scores for all evaluated splits if requested. + """ + + val: Mapping[str, list[float]] + train: Mapping[str, list[float]] | None + test: Mapping[str, list[float]] | None + + @dataclass + class CompleteEvalInfo: + """Information about the final evaluation of a cross-validation task. + + This class contains information about the final evaluation of a cross-validation + that will be passed to the post-processing function. + """ + + X: XLike + """The features to used for training.""" + + y: YLike + """The targets used for training.""" + + X_test: XLike | None + """The features used for testing.""" + + y_test: YLike | None + """The targets used for testing.""" + + splitter: BaseShuffleSplit | BaseCrossValidator + """The splitter that was used.""" + + max_splits: int + """The maximum number of splits that were (or could have been) evaluated.""" + + scores: CVEvaluation.SplitScores + """The scores for the splits that were evaluated.""" + + scorers: dict[str, _Scorer] + """The scorers that were used.""" + + models: list[BaseEstimator] | None + """The models that were trained in each split. + + This will be `None` if `post_processing_requires_models=False`. + """ + + splitter_params: Mapping[str, Any] + """The parameters that were used for the splitter.""" + + fit_params: Mapping[str, Any] + """The parameters that were used for fitting the estimator. + + Please use + [`select_params()`][amltk.sklearn.evaluation.CVEvaluation.CompleteEvalInfo.select_params] + if you need to select the params specific to a split, i.e. for `sample_weights`. + """ + + scorer_params: Mapping[str, Any] + """The parameters that were used for scoring the estimator. + + Please use + [`select_params()`][amltk.sklearn.evaluation.CVEvaluation.CompleteEvalInfo.select_params] + if you need to select the params specific to a split, i.e. for `sample_weights`. + """ + + test_scorer_params: Mapping[str, Any] + """The parameters that were used for scoring the test data. + + Please use + [`select_params()`][amltk.sklearn.evaluation.CVEvaluation.CompleteEvalInfo.select_params] + if you need to select the params specific to a split, i.e. for `sample_weights`. + """ + + # TODO: We don't use `test_fit_params` in our evaluator but someone could + # potentially need it here. Fix if it becomes a problem... + + def select_params( + self, + params: Mapping[str, Any], + indices: np.ndarray, + ) -> dict[str, Any]: + """Convinience method to select parameters for a specific split.""" + return _check_method_params(self.X, params, indices=indices) + + class _CVEarlyStoppingPlugin(Plugin): + name: ClassVar[str] = "cv-early-stopping-plugin" + + def __init__( + self, + evaluator: CVEvaluation, + *, + strategy: CVEarlyStoppingProtocol | None = None, + create_comms: Callable[[], tuple[Comm, Comm]] | None = None, + ) -> None: + super().__init__() + self.evaluator = evaluator + self.strategy = strategy + self.comm_plugin = Comm.Plugin( + create_comms=create_comms, + parameter_name="comm", + ) + + @override + def pre_submit( + self, + fn: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> tuple[Callable[P, R], tuple, dict] | None: + return self.comm_plugin.pre_submit(fn, *args, **kwargs) + + @override + def attach_task(self, task: Task) -> None: + """Attach the plugin to a task. + + This method is called when the plugin is attached to a task. This + is the place to subscribe to events on the task, create new subscribers + for people to use or even store a reference to the task for later use. + + Args: + task: The task the plugin is being attached to. + """ + self.task = task + self.comm_plugin.attach_task(task) + task.add_event(CVEvaluation.SPLIT_EVALUATED) + task.register(Comm.REQUEST, self._on_comm_request_ask_whether_to_continue) + if self.strategy is not None: + task.register(self.evaluator.SPLIT_EVALUATED, self.strategy.should_stop) + task.register(task.RESULT, self._call_strategy_update) + + def _call_strategy_update(self, _: Future, report: Trial.Report) -> None: + if self.strategy is not None: + self.strategy.update(report) + + def _on_comm_request_ask_whether_to_continue(self, msg: Comm.Msg) -> None: + if not (isinstance(msg.data, tuple) and len(msg.data) == 2): # noqa: PLR2004 + return + + trial, scores = msg.data + if not ( + isinstance(trial, Trial) + and isinstance(scores, CVEvaluation.SplitScores) + ): + return + + non_null_responses = [ + r + for _, r in self.task.emit( + self.evaluator.SPLIT_EVALUATED, + trial, + scores, + ) + if r is not None + ] + logger.debug( + f"Received responses for {self.evaluator.SPLIT_EVALUATED}:" + f" {non_null_responses}", + ) + match len(non_null_responses): + case 0: + msg.respond(response=False) + case 1: + msg.respond(response=non_null_responses[0]) + case _ if all_equal(non_null_responses): + msg.respond(response=non_null_responses[0]) + case _: + raise NotImplementedError( + "Multiple callbacks returned different values." + " Behaviour is undefined. Please aggregate behaviour" + " into one callback. Also please raise an issue to" + " discuss use cases and how to handle this.", + ) + + def __init__( # noqa: PLR0913, C901 + self, + X: XLike, # noqa: N803 + y: YLike, + *, + X_test: XLike | None = None, # noqa: N803 + y_test: YLike | None = None, + splitter: ( + Literal["holdout", "cv"] | BaseShuffleSplit | BaseCrossValidator + ) = "cv", + n_splits: int = 5, # sklearn default + holdout_size: float = 0.33, + train_score: bool = False, + store_models: bool = False, + rebalance_if_required_for_stratified_splitting: bool | None = None, + additional_scorers: Mapping[str, _Scorer] | None = None, + random_state: Seed | None = None, # Only used if cv is an int/float + params: Mapping[str, Any] | None = None, + task_hint: ( + TaskTypeName | Literal["classification", "regression", "auto"] + ) = "auto", + working_dir: str | Path | PathBucket | None = None, + on_error: Literal["raise", "fail"] = "fail", + post_split: PostSplitSignature | None = None, + post_processing: ( + Callable[[Trial.Report, Node, CVEvaluation.CompleteEvalInfo], Trial.Report] + | None + ) = None, + post_processing_requires_models: bool = False, + ) -> None: + """Initialize the evaluation protocol. + + Args: + X: The features to use for training. + y: The target to use for training. + X_test: The features to use for testing. If provided, all + scorers will be calculated on this data as well. + Must be provided with `y_test=`. + + !!! tip "Scorer params for test scoring" + + Due to nuances of sklearn's metadata routing, if you need to provide + parameters to the scorer for the test data, you can prefix these + with `#!python "test_"`. For example, if you need to provide + `pos_label` to the scorer for the test data, you must provide + `test_pos_label` in the `params` argument. + + y_test: The target to use for testing. If provided, all + scorers will be calculated on this data as well. + Must be provided with `X_test=`. + splitter: The cross-validation splitter to use. This can be either + `#!python "holdout"` or `#!python "cv"`. Please see the related + arguments below. If a scikit-learn cross-validator is provided, + this will be used directly. + n_splits: The number of cross-validation splits to use. + This argument will be ignored if `#!python splitter="holdout"` + or a custom splitter is provided for `splitter=`. + holdout_size: The size of the holdout set to use. This argument + will be ignored if `#!python splitter="cv"` or a custom splitter + is provided for `splitter=`. + train_score: Whether to score on the training data as well. This + will take extra time as predictions will be made on the + training data as well. + store_models: Whether to store the trained models in the trial. + rebalance_if_required_for_stratified_splitting: Whether the CVEvaluator + should rebalance the training data to allow for stratified splitting. + * If `True`, rebalancing will be done if required. That is when + the `splitter=` is `"cv"` or a `StratifiedKFold` and + there are fewer instances of a minority class than `n_splits=`. + * If `None`, rebalancing will be done if required it. Same + as `True` but raises a warning if it occurs. + * If `False`, rebalancing will never be done. + additional_scorers: Additional scorers to use. + random_state: The random state to use for the cross-validation + `splitter=`. If a custom splitter is provided, this will be + ignored. + params: Parameters to pass to the estimator, splitter or scorers. + See https://scikit-learn.org/stable/metadata_routing.html for + more information. + + You may also additionally include the following as dictionarys: + + * `#!python "configure"`: Parameters to pass to the pipeline + for [`configure()`][amltk.pipeline.Node.configure]. Please + the example in the class docstring for more information. + * `#!python "build"`: Parameters to pass to the pipeline for + [`build()`][amltk.pipeline.Node.build]. + + ```python + from imblearn.pipeline import Pipeline as ImbalancedPipeline + CVEvaluator( + ..., + params={ + "build": { + "builder": "sklearn", + "pipeline_type": ImbalancedPipeline + } + } + ) + ``` + + * `#!python "transform_context"`: The transform context to use + for [`configure()`][amltk.pipeline.Node.configure]. + + !!! tip "Scorer params for test scoring" + + Due to nuances of sklearn's metadata routing, if you need to provide + parameters to the scorer for the test data, you must prefix these + with `#!python "test_"`. For example, if you need to provide + `pos_label` to the scorer for the test data, you can provide + `test_pos_label` in the `params` argument. + + task_hint: A string indicating the task type matching those + use by sklearn's `type_of_target`. This can be either + `#!python "binary"`, `#!python "multiclass"`, + `#!python "multilabel-indicator"`, `#!python "continuous"`, + `#!python "continuous-multioutput"` or + `#!python "multiclass-multioutput"`. + + You can also provide `#!python "classification"` or + `#!python "regression"` for a more general hint. + + If not provided, this will be inferred from the target data. + If you know this value, it is recommended to provide it as + sometimes the target is ambiguous and sklearn may infer + incorrectly. + working_dir: The directory to use for storing data. If not provided, + a temporary directory will be used. If provided as a string + or a `Path`, it will be used as the path to the directory. + on_error: What to do if an error occurs in the task. This can be + either `#!python "raise"` or `#!python "fail"`. If `#!python "raise"`, + the error will be raised and the task will fail. If `#!python "fail"`, + the error will be caught and the task will report a failure report + with the error message stored inside. + Set this to `#!python "fail"` if you want to continue optimization + even if some trials fail. + post_split: If provided, this callable will be called with a + [`PostSplitInfo`][amltk.sklearn.evaluation.CVEvaluation.PostSplitInfo]. + + For example, this could be useful if you'd like to save out-of-fold + predictions for later use. + + ```python + def my_post_split( + split_number: int, + info: CVEvaluator.PostSplitInfo, + ) -> None: + X_val, y_val = info.val + oof_preds = fitted_model.predict(X_val) + + split = info.current_split + info.trial.store({f"oof_predictions_{split}.npy": oof_preds}) + return info + ``` + + !!! warning "Run in the worker" + + This callable will be pickled and sent to the worker that is + executing an evaluation. This means that you should mitigate + relying on any large objects if your callalbe is an object, as + the object will get pickled and sent to the worker. This also means + you can not rely on information obtained from other trials as when + sending the callable to a worker, it is no longer updatable from the + main process. + + You should also avoid holding on to references to either the model + or large data that is passed in + [`PostSplitInfo`][amltk.sklearn.evaluation.CVEvaluation.PostSplitInfo] + to the function. + + This parameter should primarily be used for callables that rely + solely on the output of the current trial and wish to store/add + additional information to the trial itself. + + post_processing: If provided, this callable will be called with all of the + evaluated splits and the final report that will be returned. + This can be used to do things such as augment the final scores + if required, cleanup any resources or any other tasks that should be + run after the evaluation has completed. This will be handed a + [`Report`][amltk.optimization.trial.Trial.Report] and a + [`CompleteEvalInfo`][amltk.sklearn.evaluation.CVEvaluation.CompleteEvalInfo], + which contains all the information about the evaluation. If your + function requires the individual models, you can set + `post_processing_requires_models=True`. By default this is `False` + as this requires having all models in memory at once. + + This can be useful when you'd like to report the score of a bagged + model, i.e. an ensemble of all validation models. Another example + is if you'd like to add to the summary, the score of what the model + would be if refit on all the data. + + ```python + from amltk.sklearn.voting import voting_with_prefitted_estimators + + # Compute the test score of all fold models bagged together + def my_post_processing( + report: Trial.Report, + pipeline: Node, + info: CVEvaluator.CompleteEvalInfo, + ) -> Trial.Report: + bagged_model = voting_with_prefitted_estimators(info.models) + acc = info.scorers["accuracy"] + bagged_score = acc(bagged_model, info.X_test, info.y_test) + report.summary["bagged_test_score"] = bagged_score + return report + ``` + + !!! warning "Run in the worker" + + This callable will be pickled and sent to the worker that is + executing an evaluation. This means that you should mitigate + relying on any large objects if your callalbe is an object, as + the object will get pickled and sent to the worker. This also means + you can not rely on information obtained from other trials as when + sending the callable to a worker, it is no longer updatable from the + main process. + + This parameter should primarily be used for callables that will + augment the report or what is stored with the trial. It should + rely solely on the current trial to prevent unexpected issues. + + post_processing_requires_models: Whether the `post_processing` function + requires the models to be passed to it. If `True`, the models will + be passed to the function in the `CompleteEvalInfo` object. If `False`, + the models will not be passed to the function. By default this is + `False` as this requires having all models in memory at once. + + """ + super().__init__() + if (X_test is not None and y_test is None) or ( + y_test is not None and X_test is None + ): + raise ValueError( + "Both `X_test`, `y_test` must be provided together if one is provided.", + ) + + match working_dir: + case None: + tmpdir = Path( + tempfile.mkdtemp( + prefix=self.TMP_DIR_PREFIX, + suffix=datetime.now().isoformat(), + ), + ) + bucket = PathBucket(tmpdir) + case str() | Path(): + bucket = PathBucket(working_dir) + case PathBucket(): + bucket = working_dir + + match task_hint: + case "classification" | "regression" | "auto": + task_type = identify_task_type(y, task_hint=task_hint) + case ( + "binary" + | "multiclass" + | "multilabel-indicator" + | "continuous" + | "continuous-multioutput" + | "multiclass-multioutput" # + ): + task_type = task_hint + case _: + raise ValueError( + f"Invalid {task_hint=} provided. Must be in {_valid_task_types}" + f"\n{type(task_hint)=}", + ) + + match splitter: + case "cv": + splitter = _default_cv_resampler( + task_type, + n_splits=n_splits, + random_state=random_state, + ) + + case "holdout": + splitter = _default_holdout( + task_type, + holdout_size=holdout_size, + random_state=random_state, + ) + case _: + splitter = splitter # noqa: PLW0127 + + # This whole block is to check whether we should resample for stratified + # sampling, in the case of a low minority class. + if ( + isinstance(splitter, StratifiedKFold) + and rebalance_if_required_for_stratified_splitting is not False + and task_type in ("binary", "multiclass") + ): + if rebalance_if_required_for_stratified_splitting is None: + _warning = ( + f"Labels have fewer than `{n_splits=}` instances. Resampling data" + " to ensure it's possible to have one of each label in each fold." + " Note that this may cause things to crash if you've provided extra" + " `params` as the `X` data will have gotten slightly larger. Please" + " set `rebalance_if_required_for_stratified_splitting=False` if you" + " do not wish this to be enabled automatically, in which case, you" + " may either perform resampling yourself or choose a smaller" + " `n_splits=`." + ) + else: + _warning = None + + x_is_frame = isinstance(X, pd.DataFrame) + y_is_frame = isinstance(y, pd.Series | pd.DataFrame) + + X, y = resample_if_minority_class_too_few_for_n_splits( # type: ignore + X if x_is_frame else pd.DataFrame(X), + y if y_is_frame else pd.Series(y), # type: ignore + n_splits=n_splits, + seed=random_state, + _warning_if_occurs=_warning, + ) + + if not x_is_frame: + X = X.to_numpy() # type: ignore + if not y_is_frame: + y = y.to_numpy() # type: ignore + + self.task_type = task_type + self.additional_scorers = additional_scorers + self.bucket = bucket + self.splitter = splitter + self.params = dict(params) if params is not None else {} + self.store_models = store_models + self.train_score = train_score + + self.X_stored = self.bucket[self._X_FILENAME].put(X) + self.y_stored = self.bucket[self._Y_FILENAME].put(y) + + self.X_test_stored = None + self.y_test_stored = None + if X_test is not None and y_test is not None: + self.X_test_stored = self.bucket[self._X_TEST_FILENAME].put(X_test) + self.y_test_stored = self.bucket[self._Y_TEST_FILENAME].put(y_test) + + # We apply a heuristic that "large" parameters, such as sample_weights + # should be stored to disk as transferring them directly to subprocess as + # parameters is quite expensive (they must be non-optimally pickled and + # streamed to the receiving process). By saving it to a file, we can + # make use of things like numpy/pandas specific efficient pickling + # protocols and also avoid the need to stream it to the subprocess. + storable_params = { + k: v + for k, v in self.params.items() + if hasattr(v, "__len__") and len(v) > self.LARGE_PARAM_HEURISTIC # type: ignore + } + for k, v in storable_params.items(): + match subclass_map(v, self.PARAM_EXTENSION_MAPPING, default=None): # type: ignore + case (_, extension_to_save_as): + ext = extension_to_save_as + case _: + ext = "pkl" + + self.params[k] = self.bucket[f"{k}.{ext}"].put(v) + + # This is the actual function that will be called in the task + self.fn = partial( + cross_validate_task, + X=self.X_stored, + y=self.y_stored, + X_test=self.X_test_stored, + y_test=self.y_test_stored, + splitter=self.splitter, + additional_scorers=self.additional_scorers, + params=self.params, + store_models=self.store_models, + train_score=self.train_score, + on_error=on_error, + post_split=post_split, + post_processing=post_processing, + post_processing_requires_models=post_processing_requires_models, + ) + + def cv_early_stopping_plugin( + self, + strategy: CVEarlyStoppingProtocol + | None = None, # TODO: Can provide some defaults... + *, + create_comms: Callable[[], tuple[Comm, Comm]] | None = None, + ) -> CVEvaluation._CVEarlyStoppingPlugin: + """Create a plugin for a task allow for early stopping. + + ```python exec="true" source="material-block" result="python" html="true" + from dataclasses import dataclass + from pathlib import Path + + import sklearn.datasets + from sklearn.tree import DecisionTreeClassifier + + from amltk.sklearn import CVEvaluation + from amltk.pipeline import Component + from amltk.optimization import Metric, Trial + + working_dir = Path("./some-path") + pipeline = Component(DecisionTreeClassifier, space={"max_depth": (1, 10)}) + x, y = sklearn.datasets.load_iris(return_X_y=True) + evaluator = CVEvaluation(x, y, n_splits=3, working_dir=working_dir) + + # Our early stopping strategy, with an `update()` and `should_stop()` + # signature match what's expected. + + @dataclass + class CVEarlyStopper: + def update(self, report: Trial.Report) -> None: + # Normally you would update w.r.t. a finished trial, such + # as updating a moving average of the scores. + pass + + def should_stop(self, trial: Trial, scores: CVEvaluation.SplitScores) -> bool | Exception: + # Return True to stop, False to continue. Alternatively, return a + # specific exception to attach to the report instead + return True + + history = pipeline.optimize( + target=evaluator.fn, + metric=Metric("accuracy", minimize=False, bounds=(0, 1)), + max_trials=1, + working_dir=working_dir, + + # Here we insert the plugin to the task that will get created + plugins=[evaluator.cv_early_stopping_plugin(strategy=CVEarlyStopper())], + + # Notably, we set `on_trial_exception="continue"` to not stop as + # we expect trials to fail given the early stopping strategy + on_trial_exception="continue", + ) + from amltk._doc import doc_print; doc_print(print, history[0]) # markdown-exec: hide + evaluator.bucket.rmdir() # markdown-exec: hide + ``` + + !!! warning "Recommended settings for `CVEvaluation` + + When a trial is early stopped, it will be counted as a failed trial. + This can conflict with the behaviour of `pipeline.optimize` which + by default sets `on_trial_exception="raise"`, causing the optimization + to end. If using [`pipeline.optimize`][amltk.pipeline.Node.optimize], + to set `on_trial_exception="continue"` to continue optimization. + + This will also add a new event to the task which you can subscribe to with + [`task.on("split-evaluated")`][amltk.sklearn.evaluation.CVEvaluation.SPLIT_EVALUATED]. + It will be passed a + [`CVEvaluation.PostSplitInfo`][amltk.sklearn.evaluation.CVEvaluation.PostSplitInfo] + that you can use to make a decision on whether to continue or stop. The + passed in `strategy=` simply sets up listening to these events for you. + You can also do this manually. + + ```python + scores = [] + evaluator = CVEvaluation(...) + task = scheduler.task( + evaluator.fn, + plugins=[evaluator.cv_early_stopping_plugin()] + ) + + @task.on("split-evaluated") + def should_stop(trial: Trial, scores: CVEvaluation.SplitScores) -> bool | Execption: + # Make a decision on whether to stop or continue + return info.scores["accuracy"] < np.mean(scores.val["accuracy"]) + + @task.on("result") + def update_scores(_, report: Trial.Report) -> bool | Execption: + if report.status is Trial.Status.SUCCESS: + return scores.append(report.values["accuracy"]) + ``` + + Args: + strategy: The strategy to use for early stopping. Must implement the + `update()` and `should_stop()` methods of + [`CVEarlyStoppingProtocol`][amltk.sklearn.evaluation.CVEarlyStoppingProtocol]. + Please follow the documentation link to find out more. + + By default, when no `strategy=` is passedj this is `None` and + this will create a [`Comm`][amltk.scheduling.plugins.comm.Comm] object, + allowing communication between the worker running the task and the main + process. This adds a new event to the task that you can subscribe + to with + [`task.on("split-evaluated")`][amltk.sklearn.evaluation.CVEvaluation.SPLIT_EVALUATED]. + This is how a passed in strategy will be called and updated. + create_comms: A function that creates a pair of comms for the + plugin to use. This is useful if you want to create a + custom communication channel. If not provided, the default + communication channel will be used. + + !!! note "Default communication channel" + + By default we use a simple `multiprocessing.Pipe` which works + for parallel processses from + [`ProcessPoolExecutor`][concurrent.futures.ProcessPoolExecutor]. + This may not work if the tasks is being executed in a different + filesystem or depending on the executor which executes the task. + + Returns: + The plugin to use for the task. + """ # noqa: E501 + return CVEvaluation._CVEarlyStoppingPlugin( + self, + strategy=strategy, + create_comms=create_comms, + ) + + +class CVEarlyStoppingProtocol(Protocol): + """Protocol for early stopping in cross-validation. + + You class should implement the + [`update()`][amltk.sklearn.evaluation.CVEarlyStoppingProtocol.update] + and [`should_stop()`][amltk.sklearn.evaluation.CVEarlyStoppingProtocol.should_stop] + methods. You can optionally inherit from this class but it is not required. + + ```python + class MyStopper: + + def update(self, report: Trial.Report) -> None: + if report.status is Trial.Status.SUCCESS: + # ... do some update logic + + def should_stop(self, trial: Trial, split_infos: list[CVEvaluation.PostSplitInfo]) -> bool | Exception: + mean_scores_up_to_current_split = np.mean([i.val_scores["accuracy"] for i in split_infos]) + if mean_scores_up_to_current_split > 0.9: + return False # Keep going + else: + return True # Stop evaluating + ``` + """ # noqa: E501 + + def update(self, report: Trial.Report) -> None: + """Update the protocol with a new report. + + This will be called when a trial has been completed, either successfully + or failed. You can check for successful trials by using + [`report.status`][amltk.optimization.Trial.Report.status]. + + Args: + report: The report from the trial. + """ + ... + + def should_stop( + self, + trial: Trial, + scores: CVEvaluation.SplitScores, + ) -> bool | Exception: + """Determines whether the cross-validation should stop early. + + Args: + trial: The trial that is currently being evaluated. + scores: The scores from the evlauated splits. + + Returns: + `True` if the cross-validation should stop, `False` if it should + continue, or an `Exception` if it should stop and you'd like a custom + error to be registered with the trial. + """ + ... diff --git a/src/amltk/store/__init__.py b/src/amltk/store/__init__.py index 714343f5..8df1a2a8 100644 --- a/src/amltk/store/__init__.py +++ b/src/amltk/store/__init__.py @@ -12,7 +12,7 @@ TxtLoader, YAMLLoader, ) -from amltk.store.stored_value import StoredValue +from amltk.store.stored import Stored __all__ = [ "Bucket", @@ -27,5 +27,5 @@ "TxtLoader", "ByteLoader", "PathLoader", - "StoredValue", + "Stored", ] diff --git a/src/amltk/store/bucket.py b/src/amltk/store/bucket.py index 8b274e95..f0d76732 100644 --- a/src/amltk/store/bucket.py +++ b/src/amltk/store/bucket.py @@ -13,7 +13,6 @@ import re from abc import ABC, abstractmethod from collections.abc import ( - Callable, Hashable, Iterable, Iterator, @@ -243,22 +242,16 @@ def update(self, items: Mapping[KeyT, Any]) -> None: # type: ignore for key, value in items.items(): self[key].put(value) - def remove( - self, - keys: Iterable[KeyT], - *, - how: Callable[[LinkT], bool] | None = None, - ) -> dict[KeyT, bool]: + def remove(self, keys: Iterable[KeyT]) -> dict[KeyT, bool]: """Remove resources from the bucket. Args: keys: The keys to the resources. - how: A function that removes the resource. Returns: A mapping of keys to whether they were removed. """ - return {key: self[key].remove(how=how) for key in keys} + return {key: self[key].remove() for key in keys} def __truediv__(self, key: KeyT) -> Self: try: diff --git a/src/amltk/store/drop.py b/src/amltk/store/drop.py index 275dc242..2ebd4f81 100644 --- a/src/amltk/store/drop.py +++ b/src/amltk/store/drop.py @@ -10,8 +10,7 @@ from more_itertools.more import first -from amltk._functional import funcname -from amltk.store.stored_value import StoredValue +from amltk.store.stored import Stored if TYPE_CHECKING: from amltk.store.loader import Loader @@ -38,16 +37,9 @@ class Drop(Generic[KeyT]): to use. Each drop has a list of default loaders that it will try to use to load the resource. - For flexibility, you can also specify a `how` when using any - of [`load`][amltk.store.drop.Drop.load], [`get`][amltk.store.drop.Drop.get] - or [`put`][amltk.store.drop.Drop.put] to override the default loaders. - The [`remove`][amltk.store.drop.Drop.remove] and - [`exists`][amltk.store.drop.Drop.exists] method also has a `how` - incase the default methods are not sufficient. - To support well typed code, you can also specify a `check` type which will be used to checked when loading objects, to make sure - it is of the correct type. This is ignored if `how` is specified. + it is of the correct type. The primary methods of interest are @@ -56,7 +48,7 @@ class Drop(Generic[KeyT]): * [`put`][amltk.store.drop.Drop.put] * [`remove`][amltk.store.drop.Drop.remove] * [`exists`][amltk.store.drop.Drop.exists] - * [`as_stored_value`][amltk.store.drop.Drop.as_stored_value] + * [`as_stored`][amltk.store.drop.Drop.as_stored] Args: key: The key to the resource. @@ -68,18 +60,15 @@ class Drop(Generic[KeyT]): _remove: Callable[[KeyT], bool] = field(repr=False) _exists: Callable[[KeyT], bool] = field(repr=False) - def as_stored_value( - self, - read: Callable[[KeyT], T] | None = None, - ) -> StoredValue[KeyT, T]: - """Convert the drop to a [`StoredValue`][amltk.store.StoredValue]. + def as_stored(self, read: Callable[[KeyT], T] | None = None) -> Stored[T]: + """Convert the drop to a [`Stored`][amltk.store.Stored]. Args: read: The method to use to load the resource. If `None` then the first loader that can load the resource will be used. Returns: - The drop as a [`StoredValue`][amltk.store.StoredValue]. + The drop as a [`Stored`][amltk.store.Stored]. """ if read is None: loader = first( @@ -92,26 +81,14 @@ def as_stored_value( read = loader.load - return StoredValue(self.key, read=read) + return Stored(self.key, read=read) - def put( - self, - obj: T, - *, - how: Callable[[T], None] | None = None, - ) -> None: + def put(self, obj: T) -> Stored[T]: """Put an object into the bucket. Args: obj: The object to put into the bucket. - how: The function to use to put the object into the bucket. - If `None` then the first loader that can put the object - will be used. """ - if how: - how(obj) - return - loader = first( (_l for _l in self.loaders if _l.can_save(obj, self.key)), default=None, @@ -119,54 +96,40 @@ def put( if not loader: msg = ( f"No default way to handle {type(obj)=} objects." - " Please provide a `how` function that will save" - f" the object to {self.key}." + " Please provide a `Loader` with your `Bucket` to specify" + f" how to handle this type of object with this extension: {self.key}." ) raise ValueError(msg) - loader.save(obj, self.key) + return loader.save(obj, self.key) @overload - def load(self, *, check: None = None, how: None = None) -> Any: + def load(self, *, check: None = None) -> Any: ... @overload - def load(self, *, check: type[T], how: None = None) -> T: + def load(self, *, check: type[T]) -> T: ... - @overload - def load(self, *, check: type[T] | None = ..., how: Callable[[KeyT], T]) -> T: - ... - - def load( - self, - *, - check: type[T] | None = None, - how: Callable[[KeyT], T] | None = None, - ) -> T | Any: + def load(self, *, check: type[T] | None = None) -> T | Any: """Load the resource. Args: check: By specifying a `type` we check the loaded object of that type, to enable correctly typed checked code. - how: The function to use to load the resource. Returns: The loaded resource. """ - if not isinstance(how, type) and callable(how): - value = how(self.key) - loader_name = funcname(how) - else: - loader = first( - (_l for _l in self.loaders if _l.can_load(self.key)), - default=None, - ) - if loader is None: - raise ValueError(f"Can't load {self.key=} from {self.loaders=}") + loader = first( + (_l for _l in self.loaders if _l.can_load(self.key)), + default=None, + ) + if loader is None: + raise ValueError(f"Can't load {self.key=} from {self.loaders=}") - value = loader.load(self.key) - loader_name = loader.name + value = loader.load(self.key) + loader_name = loader.name if check is not None and not isinstance(value, check): msg = ( @@ -178,53 +141,19 @@ def load( return value @overload - def get( - self, - default: None = None, - *, - check: None = None, - how: None = None, - ) -> Any | None: - ... - - @overload - def get( - self, - default: Default, - *, - check: None = None, - how: None = None, - ) -> Default | None: + def get(self, default: None = None, *, check: None = None) -> Any | None: ... @overload - def get( - self, - default: None = None, - *, - check: type[T], - how: Callable[[KeyT], T] = ..., - ) -> T | None: + def get(self, default: Default, *, check: None = None) -> Any | Default: ... @overload - def get( - self, - default: Default, - *, - check: type[T], - how: Callable[[KeyT], T], - ) -> Default | T: + def get(self, default: None = None, *, check: type[T]) -> T | None: ... @overload - def get( - self, - default: Default, - *, - check: type[T], - how: Callable[[KeyT], T] | None = ..., - ) -> Default | T: + def get(self, default: Default, *, check: type[T]) -> T | Default: ... def get( @@ -232,7 +161,6 @@ def get( default: Default | None = None, *, check: type[T] | None = None, - how: Callable[[KeyT], T] | None = None, ) -> Default | T | None: """Load the resource, or return the default if it can't be loaded. @@ -249,13 +177,12 @@ def get( enable correctly typed checked code. If the default value should be returned because the resource can't be loaded, then the default value is **not** checked. - how: The function to use to load the resource. Returns: The loaded resource or the default value if it cant be loaded. """ try: - return self.load(check=check, how=how) + return self.load(check=check) except TypeError as e: raise e except FileNotFoundError: @@ -268,29 +195,19 @@ def get( return None - def remove(self, *, how: Callable[[KeyT], bool] | None = None) -> bool: + def remove(self) -> bool: """Remove the resource from the bucket. - Args: - how: The function to use to remove the resource. Returns `True` if - the resource no longer exists after the removal, `False` otherwise. + !!! note "Non-existent resources" - !!! note "Non-existent resources" - - If the resource does not exist, then the function will `True`. + If the resource does not exist, then the function will return `True`. """ - logger.debug(f"Removing {self.key=}") - if how: - return how(self.key) - return self._remove(self.key) - def exists(self, *, how: Callable[[KeyT], bool] | None = None) -> bool: + def exists(self) -> bool: """Check if the resource exists. Returns: `True` if the resource exists, `False` otherwise. """ - if how: - return how(self.key) return self._exists(self.key) diff --git a/src/amltk/store/loader.py b/src/amltk/store/loader.py index 58e6765f..f1dde32a 100644 --- a/src/amltk/store/loader.py +++ b/src/amltk/store/loader.py @@ -8,9 +8,13 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, ClassVar, Generic, TypeVar +from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar + +if TYPE_CHECKING: + from amltk.store.stored import Stored T = TypeVar("T") +L = TypeVar("L") KeyT_contra = TypeVar("KeyT_contra", contravariant=True) @@ -56,7 +60,7 @@ def can_save(cls, obj: Any, key: KeyT_contra, /) -> bool: @classmethod @abstractmethod - def save(cls, obj: Any, key: KeyT_contra, /) -> None: + def save(cls, obj: Any, key: KeyT_contra, /) -> Stored[T]: """Save an object to under the given key. Args: diff --git a/src/amltk/store/paths/path_bucket.py b/src/amltk/store/paths/path_bucket.py index 577ae765..206c10b8 100644 --- a/src/amltk/store/paths/path_bucket.py +++ b/src/amltk/store/paths/path_bucket.py @@ -108,7 +108,7 @@ class PathBucket(Bucket[str, Path]): def __init__( self, - path: Path | str, + path: PathBucket | Path | str, *, loaders: Sequence[type[PathLoader]] | None = None, create: bool = True, @@ -135,6 +135,9 @@ def __init__( if isinstance(path, str): path = Path(path) + elif isinstance(path, PathBucket): + # TODO: Should we inherit the loaders? + path = path.path if clean and path.exists(): shutil.rmtree(path, ignore_errors=True) @@ -146,7 +149,7 @@ def __init__( path.mkdir(parents=True, exist_ok=True) self._create = create - self.path = path + self.path: Path = path self.loaders = _loaders def sizes(self) -> dict[str, int]: diff --git a/src/amltk/store/paths/path_loaders.py b/src/amltk/store/paths/path_loaders.py index 046dfd06..a142aaba 100644 --- a/src/amltk/store/paths/path_loaders.py +++ b/src/amltk/store/paths/path_loaders.py @@ -23,6 +23,7 @@ import pandas as pd from amltk.store.loader import Loader +from amltk.store.stored import Stored if TYPE_CHECKING: from types import ModuleType @@ -82,7 +83,7 @@ def can_save(cls, obj: Any, key: Path, /) -> bool: @override @classmethod @abstractmethod - def save(cls, obj: Any, key: Path, /) -> None: + def save(cls, obj: T, key: Path, /) -> Stored[T]: """Save an object to under the given key. Args: @@ -148,13 +149,17 @@ def load(cls, key: Path, /) -> np.ndarray: @override @classmethod - def save(cls, obj: np.ndarray, key: Path, /) -> None: + def save(cls, obj: np.ndarray, key: Path, /) -> Stored[np.ndarray]: """::: amltk.store.paths.path_loaders.PathLoader.save""" # noqa: D415 logger.debug(f"Saving {key=}") np.save(key, obj, allow_pickle=False) + return Stored(key, cls.load) -class PDLoader(PathLoader[pd.DataFrame | pd.Series]): +_DF = TypeVar("_DF", pd.DataFrame, pd.Series) + + +class PDLoader(PathLoader[_DF]): """A [`Loader`][amltk.store.paths.path_loaders.PathLoader] for loading and saving [`pd.DataFrame`][pandas.DataFrame]s. @@ -218,54 +223,57 @@ def can_save(cls, obj: Any, key: Path, /) -> bool: @override @classmethod - def load(cls, key: Path, /) -> pd.DataFrame | pd.Series: + def load(cls, key: Path, /) -> _DF: """::: amltk.store.paths.path_loaders.PathLoader.load""" # noqa: D415 logger.debug(f"Loading {key=}") if key.suffix == ".csv": - return pd.read_csv(key, index_col=0) + return pd.read_csv(key, index_col=0) # type: ignore if key.suffix == ".parquet": - return pd.read_parquet(key) + return pd.read_parquet(key) # type: ignore if key.suffix == ".pdpickle": obj = pd.read_pickle(key) # noqa: S301 - if not isinstance(obj, pd.Series | pd.DataFrame): + if not isinstance(obj, pd.Series | pd.DataFrame): # type: ignore msg = ( f"Expected `pd.Series | pd.DataFrame` from {key=}" f" but got `{type(obj).__name__}`." ) raise TypeError(msg) - return obj + return obj # type: ignore raise ValueError(f"Unknown file extension {key.suffix}") @override @classmethod - def save(cls, obj: pd.Series | pd.DataFrame, key: Path, /) -> None: + def save(cls, obj: _DF, key: Path, /) -> Stored[_DF]: """::: amltk.store.paths.path_loaders.PathLoader.save""" # noqa: D415 # Most pandas methods only seem to support dataframes logger.debug(f"Saving {key=}") if key.suffix == ".pdpickle": obj.to_pickle(key) - return + return Stored(key, cls.load) if key.suffix == ".csv": if obj.index.name is None and obj.index.nlevels == 1: obj.index.name = "index" obj.to_csv(key, index=True) - return + return Stored(key, cls.load) if key.suffix == ".parquet": obj.to_parquet(key) - return + return Stored(key, cls.load) raise ValueError(f"Unknown extension {key.suffix=}") -class JSONLoader(PathLoader[dict | list]): +_Json = TypeVar("_Json", dict, list) + + +class JSONLoader(PathLoader[_Json]): """A [`Loader`][amltk.store.paths.path_loaders.PathLoader] for loading and saving [`dict`][dict]s and [`list`][list]s to JSON. @@ -296,7 +304,7 @@ def can_save(cls, obj: Any, key: Path, /) -> bool: @override @classmethod - def load(cls, key: Path, /) -> dict | list: + def load(cls, key: Path, /) -> _Json: """::: amltk.store.paths.path_loaders.PathLoader.load""" # noqa: D415 logger.debug(f"Loading {key=}") with key.open("r") as f: @@ -306,18 +314,22 @@ def load(cls, key: Path, /) -> dict | list: msg = f"Expected `dict | list` from {key=} but got `{type(item).__name__}`" raise TypeError(msg) - return item + return item # type: ignore @override @classmethod - def save(cls, obj: dict | list, key: Path, /) -> None: + def save(cls, obj: _Json, key: Path, /) -> Stored[_Json]: """::: amltk.store.paths.path_loaders.PathLoader.save""" # noqa: D415 logger.debug(f"Saving {key=}") with key.open("w") as f: json.dump(obj, f) + return Stored(key, cls.load) -class YAMLLoader(PathLoader[dict | list]): +_Yaml = TypeVar("_Yaml", dict, list) + + +class YAMLLoader(PathLoader[_Yaml]): """A [`Loader`][amltk.store.paths.path_loaders.PathLoader] for loading and saving [`dict`][dict]s and [`list`][list]s to YAML. @@ -349,7 +361,7 @@ def can_save(cls, obj: Any, key: Path, /) -> bool: @override @classmethod - def load(cls, key: Path, /) -> dict | list: + def load(cls, key: Path, /) -> _Yaml: """::: amltk.store.paths.path_loaders.PathLoader.load""" # noqa: D415 logger.debug(f"Loading {key=}") if yaml is None: @@ -362,11 +374,11 @@ def load(cls, key: Path, /) -> dict | list: msg = f"Expected `dict | list` from {key=} but got `{type(item).__name__}`" raise TypeError(msg) - return item + return item # type: ignore @override @classmethod - def save(cls, obj: dict | list, key: Path, /) -> None: + def save(cls, obj: _Yaml, key: Path, /) -> Stored[_Yaml]: """::: amltk.store.paths.path_loaders.PathLoader.save""" # noqa: D415 logger.debug(f"Saving {key=}") if yaml is None: @@ -375,6 +387,8 @@ def save(cls, obj: dict | list, key: Path, /) -> None: with key.open("w") as f: yaml.dump(obj, f) + return Stored(key, cls.load) + class PickleLoader(PathLoader[Any]): """A [`Loader`][amltk.store.paths.path_loaders.PathLoader] for loading and @@ -427,11 +441,12 @@ def load(cls, key: Path, /) -> Any: @override @classmethod - def save(cls, obj: Any, key: Path, /) -> None: + def save(cls, obj: Any, key: Path, /) -> Stored[Any]: """::: amltk.store.paths.path_loaders.PathLoader.save""" # noqa: D415 logger.debug(f"Saving {key=}") with key.open("wb") as f: pickle.dump(obj, f) + return Stored(key, cls.load) class TxtLoader(PathLoader[str]): @@ -473,11 +488,12 @@ def load(cls, key: Path, /) -> str: @override @classmethod - def save(cls, obj: str, key: Path, /) -> None: + def save(cls, obj: str, key: Path, /) -> Stored[str]: """::: amltk.store.paths.path_loaders.PathLoader.save""" # noqa: D415 logger.debug(f"Saving {key=}") with key.open("w") as f: f.write(obj) + return Stored(key, cls.load) class ByteLoader(PathLoader[bytes]): @@ -518,8 +534,9 @@ def load(cls, key: Path, /) -> bytes: @override @classmethod - def save(cls, obj: bytes, key: Path, /) -> None: + def save(cls, obj: bytes, key: Path, /) -> Stored[bytes]: """::: amltk.store.paths.path_loaders.PathLoader.save""" # noqa: D415 logger.debug(f"Saving {key=}") with key.open("wb") as f: f.write(obj) + return Stored(key, cls.load) diff --git a/src/amltk/store/stored.py b/src/amltk/store/stored.py new file mode 100644 index 00000000..9383f0b1 --- /dev/null +++ b/src/amltk/store/stored.py @@ -0,0 +1,70 @@ +"""A value that is stored on disk and loaded lazily. + +This is useful for transmitting large objects between processes. + +```python exec="true" source="material-block" result="python" title="Stored" +from amltk.store import Stored +import pandas as pd +from pathlib import Path + +df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) +path = Path("df.csv") + +df.to_csv(path) +stored_df = Stored(path, read=pd.read_csv) + +# Somewhere in a processes +df = stored_df.load() +print(df) +path.unlink() # markdown-exec: hide +``` + +You can quickly obtain these from buckets if you require using +[`put()`][amltk.store.drop.Drop.put]. + +```python exec="true" source="material-block" result="python" title="Stored from bucket" +from amltk import PathBucket +import pandas as pd + +df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) +bucket = PathBucket("bucket_path") + +stored_df = bucket["df.csv"].put(df) + +# Somewhere in a processes +df = stored_df.load() +print(df) +bucket.rmdir() # markdown-exec: hide +``` +""" +from __future__ import annotations + +from collections.abc import Callable +from typing import Generic, TypeVar +from typing_extensions import override + +K = TypeVar("K") +V = TypeVar("V") + + +class Stored(Generic[V]): + """A value that is stored on disk and can be loaded when needed.""" + + def __init__(self, key: K, read: Callable[[K], V]): + """Initialize the stored value. + + Args: + key: The key to load the value from. + read: A function that takes a key and returns the value. + """ + super().__init__() + self.key = key + self.read = read + + def load(self) -> V: + """Get the value.""" + return self.read(self.key) + + @override + def __repr__(self) -> str: + return f"Stored({self.key})" diff --git a/src/amltk/store/stored_value.py b/src/amltk/store/stored_value.py deleted file mode 100644 index 3a6689e9..00000000 --- a/src/amltk/store/stored_value.py +++ /dev/null @@ -1,65 +0,0 @@ -"""A value that is stored on disk and loaded lazily. - -This is useful for transmitting large objects between processes. - -```python exec="true" source="material-block" result="python" title="StoredValue" -from amltk.store import StoredValue -import pandas as pd -from pathlib import Path - -df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) -path = Path("df.csv") -df.to_csv(path) - -stored_value = StoredValue(path, read=pd.read_csv) - -# Somewhere in a processes -df = stored_value.value() -print(df) - -path.unlink() -``` - -You can quickly obtain these from buckets if you require -```python exec="true" source="material-block" result="python" title="StoredValue from bucket" -from amltk import PathBucket -import pandas as pd - -df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) -bucket = PathBucket("bucket_path") -bucket.update({"df.csv": df}) - -stored_value = bucket["df.csv"].as_stored_value() - -# Somewhere in a processes -df = stored_value.value() -print(df) - -bucket.rmdir() -``` -""" # noqa: E501 -from __future__ import annotations - -from collections.abc import Callable -from dataclasses import dataclass -from typing import Generic, TypeVar - -K = TypeVar("K") -V = TypeVar("V") - - -@dataclass -class StoredValue(Generic[K, V]): - """A value that is stored on disk and can be loaded when needed.""" - - key: K - read: Callable[[K], V] - - _value: V | None = None - - def value(self) -> V: - """Get the value.""" - if self._value is None: - self._value = self.read(self.key) - - return self._value diff --git a/src/amltk/types.py b/src/amltk/types.py index eb9c821e..879c2e65 100644 --- a/src/amltk/types.py +++ b/src/amltk/types.py @@ -31,7 +31,7 @@ Space = TypeVar("Space") """Generic for objects that are aware of a space but not the specific kind""" -Seed: TypeAlias = int | np.integer | (np.random.RandomState | np.random.Generator) +Seed: TypeAlias = int | np.integer | np.random.RandomState | np.random.Generator """Type alias for kinds of Seeded objects.""" FidT: TypeAlias = tuple[int, int] | tuple[float, float] | list[Any] diff --git a/tests/optimizers/test_history.py b/tests/optimizers/test_history.py index 64d046d8..78135947 100644 --- a/tests/optimizers/test_history.py +++ b/tests/optimizers/test_history.py @@ -9,7 +9,7 @@ from amltk.optimization import History, Metric, Trial -metrics = [Metric("loss", minimize=True)] +metrics = {"loss": Metric("loss", minimize=True)} def quadratic(x: float) -> float: @@ -36,7 +36,7 @@ def eval_trial( reports: list[Trial.Report] = [] for _trial in trial: - with _trial.begin(): + with _trial.profile("trial"): with _trial.profile("dummy"): pass @@ -55,35 +55,65 @@ def case_empty() -> list[Trial.Report]: @case(tags=["success"]) -def case_one_report_success() -> list[Trial.Report]: - trial: Trial = Trial(name="trial_1", config={"x": 1}, metrics=metrics) +def case_one_report_success(tmp_path: Path) -> list[Trial.Report]: + trial: Trial = Trial.create( + name="trial_1", + config={"x": 1}, + metrics=metrics, + bucket=tmp_path, + ) return eval_trial(trial) @case(tags=["fail"]) -def case_one_report_fail() -> list[Trial.Report]: - trial: Trial = Trial(name="trial_1", config={"x": 1}, metrics=metrics) +def case_one_report_fail(tmp_path: Path) -> list[Trial.Report]: + trial: Trial = Trial.create( + name="trial_1", + config={"x": 1}, + metrics=metrics, + bucket=tmp_path, + ) return eval_trial(trial, fail=100) @case(tags=["crash"]) -def case_one_report_crash() -> list[Trial.Report]: - trial: Trial = Trial(name="trial_1", config={"x": 1}, metrics=metrics) +def case_one_report_crash(tmp_path: Path) -> list[Trial.Report]: + trial: Trial = Trial.create( + name="trial_1", + config={"x": 1}, + metrics=metrics, + bucket=tmp_path, + ) return eval_trial(trial, crash=ValueError("Some Error")) @case(tags=["success", "fail", "crash"]) -def case_many_report() -> list[Trial.Report]: +def case_many_report(tmp_path: Path) -> list[Trial.Report]: success_trials: list[Trial] = [ - Trial(name=f"trial_{i+6}", config={"x": i}, metrics=metrics) + Trial.create( + name=f"trial_{i+6}", + config={"x": i}, + metrics=metrics, + bucket=tmp_path, + ) for i in range(-5, 5) ] fail_trials: list[Trial] = [ - Trial(name=f"trial_{i+16}", config={"x": i}, metrics=metrics) + Trial.create( + name=f"trial_{i+16}", + config={"x": i}, + metrics=metrics, + bucket=tmp_path, + ) for i in range(-5, 5) ] crash_trials: list[Trial] = [ - Trial(name=f"trial_{i+26}", config={"x": i}, metrics=metrics) + Trial.create( + name=f"trial_{i+26}", + config={"x": i}, + metrics=metrics, + bucket=tmp_path, + ) for i in range(-5, 5) ] @@ -165,9 +195,14 @@ def test_trace_sortby(reports: list[Trial.Report]) -> None: ) -def test_history_sortby() -> None: +def test_history_sortby(tmp_path: Path) -> None: trials: list[Trial] = [ - Trial(name=f"trial_{i+6}", metrics=metrics, config={"x": i}) + Trial.create( + name=f"trial_{i+6}", + metrics=metrics, + config={"x": i}, + bucket=tmp_path, + ) for i in range(-5, 5) ] @@ -175,16 +210,15 @@ def test_history_sortby() -> None: history = History() for trial in trials: - with trial.begin(): - if trial.name in summary_items: - trial.summary["other_loss"] = trial.config["x"] ** 2 + if trial.name in summary_items: + trial.summary["other_loss"] = trial.config["x"] ** 2 - report = trial.success(loss=trial.config["x"]) - history.add(report) + report = trial.success(loss=trial.config["x"]) + history.add(report) trace_loss = history.sortby("loss") assert len(trace_loss) == len(trials) - losses = [r.metrics["loss"] for r in trace_loss] + losses = [r.values["loss"] for r in trace_loss] assert sorted(losses) == losses trace_other = history.filter(lambda report: "other_loss" in report.summary).sortby( @@ -197,36 +231,65 @@ def test_history_sortby() -> None: assert sorted(losses) == losses -def test_history_incumbents() -> None: +def test_history_best(tmp_path: Path) -> None: + trials: list[Trial] = [ + Trial.create( + name=f"trial_{i}", + metrics=metrics, + config={"x": i}, + bucket=tmp_path, + ) + for i in range(10) + ] + + history = History() + + for trial in trials: + # This should have been the best but failed + if trial.name == "trial_0": + history.add(trial.fail()) + else: + history.add(trial.success(loss=trial.config["x"])) + + best = history.best("loss") + assert best.name == "trial_1" + assert best.values["loss"] == 1 + + +def test_history_incumbents(tmp_path: Path) -> None: m1 = Metric("score", minimize=False) m2 = Metric("loss", minimize=True) trials: list[Trial] = [ - Trial(name=f"trial_{i+6}", metrics=[m1, m2], config={"x": i}) + Trial.create( + name=f"trial_{i+6}", + metrics={"score": m1, "loss": m2}, + config={"x": i}, + bucket=tmp_path / "bucket", + ) for i in [0, -1, 2, -3, 4, -5, 6, -7, 8, -9] ] history = History() for trial in trials: - with trial.begin(): - x = trial.config["x"] - report = trial.success(loss=x, score=x) - history.add(report) + x = trial.config["x"] + report = trial.success(loss=x, score=x) + history.add(report) hist_1 = history.incumbents("loss", ffill=True) expected_1 = [0, -1, -1, -3, -3, -5, -5, -7, -7, -9] - assert [r.metrics["loss"] for r in hist_1] == expected_1 + assert [r.values["loss"] for r in hist_1] == expected_1 hist_2 = history.incumbents("loss", ffill=False) expected_2 = [0, -1, -3, -5, -7, -9] - assert [r.metrics["loss"] for r in hist_2] == expected_2 + assert [r.values["loss"] for r in hist_2] == expected_2 hist_3 = history.incumbents("score", ffill=True) expected_3 = [0, 0, 2, 2, 4, 4, 6, 6, 8, 8] - assert [r.metrics["score"] for r in hist_3] == expected_3 + assert [r.values["score"] for r in hist_3] == expected_3 hist_4 = history.incumbents("score", ffill=False) expected_4 = [0, 2, 4, 6, 8] - assert [r.metrics["score"] for r in hist_4] == expected_4 + assert [r.values["score"] for r in hist_4] == expected_4 @parametrize_with_cases("reports", cases=".") diff --git a/tests/optimizers/test_metric.py b/tests/optimizers/test_metric.py index c2bdadc1..7b4df3d7 100644 --- a/tests/optimizers/test_metric.py +++ b/tests/optimizers/test_metric.py @@ -2,6 +2,7 @@ from dataclasses import dataclass +import pytest from pytest_cases import case, parametrize_with_cases from amltk.optimization.metric import Metric @@ -12,10 +13,11 @@ class MetricTest: """A test case for a metric.""" metric: Metric - v: Metric.Value + value: float expected_loss: float expected_distance_from_optimal: float | None expected_score: float + expected_normalized_loss: float expected_str: str @@ -24,9 +26,10 @@ def case_metric_score_bounded() -> MetricTest: metric = Metric("score_bounded", minimize=False, bounds=(0, 1)) return MetricTest( metric=metric, - v=metric(0.3), + value=0.3, expected_loss=-0.3, expected_distance_from_optimal=0.7, + expected_normalized_loss=0.7, expected_score=0.3, expected_str="score_bounded [0.0, 1.0] (maximize)", ) @@ -37,9 +40,10 @@ def case_metric_score_unbounded() -> MetricTest: metric = Metric("score_unbounded", minimize=False) return MetricTest( metric=metric, - v=metric(0.3), + value=0.3, expected_loss=-0.3, expected_distance_from_optimal=None, + expected_normalized_loss=-0.3, expected_score=0.3, expected_str="score_unbounded (maximize)", ) @@ -50,9 +54,10 @@ def case_metric_loss_unbounded() -> MetricTest: metric = Metric("loss_unbounded", minimize=True) return MetricTest( metric=metric, - v=metric(0.8), + value=0.8, expected_loss=0.8, expected_distance_from_optimal=None, + expected_normalized_loss=0.8, expected_score=-0.8, expected_str="loss_unbounded (minimize)", ) @@ -63,9 +68,10 @@ def case_metric_loss_bounded() -> MetricTest: metric = Metric("loss_bounded", minimize=True, bounds=(-1, 1)) return MetricTest( metric=metric, - v=metric(0.8), + value=0.8, expected_loss=0.8, expected_distance_from_optimal=1.8, + expected_normalized_loss=0.9, expected_score=-0.8, expected_str="loss_bounded [-1.0, 1.0] (minimize)", ) @@ -73,27 +79,28 @@ def case_metric_loss_bounded() -> MetricTest: @parametrize_with_cases(argnames="C", cases=".") def test_metrics_have_expected_outputs(C: MetricTest) -> None: - assert C.v.loss == C.expected_loss - assert C.v.distance_to_optimal == C.expected_distance_from_optimal - assert C.v.score == C.expected_score + assert C.metric.loss(C.value) == C.expected_loss + if C.expected_distance_from_optimal is not None: + assert C.metric.distance_to_optimal(C.value) == C.expected_distance_from_optimal + assert C.metric.score(C.value) == C.expected_score assert str(C.metric) == C.expected_str @parametrize_with_cases(argnames="C", cases=".", has_tag=["maximize"]) def test_metric_value_is_score_if_maximize(C: MetricTest) -> None: - assert C.v.value == C.v.score - assert C.v.value == -C.v.loss + assert C.value == C.metric.score(C.value) + assert C.value == -C.metric.loss(C.value) @parametrize_with_cases(argnames="C", cases=".", has_tag=["minimize"]) def test_metric_value_is_loss_if_minimize(C: MetricTest) -> None: - assert C.v.value == C.v.loss - assert C.v.value == -C.v.score + assert C.value == C.metric.loss(C.value) + assert C.value == -C.metric.score(C.value) @parametrize_with_cases(argnames="C", cases=".") def test_metric_value_score_is_just_loss_inverted(C: MetricTest) -> None: - assert C.v.score == -C.v.loss + assert C.metric.score(C.value) == -C.metric.loss(C.value) @parametrize_with_cases(argnames="C", cases=".", has_tag=["minimize", "unbounded"]) @@ -123,14 +130,33 @@ def test_maximize_metric_worst_optimal_if_bounded(C: MetricTest) -> None: @parametrize_with_cases(argnames="C", cases=".", has_tag=["unbounded"]) -def test_distance_to_optimal_is_none_for_unbounded(C: MetricTest) -> None: - assert C.v.distance_to_optimal is None +def test_distance_to_optimal_is_raises_for_unbounded(C: MetricTest) -> None: + with pytest.raises(ValueError, match="unbounded"): + C.metric.distance_to_optimal(C.value) @parametrize_with_cases(argnames="C", cases=".", has_tag=["bounded"]) def test_distance_to_optimal_is_always_positive_for_bounded(C: MetricTest) -> None: - assert C.v.distance_to_optimal - assert C.v.distance_to_optimal >= 0 + assert C.metric.distance_to_optimal(C.value) >= 0 + + +@parametrize_with_cases(argnames="C", cases=".") +def test_normalized_loss(C: MetricTest) -> None: + assert C.metric.normalized_loss(C.value) == C.expected_normalized_loss + + +@parametrize_with_cases(argnames="C", cases=".", has_tag=["bounded"]) +def test_normalized_loss_for_bounded(C: MetricTest) -> None: + assert 0 <= C.metric.normalized_loss(C.value) <= 1 + assert C.metric.normalized_loss(C.metric.optimal) == 0 + mid = (C.metric.optimal + C.metric.worst) / 2 + assert C.metric.normalized_loss(mid) == 0.5 + assert C.metric.normalized_loss(C.metric.worst) == 1 + + +@parametrize_with_cases(argnames="C", cases=".", has_tag=["unbounded"]) +def test_normalized_loss_for_unbounded_is_loss(C: MetricTest) -> None: + assert C.metric.normalized_loss(C.value) == C.metric.loss(C.value) @parametrize_with_cases(argnames="C", cases=".") diff --git a/tests/optimizers/test_optimizers.py b/tests/optimizers/test_optimizers.py index 52588aa5..47caad09 100644 --- a/tests/optimizers/test_optimizers.py +++ b/tests/optimizers/test_optimizers.py @@ -5,10 +5,12 @@ from typing import TYPE_CHECKING import pytest +from more_itertools import all_unique from pytest_cases import case, parametrize, parametrize_with_cases from amltk.optimization import Metric, Optimizer, Trial from amltk.pipeline import Component +from amltk.pipeline.components import Choice from amltk.profiling import Timer if TYPE_CHECKING: @@ -23,6 +25,10 @@ class _A: pass +class _B: + pass + + metrics = [ Metric("score_bounded", minimize=False, bounds=(0, 1)), Metric("score_unbounded", minimize=False), @@ -33,20 +39,20 @@ class _A: def target_function(trial: Trial, err: Exception | None = None) -> Trial.Report: """A target function for testing optimizers.""" - with trial.begin(): + with trial.profile("trial"): # Do stuff with trail.info here logger.debug(trial.info) if err is not None: - raise err + return trial.fail(err) return trial.success( - **{metric.name: metric.optimal.value for metric in trial.metrics}, + **{ + metric_name: metric.optimal + for metric_name, metric in trial.metrics.items() + }, ) - # Should fill in metric.worst here - return trial.fail() # pyright: ignore - def valid_time_interval(interval: Timer.Interval) -> bool: """Check if the start and end time are valid.""" @@ -60,7 +66,7 @@ def opt_smac_hpo(metric: Metric, tmp_path: Path) -> SMACOptimizer: except ImportError: pytest.skip("SMAC is not installed") - pipeline = Component(_A, name="hi", space={"a": (1, 10)}) + pipeline = Component(_A, name="hi", space={"a": (1.0, 10.0)}) return SMACOptimizer.create( space=pipeline, bucket=tmp_path, @@ -86,6 +92,25 @@ def opt_optuna(metric: Metric, tmp_path: Path) -> OptunaOptimizer: ) +@case +@parametrize("metric", [*metrics, metrics]) # Single obj and multi +def opt_optuna_choice_hierarchical(metric: Metric, tmp_path: Path) -> OptunaOptimizer: + try: + from amltk.optimization.optimizers.optuna import OptunaOptimizer + except ImportError: + pytest.skip("Optuna is not installed") + + c1 = Component(_A, name="hi1", space={"a": [1, 2, 3]}) + c2 = Component(_B, name="hi2", space={"b": [4, 5, 6]}) + pipeline = Choice(c1, c2, name="hi") + return OptunaOptimizer.create( + space=pipeline, + metrics=metric, + seed=42, + bucket=tmp_path, + ) + + @case @parametrize("metric", [*metrics]) # Single obj def opt_neps(metric: Metric, tmp_path: Path) -> NEPSOptimizer: @@ -112,9 +137,11 @@ def test_report_success(optimizer: Optimizer) -> None: optimizer.tell(report) assert report.status == Trial.Status.SUCCESS - assert valid_time_interval(report.time) + assert valid_time_interval(report.profiles["trial"].time) assert report.trial.info is trial.info - assert report.metric_values == tuple(metric.optimal for metric in optimizer.metrics) + assert report.values == { + name: metric.optimal for name, metric in optimizer.metrics.items() + } @parametrize_with_cases("optimizer", cases=".", prefix="opt_") @@ -124,7 +151,41 @@ def test_report_failure(optimizer: Optimizer): optimizer.tell(report) assert report.status is Trial.Status.FAIL - assert valid_time_interval(report.time) + assert valid_time_interval(report.profiles["trial"].time) assert isinstance(report.exception, ValueError) assert isinstance(report.traceback, str) - assert report.metric_values == tuple(metric.worst for metric in optimizer.metrics) + assert report.values == {} + + +@parametrize_with_cases("optimizer", cases=".", prefix="opt_") +def test_batched_ask_generates_unique_configs(optimizer: Optimizer): + """Test that batched ask generates unique configs.""" + # NOTE: This was tested with up to 100, at least from SMAC and Optuna. + # It was quite slow for smac so I've reduced it to 10. + # This is not a hard requirement of optimizers (maybe it should be?) + batch = list(optimizer.ask(10)) + assert len(batch) == 10 + assert all_unique(batch) + + +@parametrize_with_cases("optimizer", cases=".", prefix="opt_optuna_choice") +def test_optuna_choice_output(optimizer: Optimizer): + trial = optimizer.ask() + keys = list(trial.config.keys()) + assert any("__choice__" in k for k in keys), trial.config + + +@parametrize_with_cases("optimizer", cases=".", prefix="opt_optuna_choice") +def test_optuna_choice_no_params_left(optimizer: Optimizer): + trial = optimizer.ask() + keys_without_choices = [ + k for k in list(trial.config.keys()) if "__choice__" not in k + ] + for k, v in trial.config.items(): + if "__choice__" in k: + name_without_choice = k.removesuffix("__choice__") + params_for_choice = [ + k for k in keys_without_choices if k.startswith(name_without_choice) + ] + # Check that only params for the chosen choice are left + assert all(v in k for k in params_for_choice), params_for_choice diff --git a/tests/pipeline/parsing/test_optuna_parser.py b/tests/pipeline/parsing/test_optuna_parser.py index ba098ff7..789ab549 100644 --- a/tests/pipeline/parsing/test_optuna_parser.py +++ b/tests/pipeline/parsing/test_optuna_parser.py @@ -1 +1,219 @@ # TODO: Fill this in +from __future__ import annotations + +from dataclasses import dataclass + +import pytest +from pytest_cases import case, parametrize_with_cases + +from amltk.pipeline import Component, Fixed, Node +from amltk.pipeline.components import Choice, Split + +try: + from optuna.distributions import CategoricalDistribution, IntDistribution + + from amltk.pipeline.parsers.optuna import OptunaSearchSpace +except ImportError: + pytest.skip("Optuna not installed", allow_module_level=True) + + +FLAT = True +NOT_FLAT = False +CONDITIONED = True +NOT_CONDITIONED = False + + +@dataclass +class Params: + """A test case for parsing a Node into a ConfigurationSpace.""" + + root: Node + expected: dict[tuple[bool, bool], OptunaSearchSpace] + + +@case +def case_single_frozen() -> Params: + item = Fixed(object(), name="a") + space = OptunaSearchSpace() + expected = { + (NOT_FLAT, CONDITIONED): space, + (NOT_FLAT, NOT_CONDITIONED): space, + (FLAT, CONDITIONED): space, + (FLAT, NOT_CONDITIONED): space, + } + return Params(item, expected) # type: ignore + + +@case +def case_single_component() -> Params: + item = Component(object, name="a", space={"hp": [1, 2, 3]}) + space = OptunaSearchSpace({"a:hp": CategoricalDistribution([1, 2, 3])}) + expected = { + (NOT_FLAT, CONDITIONED): space, + (NOT_FLAT, NOT_CONDITIONED): space, + (FLAT, CONDITIONED): space, + (FLAT, NOT_CONDITIONED): space, + } + return Params(item, expected) # type: ignore + + +@case +def case_single_step_two_hp() -> Params: + item = Component(object, name="a", space={"hp": [1, 2, 3], "hp2": [1, 2, 3]}) + space = OptunaSearchSpace( + { + "a:hp": CategoricalDistribution([1, 2, 3]), + "a:hp2": CategoricalDistribution([1, 2, 3]), + }, + ) + + expected = { + (NOT_FLAT, CONDITIONED): space, + (NOT_FLAT, NOT_CONDITIONED): space, + (FLAT, CONDITIONED): space, + (FLAT, NOT_CONDITIONED): space, + } + return Params(item, expected) # type: ignore + + +@case +def case_single_step_two_hp_different_types() -> Params: + item = Component(object, name="a", space={"hp": [1, 2, 3], "hp2": (1, 10)}) + space = OptunaSearchSpace( + {"a:hp": CategoricalDistribution([1, 2, 3]), "a:hp2": IntDistribution(1, 10)}, + ) + expected = { + (NOT_FLAT, CONDITIONED): space, + (NOT_FLAT, NOT_CONDITIONED): space, + (FLAT, CONDITIONED): space, + (FLAT, NOT_CONDITIONED): space, + } + return Params(item, expected) # type: ignore + + +@case +def case_choice() -> Params: + item = Choice( + Component(object, name="a", space={"hp": [1, 2, 3]}), + Component(object, name="b", space={"hp2": (1, 10)}), + name="choice1", + space={"hp3": (1, 10)}, + ) + + expected = {} + + # Not Flat and without conditions + space = OptunaSearchSpace( + { + "choice1:a:hp": CategoricalDistribution([1, 2, 3]), + "choice1:b:hp2": IntDistribution(1, 10), + "choice1:hp3": IntDistribution(1, 10), + "choice1:__choice__": CategoricalDistribution(["a", "b"]), + }, + ) + expected[(NOT_FLAT, NOT_CONDITIONED)] = space + + # Flat and without conditions + space = OptunaSearchSpace( + { + "a:hp": CategoricalDistribution([1, 2, 3]), + "b:hp2": IntDistribution(1, 10), + "choice1:hp3": IntDistribution(1, 10), + "choice1:__choice__": CategoricalDistribution(["a", "b"]), + }, + ) + expected[(FLAT, NOT_CONDITIONED)] = space + return Params(item, expected) # type: ignore + + +@case +def case_nested_choices_with_split_and_choice() -> Params: + item = Choice( + Split( + Choice( + Component(object, name="a", space={"hp": [1, 2, 3]}), + Component(object, name="b", space={"hp2": (1, 10)}), + name="choice3", + ), + Component(object, name="c", space={"hp3": (1, 10)}), + name="split2", + ), + Component(object, name="d", space={"hp4": (1, 10)}), + name="choice1", + ) + expected = {} + + # Not flat and without conditions + space = OptunaSearchSpace( + { + "choice1:split2:choice3:a:hp": CategoricalDistribution([1, 2, 3]), + "choice1:split2:choice3:b:hp2": IntDistribution(1, 10), + "choice1:split2:c:hp3": IntDistribution(1, 10), + "choice1:d:hp4": IntDistribution(1, 10), + "choice1:__choice__": CategoricalDistribution(["d", "split2"]), + "choice1:split2:choice3:__choice__": CategoricalDistribution(["a", "b"]), + }, + ) + expected[(NOT_FLAT, NOT_CONDITIONED)] = space + + # Flat and without conditions + space = OptunaSearchSpace( + { + "a:hp": CategoricalDistribution([1, 2, 3]), + "b:hp2": IntDistribution(1, 10), + "c:hp3": IntDistribution(1, 10), + "d:hp4": IntDistribution(1, 10), + "choice1:__choice__": CategoricalDistribution(["d", "split2"]), + "choice3:__choice__": CategoricalDistribution(["a", "b"]), + }, + ) + expected[(FLAT, NOT_CONDITIONED)] = space + return Params(item, expected) + + +@parametrize_with_cases("test_case", cases=".") +def test_parsing_pipeline(test_case: Params) -> None: + pipeline = test_case.root + + for (flat, conditioned), expected in test_case.expected.items(): + parsed_space = pipeline.search_space( + "optuna", + flat=flat, + conditionals=conditioned, + ) + assert ( + parsed_space == expected + ), f"Failed for {flat=}, {conditioned=}.\n{parsed_space}\n{expected}" + + +@parametrize_with_cases("test_case", cases=".") +def test_parsing_does_not_mutate_space_of_nodes(test_case: Params) -> None: + pipeline = test_case.root + spaces_before = {tuple(path): step.space for path, step in pipeline.walk()} + + for (flat, conditioned), _ in test_case.expected.items(): + pipeline.search_space( + "optuna", + flat=flat, + conditionals=conditioned, + ) + spaces_after = {tuple(path): step.space for path, step in pipeline.walk()} + assert spaces_before == spaces_after + + +@parametrize_with_cases("test_case", cases=".") +def test_parsing_twice_produces_same_space(test_case: Params) -> None: + pipeline = test_case.root + + for (flat, conditioned), _ in test_case.expected.items(): + parsed_space = pipeline.search_space( + "optuna", + flat=flat, + conditionals=conditioned, + ) + parsed_space2 = pipeline.search_space( + "optuna", + flat=flat, + conditionals=conditioned, + ) + assert parsed_space == parsed_space2 diff --git a/tests/pipeline/test_node.py b/tests/pipeline/test_node.py index 0295fa26..eee83fdc 100644 --- a/tests/pipeline/test_node.py +++ b/tests/pipeline/test_node.py @@ -5,7 +5,7 @@ import pytest -from amltk.exceptions import RequestNotMetError +from amltk.exceptions import DuplicateNamesError, RequestNotMetError from amltk.pipeline import Choice, Join, Node, Sequential, request @@ -172,3 +172,13 @@ def test_walk() -> None: ): assert node == _exp_node assert path == _exp_path + + +def test_node_fails_if_children_with_duplicate_name() -> None: + with pytest.raises(DuplicateNamesError): + Node(Node(name="child1"), Node(name="child1"), name="node") + + +def test_node_fails_if_child_has_same_name() -> None: + with pytest.raises(DuplicateNamesError): + Node(Node(name="child1"), Node(name="node"), name="node") diff --git a/tests/pipeline/test_optimize.py b/tests/pipeline/test_optimize.py new file mode 100644 index 00000000..84403b74 --- /dev/null +++ b/tests/pipeline/test_optimize.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +from collections.abc import Sequence +from pathlib import Path +from typing import Any +from typing_extensions import override + +import pytest +import threadpoolctl + +from amltk import Component, Metric, Node, Trial +from amltk.optimization.history import History +from amltk.optimization.optimizers.smac import SMACOptimizer +from amltk.scheduling.scheduler import Scheduler +from amltk.scheduling.task import Task +from amltk.store import PathBucket +from amltk.types import Seed + +METRIC = Metric("acc", minimize=False, bounds=(0.0, 1.0)) + + +class _CustomError(Exception): + pass + + +def target_funtion(trial: Trial, pipeline: Node) -> Trial.Report: # noqa: ARG001 + # We don't really care here + threadpool_info = threadpoolctl.threadpool_info() + trial.summary["num_threads"] = threadpool_info[0]["num_threads"] + return trial.success(acc=0.5) + + +def test_custom_callback_used(tmp_path: Path) -> None: + def my_callback(task: Task, scheduler: Scheduler, history: History) -> None: # noqa: ARG001 + raise _CustomError() + + component = Component(object, space={"a": (0.0, 1.0)}) + + with pytest.raises(_CustomError): + component.optimize( + target_funtion, + metric=METRIC, + on_begin=my_callback, + max_trials=1, + working_dir=tmp_path, + ) + + +def test_populates_given_history(tmp_path: Path) -> None: + history = History() + component = Component(object, space={"a": (0.0, 1.0)}) + trial = Trial.create( + name="test_trial", + config={}, + bucket=PathBucket(tmp_path) / "trial", + ) + report = trial.success() + history.add(report) + + component.optimize( + target_funtion, + metric=METRIC, + history=history, + max_trials=1, + working_dir=tmp_path, + ) + + +def test_custom_create_optimizer_signature(tmp_path: Path) -> None: + component = Component(object, space={"a": (0.0, 1.0)}) + + def my_custom_optimizer_creator( + *, + space: Node, + metrics: Metric | Sequence[Metric], + bucket: PathBucket | None = None, + seed: Seed | None = None, + ) -> SMACOptimizer: + assert space is component + assert metrics is METRIC + assert bucket is not None + assert bucket.path == tmp_path + assert seed == 1 + + raise _CustomError() + + with pytest.raises(_CustomError): + component.optimize( + target_funtion, + metric=METRIC, + optimizer=my_custom_optimizer_creator, + max_trials=1, + seed=1, + working_dir=tmp_path, + ) + + +def test_history_populated_with_exactly_maximum_trials(tmp_path: Path) -> None: + component = Component(object, space={"a": (0.0, 1.0)}) + history = component.optimize( + target_funtion, + metric=METRIC, + max_trials=10, + working_dir=tmp_path, + ) + assert len(history) == 10 + + +def test_sklearn_head_triggers_triggers_threadpoolctl(tmp_path: Path) -> None: + from sklearn.ensemble import RandomForestClassifier + + info = threadpoolctl.threadpool_info() + num_threads = info[0]["num_threads"] + + component = Component(RandomForestClassifier, space={"a": (0.0, 1.0)}) + history = component.optimize( + target_funtion, + metric=METRIC, + max_trials=1, + working_dir=tmp_path, + ) + + report = history[0] + # Should have a different number of threads in there. By default 1 + assert report.summary["num_threads"] != num_threads + assert report.summary["num_threads"] == 1 + + +def test_no_sklearn_head_does_not_trigger_threadpoolctl(tmp_path: Path) -> None: + info = threadpoolctl.threadpool_info() + num_threads = info[0]["num_threads"] + + component = Component(object, space={"a": (0.0, 1.0)}) + history = component.optimize( + target_funtion, + metric=METRIC, + max_trials=1, + working_dir=tmp_path, + ) + + report = history[0] + assert report.summary["num_threads"] == num_threads + + +def test_optimizer_is_reported_to(tmp_path: Path) -> None: + class MyOptimizer(SMACOptimizer): + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.told_report: Trial.Report | None = None + super().__init__(*args, **kwargs) + + @override + def tell(self, report: Trial.Report) -> None: + self.told_report = report + return super().tell(report) + + component = Component(object, space={"a": (0.0, 1.0)}) + optimizer = MyOptimizer.create( + space=component, + metrics=METRIC, + bucket=PathBucket(tmp_path), + ) + + history = component.optimize( + target_funtion, + metric=METRIC, + optimizer=optimizer, + max_trials=1, + working_dir=tmp_path, + ) + + assert optimizer.told_report is history[0] diff --git a/tests/pytorch/__init__.py b/tests/pytorch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/pytorch/common.py b/tests/pytorch/common.py new file mode 100644 index 00000000..2606bd3d --- /dev/null +++ b/tests/pytorch/common.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from amltk import Metric, Node +from amltk.optimization.optimizers.smac import SMACOptimizer + + +def create_optimizer(pipeline: Node) -> SMACOptimizer: + """Create optimizer for the given pipeline.""" + metric = Metric("accuracy", minimize=False, bounds=(0, 1)) + return SMACOptimizer.create( + space=pipeline, + metrics=metric, + seed=1, + bucket="pytorch-experiments", + ) diff --git a/tests/pytorch/test_build_model_from_pipeline.py b/tests/pytorch/test_build_model_from_pipeline.py new file mode 100644 index 00000000..4693a364 --- /dev/null +++ b/tests/pytorch/test_build_model_from_pipeline.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import torch + +from amltk import Component, Fixed, Sequential +from amltk.pytorch import build_model_from_pipeline + + +def test_build_model_from_pipeline(): + # Define a simple pipeline for a multi-layer perceptron + pipeline = Sequential( + torch.nn.Flatten(start_dim=1), + Component( + torch.nn.Linear, + config={"in_features": 784, "out_features": 100}, + name="fc1", + ), + Fixed(torch.nn.ReLU(), name="activation"), + Component( + torch.nn.Linear, + config={"in_features": 100, "out_features": 10}, + name="fc2", + ), + torch.nn.LogSoftmax(dim=1), + name="my-mlp-pipeline", + ) + + # Build the model from the pipeline + model = build_model_from_pipeline(pipeline) + + # Verify that the model is constructed correctly + assert isinstance(model, torch.nn.Sequential) + assert len(model) == 5 # Check the number of layers in the model + + # Check the layer types and dimensions + assert isinstance(model[0], torch.nn.Flatten) + assert isinstance(model[1], torch.nn.Linear) + assert model[1].in_features == 784 + assert model[1].out_features == 100 + + assert isinstance(model[2], torch.nn.ReLU) + + assert isinstance(model[3], torch.nn.Linear) + assert model[3].in_features == 100 + assert model[3].out_features == 10 + + assert isinstance(model[4], torch.nn.LogSoftmax) + assert model[4].dim == 1 diff --git a/tests/pytorch/test_match_chosen_dimensions.py b/tests/pytorch/test_match_chosen_dimensions.py new file mode 100644 index 00000000..a41b0dd7 --- /dev/null +++ b/tests/pytorch/test_match_chosen_dimensions.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from collections.abc import Callable + +import pytest +import torch + +from amltk import ( + Choice, + Component, + Fixed, + Sequential, +) +from amltk.exceptions import MatchChosenDimensionsError +from amltk.pytorch import MatchChosenDimensions, build_model_from_pipeline +from tests.pytorch.common import create_optimizer + + +class TestMatchChosenDimensions: + @pytest.fixture(scope="class") + def common_pipeline(self) -> Callable[..., Sequential]: + def _create_pipeline(choices: dict) -> Sequential: + # Define a pipeline with a Choice class + return Sequential( + Choice( + Sequential( + torch.nn.Linear(in_features=10, out_features=20), + name="choice1", + ), + Sequential( + torch.nn.Linear(in_features=5, out_features=10), + name="choice2", + ), + name="my_choice", + ), + Component( + torch.nn.Linear, + config={ + "in_features": MatchChosenDimensions( + choice_name="my_choice", + choices=choices, + ), + "out_features": 30, + }, + name="fc1", + ), + Choice(torch.nn.ReLU(), torch.nn.Sigmoid(), name="activation"), + Component( + torch.nn.Linear, + config={"in_features": 30, "out_features": 10}, + name="fc2", + ), + Fixed(torch.nn.LogSoftmax(dim=1), name="log_softmax"), + name="my-pipeline", + ) + + return _create_pipeline + + def test_valid_pipeline(self, common_pipeline: Callable[..., Sequential]) -> None: + valid_pipeline = common_pipeline(choices={"choice1": 20, "choice2": 10}) + + optimizer = create_optimizer(valid_pipeline) + trial = optimizer.ask() + model = valid_pipeline.configure(trial.config).build( + builder=build_model_from_pipeline, + ) + + # Verify that the model is constructed correctly + assert isinstance(model, torch.nn.Sequential) + + # Conditional check for the Choice node + assert model[0].out_features == model[1].in_features + + assert model[1].out_features == 30 + + assert isinstance(model[2], torch.nn.ReLU | torch.nn.Sigmoid) + + assert isinstance(model[4], torch.nn.LogSoftmax) + assert model[4].dim == 1 + + def test_invalid_pipeline(self, common_pipeline: Callable[..., Sequential]) -> None: + # Modify the common pipeline to create a pipeline with invalid choices + invalid_pipeline = common_pipeline(choices={"choice123": 123, "choice321": 321}) + + optimizer = create_optimizer(invalid_pipeline) + trial = optimizer.ask() + + with pytest.raises(MatchChosenDimensionsError): + invalid_pipeline.configure(trial.config).build( + builder=build_model_from_pipeline, + ) diff --git a/tests/pytorch/test_match_dimensions.py b/tests/pytorch/test_match_dimensions.py new file mode 100644 index 00000000..98372d60 --- /dev/null +++ b/tests/pytorch/test_match_dimensions.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import pytest +import torch + +from amltk import Component, Fixed, Node, Sequential +from amltk.exceptions import MatchDimensionsError +from amltk.pytorch import MatchDimensions, build_model_from_pipeline +from tests.pytorch.common import create_optimizer + + +class TestMatchDimensions: + @pytest.fixture(scope="class") + def valid_pipeline(self) -> Node: + """Fixture to create a valid pipeline.""" + return Sequential( + Component( + torch.nn.Linear, + config={"in_features": 784}, + space={"out_features": (64, 128)}, + name="fc1", + ), + Fixed(torch.nn.Sigmoid(), name="activation"), + Component( + torch.nn.Linear, + config={ + "in_features": MatchDimensions("fc1", param="out_features"), + "out_features": MatchDimensions("fc3", param="in_features"), + }, + name="fc2", + ), + Component( + torch.nn.Linear, + space={"in_features": (128, 256)}, + config={"out_features": 10}, + name="fc3", + ), + Fixed(torch.nn.LogSoftmax(dim=1), name="log_softmax"), + name="my-pipeline", + ) + + @pytest.fixture( + scope="class", + params=[ + MatchDimensions("non-existing-layer", param="out_features"), + MatchDimensions("fc1", param="non-existing-param"), + MatchDimensions("fc1", param=None), # type: ignore + MatchDimensions(layer_name="", param="out_features"), + ], + ) + def invalid_pipeline(self, request: pytest.FixtureRequest) -> Node: + """Fixture to create several invalid pipelines.""" + return Sequential( + torch.nn.Flatten(start_dim=1), + Component( + torch.nn.Linear, + config={"in_features": 784}, + space={"out_features": (64, 128)}, + name="fc1", + ), + Fixed(torch.nn.Sigmoid(), name="activation"), + Component( + torch.nn.Linear, + config={ + "in_features": request.param, + "out_features": 10, + }, + name="fc2", + ), + Fixed(torch.nn.LogSoftmax(dim=1), name="log_softmax"), + name="my-pipeline", + ) + + def test_match_dimensions_valid(self, valid_pipeline: Node) -> None: + """Test for valid pipeline.""" + optimizer = create_optimizer(valid_pipeline) + trial = optimizer.ask() + model = valid_pipeline.configure(trial.config).build( + builder=build_model_from_pipeline, + ) + + assert isinstance(model, torch.nn.Sequential) + assert len(model) == 5 + + assert isinstance(model[0], torch.nn.Linear) + assert model[0].in_features == 784 + + assert isinstance(model[1], torch.nn.Sigmoid) + + assert isinstance(model[2], torch.nn.Linear) + assert model[2].in_features == model[0].out_features + assert model[2].out_features == model[3].in_features + + assert isinstance(model[3], torch.nn.Linear) + assert model[3].out_features == 10 + + def test_match_dimensions_invalid(self, invalid_pipeline: Node) -> None: + """Test for invalid pipeline.""" + optimizer = create_optimizer(invalid_pipeline) + trial = optimizer.ask() + + with pytest.raises((MatchDimensionsError, KeyError)): + invalid_pipeline.configure(trial.config).build( + builder=build_model_from_pipeline, + ) diff --git a/tests/scheduling/plugins/test_call_limiter_plugin.py b/tests/scheduling/plugins/test_call_limiter_plugin.py index e0dfeb47..c63628bf 100644 --- a/tests/scheduling/plugins/test_call_limiter_plugin.py +++ b/tests/scheduling/plugins/test_call_limiter_plugin.py @@ -55,6 +55,8 @@ def case_loky_executor() -> ProcessPoolExecutor: @parametrize_with_cases("executor", cases=".", has_tag="executor") def scheduler(executor: Executor) -> Iterator[Scheduler]: yield Scheduler(executor) + if isinstance(executor, ClientExecutor): + executor._client.close() def time_wasting_function(duration: int) -> int: diff --git a/tests/scheduling/plugins/test_comm_plugin.py b/tests/scheduling/plugins/test_comm_plugin.py index 75f16e66..334159b3 100644 --- a/tests/scheduling/plugins/test_comm_plugin.py +++ b/tests/scheduling/plugins/test_comm_plugin.py @@ -18,19 +18,20 @@ logger = logging.getLogger(__name__) -def sending_worker(comm: Comm, replies: list[Any]) -> None: +def sending_worker(replies: list[Any], *, comm: Comm | None = None) -> None: """A worker that responds to messages. Args: comm: The communication channel to use. replies: A list of replies to send to the client. """ + assert comm is not None with comm.open(): for reply in replies: comm.send(reply) -def requesting_worker(comm: Comm, requests: list[Any]) -> None: +def requesting_worker(requests: list[Any], *, comm: Comm | None = None) -> None: """A worker that waits for messages. This will send a request, waiting for a response, finally @@ -41,6 +42,7 @@ def requesting_worker(comm: Comm, requests: list[Any]) -> None: comm: The communication channel to use. requests: A list of requests to receive from the client. """ + assert comm is not None with comm.open(): for request in requests: response = comm.request(request) @@ -90,6 +92,8 @@ def case_dask_executor() -> ClientExecutor: @parametrize_with_cases("executor", cases=".", has_tag="executor") def scheduler(executor: Executor) -> Iterator[Scheduler]: yield Scheduler(executor) + if isinstance(executor, ClientExecutor): + executor._client.close() def test_sending_worker(scheduler: Scheduler) -> None: diff --git a/tests/scheduling/plugins/test_pynisher_plugin.py b/tests/scheduling/plugins/test_pynisher_plugin.py index 91156285..00aab12e 100644 --- a/tests/scheduling/plugins/test_pynisher_plugin.py +++ b/tests/scheduling/plugins/test_pynisher_plugin.py @@ -6,6 +6,7 @@ from collections import Counter from collections.abc import Iterator from concurrent.futures import Executor, ProcessPoolExecutor +from pathlib import Path import pytest from dask.distributed import Client, LocalCluster, Worker @@ -52,6 +53,8 @@ def case_dask_executor() -> ClientExecutor: @parametrize_with_cases("executor", cases=".", has_tag="executor") def scheduler(executor: Executor) -> Iterator[Scheduler]: yield Scheduler(executor) + if isinstance(executor, ClientExecutor): + executor._client.close() def big_memory_function(mem_in_bytes: int) -> bytearray: @@ -60,20 +63,11 @@ def big_memory_function(mem_in_bytes: int) -> bytearray: def trial_with_big_memory(trial: Trial, mem_in_bytes: int) -> Trial.Report: - with trial.begin(): - pass - - # We're particularly interested when the memory error happens during the - # task execution, not during the trial begin period big_memory_function(mem_in_bytes) - return trial.success() def trial_with_time_wasting(trial: Trial, duration: int) -> Trial.Report: - with trial.begin(): - time_wasting_function(duration) - time_wasting_function(duration) return trial.success() @@ -275,7 +269,7 @@ def start_task() -> None: assert isinstance(end_status.exception, PynisherPlugin.WallTimeoutException) -def test_trial_gets_autodetect_memory(scheduler: Scheduler) -> None: +def test_trial_gets_autodetect_memory(scheduler: Scheduler, tmp_path: Path) -> None: if not PynisherPlugin.supports("memory"): pytest.skip("Pynisher does not support memory limits on this system") @@ -288,7 +282,7 @@ def test_trial_gets_autodetect_memory(scheduler: Scheduler) -> None: disable_trial_handling=False, ), ) - trial = Trial(name="test_trial", config={}) + trial = Trial.create(name="test_trial", config={}, bucket=tmp_path / "trial") @scheduler.on_start def start_task() -> None: @@ -323,7 +317,7 @@ def trial_report(_, report: Trial.Report) -> None: assert isinstance(reports[0].exception, PynisherPlugin.MemoryLimitException) -def test_trial_gets_autodetect_time(scheduler: Scheduler) -> None: +def test_trial_gets_autodetect_time(scheduler: Scheduler, tmp_path: Path) -> None: if not PynisherPlugin.supports("wall_time"): pytest.skip("Pynisher does not support wall_time limits on this system") @@ -334,7 +328,7 @@ def test_trial_gets_autodetect_time(scheduler: Scheduler) -> None: disable_trial_handling=False, ), ) - trial = Trial(name="test_trial", config={}) + trial = Trial.create(name="test_trial", config={}, bucket=tmp_path / "trial") @scheduler.on_start def start_task() -> None: diff --git a/tests/scheduling/plugins/test_threadpoolctl_plugin.py b/tests/scheduling/plugins/test_threadpoolctl_plugin.py index 7ae169d4..7e6925c6 100644 --- a/tests/scheduling/plugins/test_threadpoolctl_plugin.py +++ b/tests/scheduling/plugins/test_threadpoolctl_plugin.py @@ -4,6 +4,7 @@ import sys import warnings from collections import Counter +from collections.abc import Iterator from concurrent.futures import Executor, ProcessPoolExecutor from typing import Any @@ -61,8 +62,10 @@ def case_loky_executor() -> ProcessPoolExecutor: @fixture(scope="function") @parametrize_with_cases("executor", cases=".", has_tag="executor") -def scheduler(executor: Executor) -> Scheduler: - return Scheduler(executor) +def scheduler(executor: Executor) -> Iterator[Scheduler]: + yield Scheduler(executor) + if isinstance(executor, ClientExecutor): + executor._client.close() def f() -> list[Any]: diff --git a/tests/scheduling/test_scheduler.py b/tests/scheduling/test_scheduler.py index 60025ef0..22aabd46 100644 --- a/tests/scheduling/test_scheduler.py +++ b/tests/scheduling/test_scheduler.py @@ -5,6 +5,7 @@ import warnings from asyncio import Future from collections import Counter +from collections.abc import Iterator from concurrent.futures import Executor, ProcessPoolExecutor, ThreadPoolExecutor from typing import TYPE_CHECKING @@ -82,8 +83,10 @@ def case_sequential_executor() -> SequentialExecutor: @fixture(scope="function") @parametrize_with_cases("executor", cases=".", has_tag="executor") -def scheduler(executor: Executor) -> Scheduler: - return Scheduler(executor) +def scheduler(executor: Executor) -> Iterator[Scheduler]: + yield Scheduler(executor) + if isinstance(executor, ClientExecutor): + executor._client.close() def test_scheduler_with_timeout_and_wait_for_tasks(scheduler: Scheduler) -> None: diff --git a/tests/sklearn/test_evaluation.py b/tests/sklearn/test_evaluation.py new file mode 100644 index 00000000..e4c59701 --- /dev/null +++ b/tests/sklearn/test_evaluation.py @@ -0,0 +1,1006 @@ +from __future__ import annotations + +import warnings +from collections.abc import Mapping +from dataclasses import dataclass +from functools import cache +from pathlib import Path +from typing import Any, Literal + +import numpy as np +import pandas as pd +import pytest +import sklearn.datasets +import sklearn.pipeline +from pytest_cases import case, parametrize, parametrize_with_cases +from sklearn import config_context as sklearn_config_context +from sklearn.base import check_is_fitted +from sklearn.cluster import KMeans +from sklearn.datasets import make_classification, make_regression +from sklearn.dummy import DummyClassifier +from sklearn.metrics import get_scorer, make_scorer +from sklearn.metrics._scorer import _Scorer +from sklearn.model_selection import ( + GroupKFold, + KFold, + ShuffleSplit, + StratifiedKFold, + StratifiedShuffleSplit, +) +from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor + +from amltk.exceptions import TaskTypeWarning, TrialError +from amltk.optimization.trial import Metric, Trial +from amltk.pipeline import Component, Node, request +from amltk.pipeline.builders.sklearn import build as sklearn_pipeline_builder +from amltk.sklearn.evaluation import ( + CVEvaluation, + ImplicitMetricConversionWarning, + TaskTypeName, + _default_cv_resampler, + _default_holdout, + identify_task_type, +) + + +# NOTE: We can cache this as it doesn't get changed +@cache +def data_for_task_type(task_type: TaskTypeName) -> tuple[np.ndarray, np.ndarray]: + match task_type: + case "binary": + return make_classification( + random_state=42, + n_samples=20, + n_classes=2, + n_informative=3, + ) # type: ignore + case "multiclass": + return make_classification( + random_state=42, + n_samples=20, + n_classes=4, + n_informative=3, + ) # type: ignore + case "multilabel-indicator": + x, y = make_classification( + random_state=42, + n_samples=20, + n_classes=2, + n_informative=3, + ) + y = np.vstack([y, y]).T + return x, y # type: ignore + case "multiclass-multioutput": + x, y = make_classification( + random_state=42, + n_samples=20, + n_classes=4, + n_informative=3, + ) + y = np.vstack([y, y, y]).T + return x, y # type: ignore + case "continuous": + return make_regression(random_state=42, n_samples=20, n_targets=1) # type: ignore + case "continuous-multioutput": + return make_regression(random_state=42, n_samples=20, n_targets=2) # type: ignore + + raise ValueError(f"Unknown task type {task_type}") + + +def _sample_y(task_type: TaskTypeName) -> np.ndarray: + return data_for_task_type(task_type)[1] + + +@parametrize( + "real, task_hint, expected", + [ + ("binary", "auto", "binary"), + ("binary", "classification", "binary"), + ("binary", "regression", "continuous"), + # + ("multiclass", "auto", "multiclass"), + ("multiclass", "classification", "multiclass"), + ("multiclass", "regression", "continuous"), + # + ("multilabel-indicator", "auto", "multilabel-indicator"), + ("multilabel-indicator", "classification", "multilabel-indicator"), + ("multilabel-indicator", "regression", "continuous-multioutput"), + # + ("multiclass-multioutput", "auto", "multiclass-multioutput"), + ("multiclass-multioutput", "classification", "multiclass-multioutput"), + ("multiclass-multioutput", "regression", "continuous-multioutput"), + # + ("continuous", "auto", "continuous"), + ("continuous", "classification", "multiclass"), + ("continuous", "regression", "continuous"), + # + ("continuous-multioutput", "auto", "continuous-multioutput"), + ("continuous-multioutput", "classification", "multiclass-multioutput"), + ("continuous-multioutput", "regression", "continuous-multioutput"), + ], +) +def test_identify_task_type( + real: TaskTypeName, + task_hint: Literal["classification", "regression", "auto"], + expected: TaskTypeName, +) -> None: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=TaskTypeWarning) + + if real == "continuous-multioutput" and task_hint == "classification": + # Special case since we have to check when the y with multiple values. + A, B, C = 0.1, 2.1, 3.1 + + # <=2 unique values per column + y = np.array( + [ + [A, B], + [A, B], + [A, C], + ], + ) + + identified = identify_task_type(y, task_hint=task_hint) + assert identified == "multilabel-indicator" + + # >2 unique values per column + y = np.array( + [ + [A, A], + [B, B], + [C, C], + ], + ) + identified = identify_task_type(y, task_hint=task_hint) + assert identified == "multiclass-multioutput" + + elif real == "continuous" and task_hint == "classification": + # Special case since we have to check when the y with multiple values. + A, B, C = 0.1, 2.1, 3.1 + + y = np.array([A, A, B, B]) + identified = identify_task_type(y, task_hint=task_hint) + assert identified == "binary" + + y = np.array([A, B, C]) + identified = identify_task_type(y, task_hint=task_hint) + assert identified == "multiclass" + else: + y = _sample_y(expected) + identified = identify_task_type(y, task_hint=task_hint) + + assert identified == expected + + +@parametrize( + "task_type, expected", + [ + # Holdout + ("binary", StratifiedShuffleSplit), + ("multiclass", StratifiedShuffleSplit), + ("multilabel-indicator", ShuffleSplit), + ("multiclass-multioutput", ShuffleSplit), + ("continuous", ShuffleSplit), + ("continuous-multioutput", ShuffleSplit), + ], +) +def test_default_holdout(task_type: TaskTypeName, expected: type) -> None: + sampler = _default_holdout(task_type, holdout_size=0.387, random_state=42) + assert isinstance(sampler, expected) + assert sampler.n_splits == 1 # type: ignore + assert sampler.random_state == 42 # type: ignore + assert sampler.test_size == 0.387 # type: ignore + + +@parametrize( + "task_type, expected", + [ + # CV - Notable, only binary and multiclass can be stratified + ("binary", StratifiedKFold), + ("multiclass", StratifiedKFold), + ("multilabel-indicator", KFold), + ("multiclass-multioutput", KFold), + ("continuous", KFold), + ("continuous-multioutput", KFold), + ], +) +def test_default_resampling(task_type: TaskTypeName, expected: type) -> None: + sampler = _default_cv_resampler(task_type, n_splits=2, random_state=42) + assert isinstance(sampler, expected) + assert sampler.n_splits == 2 # type: ignore + assert sampler.random_state == 42 # type: ignore + + +@dataclass +class _EvalKwargs: + trial: Trial + pipeline: Component + additional_scorers: Mapping[str, _Scorer] | None + params: Mapping[str, Any] | None + task_type: TaskTypeName + working_dir: Path + X: pd.DataFrame | np.ndarray + y: pd.Series | np.ndarray | pd.DataFrame + + +@case +@parametrize( + "metric", + [ + # Single ob + Metric("accuracy", minimize=False, bounds=(0, 1)), + # Mutli obj + [ + Metric("custom", minimize=False, bounds=(0, 1), fn=get_scorer("accuracy")), + Metric("roc_auc_ovr", minimize=False, bounds=(0, 1)), + ], + ], +) +@parametrize( + "additional_scorers", + [ + None, + {"acc": get_scorer("accuracy"), "roc": get_scorer("roc_auc_ovr")}, + ], +) +@parametrize( + "task_type", + ["binary", "multiclass", "multilabel-indicator"], +) +def case_classification( + tmp_path: Path, + metric: Metric | list[Metric], + additional_scorers: Mapping[str, _Scorer] | None, + task_type: TaskTypeName, +) -> _EvalKwargs: + x, y = data_for_task_type(task_type) + return _EvalKwargs( + trial=Trial.create( + name="test", + config={}, + seed=42, + bucket=tmp_path / "trial", + metrics=metric, + ), + task_type=task_type, + pipeline=Component(DecisionTreeClassifier, config={"max_depth": 1}), + additional_scorers=additional_scorers, + params=None, + working_dir=tmp_path / "data", + X=x, + y=y, + ) + + +@case +@parametrize( + "metric", + [ + # Single ob + Metric("neg_mean_absolute_error", minimize=True, bounds=(-np.inf, 0)), + # Mutli obj + [ + Metric("custom", minimize=False, bounds=(-np.inf, 1), fn=get_scorer("r2")), + Metric("neg_mean_squared_error", minimize=False, bounds=(-np.inf, 0)), + ], + ], +) +@parametrize( + "additional_scorers", + [ + None, + { + "rmse": get_scorer("neg_root_mean_squared_error"), + "err": get_scorer("neg_mean_absolute_error"), + }, + ], +) +@parametrize("task_type", ["continuous", "continuous-multioutput"]) +def case_regression( + tmp_path: Path, + metric: Metric | list[Metric], + additional_scorers: Mapping[str, _Scorer] | None, + task_type: TaskTypeName, +) -> _EvalKwargs: + x, y = data_for_task_type(task_type) + + return _EvalKwargs( + trial=Trial.create( + name="test", + config={}, + seed=42, + bucket=tmp_path / "trial", + metrics=metric, + ), + pipeline=Component(DecisionTreeRegressor, config={"max_depth": 1}), + additional_scorers=additional_scorers, + task_type=task_type, + params=None, + working_dir=tmp_path / "data", + X=x, + y=y, + ) + + +@parametrize("as_pd", [True, False]) +@parametrize("store_models", [True, False]) +@parametrize("train_score", [True, False]) +@parametrize("test_data", [True, False]) +@parametrize_with_cases("item", cases=".", prefix="case_") +@parametrize("cv_value, splitter", [(2, "cv"), (0.3, "holdout")]) +def test_evaluator( # noqa: PLR0912 + as_pd: bool, + store_models: bool, + train_score: bool, + item: _EvalKwargs, + test_data: bool, + cv_value: int | float, + splitter: str, +) -> None: + x = pd.DataFrame(item.X) if as_pd else item.X + y = ( + item.y + if not as_pd + else (pd.DataFrame(item.y) if np.ndim(item.y) > 1 else pd.Series(item.y)) + ) + trial = item.trial + if splitter == "cv": + cv_kwargs = {"n_splits": cv_value, "splitter": "cv"} + else: + cv_kwargs = {"holdout_size": cv_value, "splitter": "holdout"} + + x_test = None + y_test = None + if test_data: + x_test = x.iloc[:20] if isinstance(x, pd.DataFrame) else x[:20] + y_test = y.iloc[:20] if isinstance(y, pd.DataFrame | pd.Series) else y[:20] + + evaluator = CVEvaluation( + X=x, + y=y, + X_test=x_test, + y_test=y_test, + working_dir=item.working_dir, + train_score=train_score, + store_models=store_models, + params=item.params, + additional_scorers=item.additional_scorers, + task_hint=item.task_type, + random_state=42, + on_error="raise", + **cv_kwargs, # type: ignore + ) + n_splits = evaluator.splitter.get_n_splits(x, y) + assert n_splits is not None + + report = evaluator.fn(trial, item.pipeline) + + # ------- Property testing + + # Model should be stored + if store_models: + for i in range(n_splits): + assert f"model_{i}.pkl" in report.storage + + # All metrics should be recorded and valid + for metric_name, metric in trial.metrics.items(): + assert metric_name in report.values + value = report.values[metric_name] + # ... in correct bounds + if metric.bounds is not None: + assert metric.bounds[0] <= value <= metric.bounds[1] + + # Summary should contain all optimization metrics + expected_summary_scorers = [ + *trial.metrics.keys(), + *(item.additional_scorers.keys() if item.additional_scorers else []), + ] + for metric_name in expected_summary_scorers: + for i in range(n_splits): + assert f"split_{i}:val_{metric_name}" in report.summary + assert f"val_mean_{metric_name}" in report.summary + assert f"val_std_{metric_name}" in report.summary + + if train_score: + assert "cv:train_score" in report.profiles + for metric_name in expected_summary_scorers: + for i in range(n_splits): + assert f"split_{i}:train_{metric_name}" in report.summary + assert f"train_mean_{metric_name}" in report.summary + assert f"train_std_{metric_name}" in report.summary + + if test_data: + assert "cv:test_score" in report.profiles + for metric_name in expected_summary_scorers: + for i in range(n_splits): + assert f"split_{i}:test_{metric_name}" in report.summary + assert f"test_mean_{metric_name}" in report.summary + assert f"test_std_{metric_name}" in report.summary + + # All folds are profiled + assert "cv" in report.profiles + for i in range(n_splits): + assert f"cv:split_{i}" in report.profiles + + +@parametrize( + "task_type", + [ + "binary", + "multiclass", + "multilabel-indicator", + "multiclass-multioutput", + "continuous", + "continuous-multioutput", + ], +) +@parametrize("cv_value, splitter", [(2, "cv"), (0.3, "holdout")]) +def test_consistent_results_across_seeds( + tmp_path: Path, + cv_value: int | float, + splitter: Literal["cv", "holdout"], + task_type: TaskTypeName, +) -> None: + x, y = data_for_task_type(task_type) + match task_type: + case "binary" | "multiclass" | "multilabel-indicator": + pipeline = Component( + DecisionTreeClassifier, + config={"max_depth": 1, "random_state": request("random_state")}, + ) + metric = Metric("accuracy", minimize=False, bounds=(0, 1)) + case "continuous" | "continuous-multioutput": + pipeline = Component( + DecisionTreeRegressor, + config={"max_depth": 1, "random_state": request("random_state")}, + ) + metric = Metric("r2", minimize=True, bounds=(-np.inf, 1)) + case "multiclass-multioutput": + pipeline = Component( + DecisionTreeClassifier, + config={"max_depth": 1, "random_state": request("random_state")}, + ) + # Sklearn doesn't have any multiclass-multioutput metrics + metric = Metric( + "custom", + minimize=False, + bounds=(0, 1), + fn=lambda y_true, y_pred: (y_pred == y_true).mean().mean(), + ) + + if splitter == "cv": + cv_kwargs = {"n_splits": cv_value, "splitter": "cv"} + else: + cv_kwargs = {"holdout_size": cv_value, "splitter": "holdout"} + + evaluator_1 = CVEvaluation( + X=x, + y=y, + working_dir=tmp_path, + random_state=42, + train_score=True, + store_models=False, + task_hint=task_type, + params=None, + on_error="raise", + **cv_kwargs, # type: ignore + ) + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=ImplicitMetricConversionWarning) + + report_1 = evaluator_1.fn( + Trial.create( + name="trial-name", + config={}, + seed=50, + bucket=tmp_path / "trial-name", + metrics=metric, + ), + pipeline, + ) + + # Make sure to clean up the bucket for the second + # trial as it will have the same name + report_1.bucket.rmdir() + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=ImplicitMetricConversionWarning) + + report_2 = evaluator_1.fn( + Trial.create( + name="trial-name", + config={}, + seed=50, + bucket=tmp_path / "trial-name", # We give a different dir + metrics=metric, + ), + pipeline, + ) + + # We ignore profiles because they will be different timings + # We ignore trial.created_at and report.reported_at as they will naturally + # be different + df_1 = report_1.df(profiles=False).drop(columns=["reported_at", "created_at"]) + df_2 = report_2.df(profiles=False).drop(columns=["reported_at", "created_at"]) + pd.testing.assert_frame_equal(df_1, df_2) + + +def test_scoring_params_get_forwarded(tmp_path: Path) -> None: + with sklearn_config_context(enable_metadata_routing=True): + pipeline = Component(DecisionTreeClassifier, config={"max_depth": 1}) + x, y = data_for_task_type("binary") + + # This custom metrics requires a custom parameter + def custom_metric( + y_true: np.ndarray, # noqa: ARG001 + y_pred: np.ndarray, # noqa: ARG001 + *, + scorer_param_required: float, + ): + return scorer_param_required + + custom_scorer = ( + make_scorer(custom_metric, response_method="predict") + # Here we specify that it needs this parameter routed to it + .set_score_request(scorer_param_required=True) + ) + + value = 0.123 + evaluator = CVEvaluation( + x, + y, + params={"scorer_param_required": value}, # Pass it in here + working_dir=tmp_path, + on_error="raise", + ) + trial = Trial.create( + name="test", + bucket=tmp_path / "trial", + metrics=Metric(name="custom_metric", fn=custom_scorer), + ) + report = evaluator.fn(trial, pipeline) + + assert report.values["custom_metric"] == value + + +def test_splitter_params_get_forwarded(tmp_path: Path) -> None: + with sklearn_config_context(enable_metadata_routing=True): + # A DecisionTreeClassifier by default allows for sample_weight as a parameter + # request + pipeline = Component(DecisionTreeClassifier, config={"max_depth": 1}) + x, y = data_for_task_type("binary") + + # Make a group which is half 0 and half 1 + _half = len(x) // 2 + fake_groups = np.asarray([0] * _half + [1] * (len(x) - _half)) + + trial = Trial.create(name="test", bucket=tmp_path / "trial") + + # First make sure it errors if groups is not provided to the splitter + evaluator = CVEvaluation( + x, + y, + # params={"groups": fake_groups}, # noqa: ERA001 + splitter=GroupKFold(n_splits=2), + working_dir=tmp_path, + on_error="raise", + ) + with pytest.raises( + TrialError, + match=r"The 'groups' parameter should not be None.", + ): + evaluator.fn(trial, pipeline) + + # Now make sure it works + evaluator = CVEvaluation( + x, + y, + splitter=GroupKFold(n_splits=2), # We specify a group splitter + params={"groups": fake_groups}, # Pass it in here + working_dir=tmp_path, + on_error="raise", + ) + evaluator.fn(trial, pipeline) + + +def test_estimator_params_get_forward(tmp_path: Path) -> None: + with sklearn_config_context(enable_metadata_routing=True): + # NOTE: There is no way to explcitly check that metadata was indeed + # routed to the estimator, e.g. through an error. Please see this + # issue + # https://github.com/scikit-learn/scikit-learn/issues/23920 + + # We'll test this using the DummyClassifier with a Prior config. + # Thankfully this is deterministic so it's attributes_ should + # only get modified based on it's input. + # One attribute_ that gets modified depending on sample_weight + # is estimator.class_prior_ which we can check pretty easily. + x, y = data_for_task_type("binary") + sample_weight = np.random.rand(len(x)) # noqa: NPY002 + + def create_dummy_classifier_with_sample_weight_request( + *args: Any, + **kwargs: Any, + ) -> DummyClassifier: + est = DummyClassifier(*args, **kwargs) + # https://scikit-learn.org/stable/metadata_routing.html#api-interface + est.set_fit_request(sample_weight=True) + return est + + pipeline = Component( + create_dummy_classifier_with_sample_weight_request, + config={"strategy": "prior"}, + ) + + # First we use an evaluator without sample_weight + trial = Trial.create(name="test", bucket=tmp_path / "trial_1") + evaluator = CVEvaluation( + x, + y, + holdout_size=0.3, + working_dir=tmp_path, + store_models=True, + # params={"sample_weight": sample_weight}, # noqa: ERA001 + on_error="raise", + ) + report = evaluator.fn(trial, pipeline) + + # load pipeline, get 0th model, get it's class_prior_ + class_weights_1 = report.retrieve("model_0.pkl")[0].class_prior_ + + # To make sure that our tests are correct, we repeat this without + # sample weights, should remain the same + trial = Trial.create(name="test", bucket=tmp_path / "trial_2") + report = evaluator.fn(trial, pipeline) + class_weights_2 = report.retrieve("model_0.pkl")[0].class_prior_ + + np.testing.assert_array_equal(class_weights_1, class_weights_2) + + # Now with the sample weights, the class_prior_ should be different + trial = Trial.create(name="test", bucket=tmp_path / "trial_3") + evaluator = CVEvaluation( + x, + y, + holdout_size=0.3, + working_dir=tmp_path, + store_models=True, + params={"sample_weight": sample_weight}, # Passed in this time + on_error="raise", + ) + report = evaluator.fn(trial, pipeline) + class_weights_3 = report.retrieve("model_0.pkl")[0].class_prior_ + + with pytest.raises(AssertionError): + np.testing.assert_array_equal(class_weights_1, class_weights_3) + + +def test_evaluator_with_clustering(tmp_path: Path) -> None: + x, y = sklearn.datasets.make_blobs( # type: ignore + n_samples=20, + centers=2, + n_features=2, + random_state=42, + ) + pipeline = Component(KMeans, config={"n_clusters": 2, "random_state": 42}) + + metrics = Metric("adjusted_rand_score", minimize=False, bounds=(-0.5, 1)) + trial = Trial.create(name="test", bucket=tmp_path / "trial", metrics=metrics) + + evaluator = CVEvaluation( + x, # type: ignore + y, # type: ignore + working_dir=tmp_path, + on_error="raise", + random_state=42, + ) + report = evaluator.fn(trial, pipeline) + + # We are not really trying to detect the score of the algorithm, just to ensure + # that it did indeed train with the data and does not error. + # If it seems to get a slightly less score than 1.0 then that's okay, + # just change the value. Should not change due to the seeding but + # make sklearn changes something + assert "adjusted_rand_score" in report.values + assert report.values["adjusted_rand_score"] == pytest.approx(1.0) + + +def test_custom_configure_gets_forwarded(tmp_path: Path) -> None: + with sklearn_config_context(enable_metadata_routing=True): + # Pipeline requests a max_depth, defaulting to 1 + pipeline = Component( + DecisionTreeClassifier, + config={ + "max_depth": request("max_depth", default=1), + }, + ) + + # We pass in explicitly to configure with 2 + # This can be useful for estimators that require explicit information + # about the dataset + configure_params = {"max_depth": 2} + + x, y = data_for_task_type("binary") + evaluator = CVEvaluation( + x, + y, + params={"configure": configure_params}, + working_dir=tmp_path, + splitter="holdout", + holdout_size=0.3, + store_models=True, + on_error="raise", + ) + trial = Trial.create( + name="test", + bucket=tmp_path / "trial", + metrics=Metric("accuracy"), + ) + report = evaluator.fn(trial, pipeline) + model = report.retrieve("model_0.pkl")[0] + assert model.max_depth == 2 + + +# Used in the test below +class _MyPipeline(sklearn.pipeline.Pipeline): + # Have to explcitiyl list out all parameters by sklearn API + def __init__( + self, + steps: Any, + *, + memory: None = None, + verbose: bool = False, + bamboozled: str = "no", + ): + super().__init__(steps, memory=memory, verbose=verbose) + self.bamboozled = bamboozled + + +# Used in test below, builds one of the +# _MyPipeline with a custom parameter that +# will also get passed in +def _my_custom_builder( + *args: Any, + bamboozled: str = "no", + **kwargs: Any, +) -> _MyPipeline: + return sklearn_pipeline_builder( + *args, + pipeline_type=_MyPipeline, + bamboozled=bamboozled, + **kwargs, + ) + + +def test_custom_builder_can_be_forwarded(tmp_path: Path) -> None: + with sklearn_config_context(enable_metadata_routing=True): + pipeline = Component(DecisionTreeClassifier, config={"max_depth": 1}) + + x, y = data_for_task_type("binary") + evaluator = CVEvaluation( + x, + y, + params={"build": {"builder": _my_custom_builder, "bamboozled": "yes"}}, + working_dir=tmp_path, + store_models=True, + on_error="raise", + ) + trial = Trial.create( + name="test", + bucket=tmp_path / "trial", + metrics=Metric("accuracy"), + ) + + report = evaluator.fn(trial, pipeline) + model = report.retrieve("model_0.pkl") + assert isinstance(model, _MyPipeline) + assert hasattr(model, "bamboozled") + assert model.bamboozled == "yes" + + +def test_early_stopping_plugin(tmp_path: Path) -> None: + pipeline = Component(DecisionTreeClassifier, space={"max_depth": (1, 10)}) + x, y = data_for_task_type("binary") + evaluator = CVEvaluation( + x, + y, + splitter="cv", + n_splits=2, # Notably 2 folds + working_dir=tmp_path, + ) + + @dataclass + class CVEarlyStopper: + def update(self, report: Trial.Report) -> None: + pass # Normally you would update w.r.t. a finished trial + + def should_stop( + self, + trial: Trial, # noqa: ARG002 + scores: CVEvaluation.SplitScores, # noqa: ARG002 + ) -> bool: + # Just say yes, should stop + return True + + history = pipeline.optimize( + target=evaluator.fn, + metric=Metric("accuracy", minimize=False, bounds=(0, 1)), + working_dir=tmp_path, + plugins=[evaluator.cv_early_stopping_plugin(strategy=CVEarlyStopper())], + max_trials=1, + on_trial_exception="continue", + ) + assert len(history) == 1 + report = history.reports[0] + + assert report.status is Trial.Status.FAIL + assert not any(report.values) + assert report.exception is not None + assert "Early stop" in str(report.exception) + + # Only the first fold should have been run and put in summary + assert "split_0:val_accuracy" in report.summary + assert "split_1:val_accuracy" not in report.summary + + +def test_that_test_scorer_params_can_be_forwarded(tmp_path: Path) -> None: + """Not the biggest fan of this test, apologies. + + Main concerns are just to ensure that the correct parameters get forwarded + to `custom_metric` and that the data used in the test remains to be in the assumed + state. + """ + with sklearn_config_context(enable_metadata_routing=True): + pipeline = Component(DecisionTreeClassifier, config={"max_depth": 1}) + + x, y = data_for_task_type("binary") + + # We do some sanity checking that this test is doing what's + # intended and doesn't silently break, namely we want to ensure that + # the scorer gets two different sized inputs, one for the splits + # themselves and one for th test data. This assumption is required + # for the test to work + N_SPLITS = 2 + assert len(x) % N_SPLITS == 0, "Need to have equal sized splits" + + EXPECTED_FOLD_SIZE = len(x) // N_SPLITS + TEST_SIZE = 2 + assert EXPECTED_FOLD_SIZE != TEST_SIZE, "Test size and fold size matched" + + x_test, y_test = x[:TEST_SIZE], y[:TEST_SIZE] + + def custom_metric( + y_true: np.ndarray, + y_pred: np.ndarray, + *, + data_independant: float, # e.g. pos_label + data_dependant: np.ndarray, # e.g. sample_weight + ): + assert len(data_dependant) in (EXPECTED_FOLD_SIZE, TEST_SIZE) + + # Just ensure shapes match + if len(data_dependant) == EXPECTED_FOLD_SIZE: + assert all(len(p) == EXPECTED_FOLD_SIZE for p in (y_pred, y_true)) + + if len(data_dependant) == TEST_SIZE: + assert all(len(p) == TEST_SIZE for p in (y_pred, y_true)) + + # Return the fake score, i.e. the injected data_independant value + return data_independant + + custom_scorer = ( + make_scorer(custom_metric, response_method="predict") + # Here we specify that it needs this parameter routed to it + # NOTE: We don't specify that we need the test variations, that + # will be handled by the evaluator by prefixing the test_ to the + # parameters + .set_score_request(data_independant=True, data_dependant=True) + ) + + evaluator = CVEvaluation( + x, + y, + X_test=x_test, + y_test=y_test, + n_splits=N_SPLITS, + params={ + "data_independant": 1, + "data_dependant": np.ones(len(x)), + # Here we provide the test specific scorer params + "test_data_independant": 2, + "test_data_dependant": np.ones(len(x_test)), + }, + working_dir=tmp_path, + on_error="raise", + ) + trial = Trial.create( + name="test", + bucket=tmp_path / "trial", + metrics=Metric(name="custom_metric", fn=custom_scorer), + ) + report = evaluator.fn(trial, pipeline) + + assert report.values["custom_metric"] == 1 + assert report.summary["test_mean_custom_metric"] == 2 + + +def record_split_number( + trial: Trial, + split_number: int, + info: CVEvaluation.PostSplitInfo, +) -> CVEvaluation.PostSplitInfo: + # Should get the test data if it was passed in as it is in the function below + assert info.X_test is not None + assert info.y_test is not None + check_is_fitted(info.model) + + trial.summary[f"post_split_{split_number}"] = split_number + return info + + +def test_post_split(tmp_path: Path) -> None: + pipeline = Component(DecisionTreeClassifier, config={"max_depth": 1}) + x, y = data_for_task_type("binary") + TEST_SIZE = 2 + x_test, y_test = x[:TEST_SIZE], y[:TEST_SIZE] + + NSPLITS = 3 + evaluator = CVEvaluation( + x, + y, + X_test=x_test, + y_test=y_test, + n_splits=NSPLITS, + working_dir=tmp_path, + on_error="raise", + post_split=record_split_number, + ) + trial = Trial.create("test", bucket=tmp_path / "trial", metrics=Metric("accuracy")) + report = evaluator.fn(trial, pipeline) + + for i in range(NSPLITS): + assert f"post_split_{i}" in report.summary + assert report.summary[f"post_split_{i}"] == i + + +def chaotic_post_processing( + report: Trial.Report, + pipeline: Node, # noqa: ARG001 + eval_info: CVEvaluation.CompleteEvalInfo, +) -> Trial.Report: + # We should have no models in our post processing since we didn't ask for it + # in the init. + assert eval_info.models is None + + # We told it to store models, so we should have models in the storage + for i in range(eval_info.max_splits): + assert f"model_{i}.pkl" in report.storage + + # Delete the models + trial = report.trial + + trial.delete_from_storage( + [f"model_{i}.pkl" for i in range(eval_info.max_splits)], + ) + + # Return some bogy number as the metric value + return trial.success(accuracy=0.123) + + +def test_post_processing_no_models(tmp_path: Path) -> None: + pipeline = Component(DecisionTreeClassifier, config={"max_depth": 1}) + x, y = data_for_task_type("binary") + evaluator = CVEvaluation( + x, + y, + working_dir=tmp_path, + on_error="raise", + post_processing=chaotic_post_processing, + store_models=True, + ) + trial = Trial.create("test", bucket=tmp_path / "trial", metrics=Metric("accuracy")) + report = evaluator.fn(trial, pipeline) + + # The chaotic post processing + assert report.values["accuracy"] == 0.123 + assert len(report.storage) == 0