这是indexloc提供的服务,不要输入任何密码
Skip to content

Conversation

@Innixma
Copy link
Contributor

@Innixma Innixma commented Jul 18, 2024

Issue #, if available:

Description of changes:

  • Add initial callbacks support to TabularPredictor
  • Callbacks allow users to inject custom logic into the training process, and theoretically allow the user to completely override the training logic with their own custom logic.

Example Code:

import pandas as pd
from autogluon.tabular import TabularPredictor

from autogluon.core.callbacks import EarlyStoppingCallback


if __name__ == '__main__':
    label = 'class'
    train_data = pd.read_csv('https://autogluon.s3.amazonaws.com/datasets/Inc/train.csv')

    callbacks = [EarlyStoppingCallback(patience=3)]
    predictor = TabularPredictor(label=label).fit(train_data=train_data, callbacks=callbacks)

Example Output:

...
User-specified callbacks (1): ['EarlyStoppingCallback']
Fitting 13 L1 models ...
Fitting model: KNeighborsUnif ...
	0.7736	 = Validation score   (accuracy)
	1.63s	 = Training   runtime
	0.02s	 = Validation runtime
EarlyStoppingCallback: Best Score: 0.7736 | Patience: 0/3 | Best Model: KNeighborsUnif (New Best)
Fitting model: KNeighborsDist ...
	0.7652	 = Validation score   (accuracy)
	0.2s	 = Training   runtime
	0.01s	 = Validation runtime
EarlyStoppingCallback: Best Score: 0.7736 | Patience: 1/3 | Best Model: KNeighborsUnif
Fitting model: LightGBMXT ...
	0.8792	 = Validation score   (accuracy)
	1.74s	 = Training   runtime
	0.0s	 = Validation runtime
EarlyStoppingCallback: Best Score: 0.8792 | Patience: 0/3 | Best Model: LightGBMXT (New Best)
Fitting model: LightGBM ...
	0.8824	 = Validation score   (accuracy)
	1.37s	 = Training   runtime
	0.0s	 = Validation runtime
EarlyStoppingCallback: Best Score: 0.8824 | Patience: 0/3 | Best Model: LightGBM (New Best)
Fitting model: RandomForestGini ...
	0.8612	 = Validation score   (accuracy)
	0.91s	 = Training   runtime
	0.08s	 = Validation runtime
EarlyStoppingCallback: Best Score: 0.8824 | Patience: 1/3 | Best Model: LightGBM
Fitting model: RandomForestEntr ...
	0.8584	 = Validation score   (accuracy)
	1.0s	 = Training   runtime
	0.09s	 = Validation runtime
EarlyStoppingCallback: Best Score: 0.8824 | Patience: 2/3 | Best Model: LightGBM
Fitting model: CatBoost ...
	0.8824	 = Validation score   (accuracy)
	6.89s	 = Training   runtime
	0.01s	 = Validation runtime
EarlyStoppingCallback: Best Score: 0.8824 | Patience: 3/3 | Best Model: LightGBM
EarlyStoppingCallback: Early stopping trainer fit for level=1. Reason: No score_val improvement in the past 3 models.
Fitting model: WeightedEnsemble_L2 ...
	Ensemble Weights: {'CatBoost': 0.5, 'LightGBM': 0.45, 'KNeighborsUnif': 0.05}
	0.886	 = Validation score   (accuracy)
	0.06s	 = Training   runtime
	0.0s	 = Validation runtime
EarlyStoppingCallback: Best Score: 0.8860 | Patience: 0/3 | Best Model: WeightedEnsemble_L2 (New Best)
AutoGluon training complete, total runtime = 15.16s ... Best model: WeightedEnsemble_L2 | Estimated inference throughput: 69200.6 rows/s (2500 batch size)
TabularPredictor saved. To load, use: predictor = TabularPredictor.load("AutogluonModels/ag-20240730_021147")

TODO:

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

@Innixma Innixma added API & Doc Improvements or additions to documentation enhancement New feature or request module: tabular labels Jul 18, 2024
@Innixma Innixma added this to the 1.2 Release milestone Jul 18, 2024
@github-actions
Copy link

Job PR-4327-aeb82c1 is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-4327/aeb82c1/index.html

@github-actions
Copy link

Job PR-4327-7c407de is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-4327/7c407de/index.html

Copy link

@eddiebergman eddiebergman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the unsolicited PR review, was just curious to see how callbacks get implemented!

time_limit: float | None = None,
stack_name: str = "core",
level: int = 1,
) -> Tuple[bool, bool]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want, you can put from __future__ import annotations at the top of a file and can then use tuple[bool ,bool]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Updated

return early_stop

def _calc_new_best(self, trainer: AbstractTrainer):
leaderboard = trainer.leaderboard()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not critical but it seems this call to trainer.leaderboard() is not the cheapest, given it does a lot of dict parsing, dag processing and sorting.

https://github.com/Innixma/autogluon/blob/7c407de63254ff7f8a09c63be1adb495bd392229/core/src/autogluon/core/trainer/abstract_trainer.py#L3151

Alternatives:

model, score = max(trainer.get_model_attributes("val_score").items(), key=lambda t: t[1]) # t = (model_name, score) 

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call! I actually needed to revamp this logic to make it work in cases where the user specified infer_limit to ensure that the model returned satisfies the infer_limit.


# self._exceptions_list = [] # TODO: Keep exceptions list for debugging during benchmarking.

self.callbacks: List[callable] = []

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lowercase callable is actually not a type, it's a function. You need typing.Callable.

(...oh how I wish it was though, like list and dict are)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I updated the PR and I created a GitHub issue to track replacement throughout the project: #4349

level_time_modifier=0.333,
infer_limit=None,
infer_limit_batch_size=None,
callbacks: List[callable] = None,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise callable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

if self.model_best is None and len(model_names_fit) != 0:
self.model_best = self.get_model_best(can_infer=True, infer_limit=infer_limit, infer_limit_as_child=True)
self._time_limit = None
self._fit_cleanup()
Copy link

@eddiebergman eddiebergman Jul 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not worth the complexity if this is the only usage but the _fit_setup() and _fit_cleanup() can be removed in favor of a @contextmanager, i.e. with _fitting_contenxt(): , it can make the intent a bit clearer and there's one less completion your editor gives you

from contextlib import contextmanager

class WhateverThisClassIs:

    @contextmanager
    def _fitting_context(self) -> Iterator[None]:
        # Previously `_fit_start()`
        self._time_train_start = time.time()
        self._time_train_start_last = self._time_train_start
        self._time_limit = time_limit
        
        # Might be able to just lift logic here and remove one
        # more method
        self.reset_callbacks() 
        
        if callbacks is not None:
            assert isinstance(callbacks, list), f"`callbacks` must be a list. Found invalid type: `{type(callbacks)}`."
        else:
            callbacks = []
        self.callbacks = callbacks
        
        yield
        
        # Previously `_fit_cleanup()`
        self._time_limit = None
        self._time_train_start = None
                
        # Likewise with lifting logic
        self.reset_callbacks()
        
        

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting idea. I think for now I will keep it as is, but this could be something worth adopting in future. I've made essentially zero use of yield in AutoGluon so far, mostly due to my unfamiliarity.

) -> Tuple[bool, bool]:
time_limit_trainer = trainer._time_limit
if time_limit_trainer is not None and trainer._time_train_start is not None:
time_left_total = time_limit_trainer - (time.time() - trainer._time_train_start)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recently learnt that time.time() is not the best for measuring durations. Most systems are fine and it's rarely ever leading to a bug until it is one.

https://docs.python.org/3/library/time.html#time.monotonic

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice find! #4342

Comment on lines +696 to +748
if self._callback_early_stop:
return []

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth considering if you want some overwrite ability for this, i.e. for testing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting idea, will probably keep as is for now and add overwrite logic if it becomes useful.

Comment on lines 2427 to 2528
def _callbacks_after_fit(
self,
*,
model_names: List[str],
stack_name: str,
level: int,
):
for callback in self.callbacks:
callback_early_stop = callback.after_fit(
self,
model_names=model_names,
logger=logger,
stack_name=stack_name,
level=level,
)
if callback_early_stop:
self._callback_early_stop = True
Copy link

@eddiebergman eddiebergman Jul 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's no side effect to care about with the .after_fit(), i.e. you don't care about callbacks state after this call, then you could break early. Might not be important if callbacks are super cheap.

    def _callbacks_after_fit(
        self,
        *,
        model_names: List[str],
        stack_name: str,
        level: int,
    ):
        for callback in self.callbacks:
            callback_early_stop = callback.after_fit(
                self,
                model_names=model_names,
                logger=logger,
                stack_name=stack_name,
                level=level,
            )
            if callback_early_stop:
                self._callback_early_stop = True
                break

Or if you're feeling functional ;)

    def _callbacks_after_fit(
        self,
        *,
        model_names: List[str],
        stack_name: str,
        level: int,
    ):
        should_stop_itr = (
            callback.after_fit(
                self,
                model_names=model_names,
                logger=logger,
                stack_name=stack_name,
                level=level,
            )
            for callback in self.callbacks
        )
        self._callback_early_stop = any(should_stop_itr)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thought process is that some callbacks could simply be logging related, and I wouldn't want a callback that early stops to prevent a logging callback from doing its logging.

The other thing is technically the later callbacks can check if trainer._callback_early_stop == True and skip their logic.

This is a good idea though, and I thought about it too while implementing. I've added the following self.skip_if_trainer_stopped parameter to callbacks. When true, the callback logic will be skipped when trainer._callback_early_stop == True. This should give the best of both worlds.

Copy link

@eddiebergman eddiebergman Jul 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! If you want to make it clearer to users, you could make a shallow class LoggingCallback, or conversely, EarlyStoppingCallback which just internally sets this flag, i.e. user never sets it. I use the pattern a good bit where the flag is actually set as a class variables, e.g. LoggingCallback sets skip = False
And EarlyStoppingCallback with skip = True

Gives you the freedom to later change this to a dict or things, or whatever, as long as LoggingCallback is never skipped and EarlyStoppimhCallback is. Also one less parameter for someone to think about

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a great idea, updated the code, I think this is a good practice I should adopt more often to separate user-level args with developer level args (aka those that are only relevant to specify when subclassing)

@Innixma
Copy link
Contributor Author

Innixma commented Jul 25, 2024

@eddiebergman unsolicited PR reviews are among my favorite kinds of code reviews :)

@github-actions
Copy link

Job PR-4327-7f9a6e6 is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-4327/7f9a6e6/index.html

@Innixma Innixma changed the title [WIP] [tabular] Add initial callbacks support [tabular] Add initial callbacks support Jul 30, 2024
@Innixma Innixma requested review from prateekdesai04 and shchur July 30, 2024 02:19
@github-actions
Copy link

Job PR-4327-c3b7fbf is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-4327/c3b7fbf/index.html

trainer: AbstractTrainer,
model: AbstractModel,
time_limit: float | None = None,
stack_name: str = "core",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a typing.Literal?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this case no because stack_name can be any valid string.

cgroups). Otherwise, AutoGluon might wrongly assume more resources are available for fitting a model than the operating system allows,
which can result in model training failing or being very inefficient.
callbacks : List[AbstractCallback], default = None
[Experimental] Callback support is preliminary, targeted towards developers, and is subject to change.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Should we use admonitions to highlight experimental features?

:::{warning}
This is an experimental feature and may change in the future releases without warning.
:::

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Added

self._time_limit = time_limit
self.reset_callbacks()
if callbacks is not None:
assert isinstance(callbacks, list), f"`callbacks` must be a list. Found invalid type: `{type(callbacks)}`."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to verify that callbacks are of type AbstractCallback?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! Added

# self._exceptions_list = [] # TODO: Keep exceptions list for debugging during benchmarking.

self.callbacks: List[AbstractCallback] = []
self._callback_early_stop = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very familiar with the rest of the AbstractTrainer code, so this might be a bad suggestion, but would it be feasible to communicate the early stopping / interruption without altering the state of the trainer? I can imagine that in some scenarios such as distributed training it will be really tricky to reason about the state of this variable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tough to say, but I think I would prefer to edit the trainer state, as the intention is to not pass the trainer object to worker threads, and therefore it remains a singular source of truth. We can change it later if we find this has limitations.

@github-actions
Copy link

Job PR-4327-c34b62a is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-4327/c34b62a/index.html

@github-actions
Copy link

Job PR-4327-78d82af is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-4327/78d82af/index.html

)
model_names_fit += base_model_names + aux_models
if self.model_best is None and len(model_names_fit) != 0:
if (self.model_best is None or infer_limit is not None) and len(model_names_fit) != 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: could you explain this check ?

Copy link
Contributor Author

@Innixma Innixma Aug 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pseudocode breakdown of the check:


if (user specified infer limit) and (any model exists):
    select the best valid model based on validation score that satisfies the infer limit
elif (trainer has not specified which model is best) and (any model exists):
    select the best valid model based on validation score

The added infer_limit check ensures that we always return a model that satisfies the infer_limit constraints. Sometimes, a non-ensemble model can get a better score while satisfying infer limit than an ensemble model, and by default only ensemble models are set to self.model_best prior to this call. This call ensures that the best model will be picked, even if it isn't an ensemble.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

API & Doc Improvements or additions to documentation enhancement New feature or request module: tabular

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants