Hüseyin Tunç, Doğanay Özese, Ş .İlker Birbil, Donato Maragno, Marco Caserta, Mustafa Baydoğan
Incorporating domain-specific constraints into machine learning models is essential for generating predictions that are both accurate and feasible in real-world applications. This paper introduces new methods for training Output-Constrained Regression Trees (OCRT), addressing the limitations of traditional decision trees in constrained multi-target regression tasks. We propose three approaches: M-OCRT, which uses split-based mixed integer programming to enforce constraints; E-OCRT, which employs an exhaustive search for optimal splits and solves constrained prediction problems at each decision node; and EP-OCRT, which applies post-hoc constrained optimization to tree predictions. To illustrate their potential uses in ensemble learning, we also introduce a random forest framework working under convex feasible sets. We validate the proposed methods through a computational study both on synthetic and industry-driven hierarchical time series datasets. Our results demonstrate that imposing constraints on decision tree training results in accurate and feasible predictions.
The details of this work are available in our paper.
The repository contains YAML file (named ocdt.yml
) to create the environment necessary to train the OCDT model. To create the environment you can use the following command:
$ conda env create --file=ocdt.yml
This will create a virtual environment with the name ocdt
. You should be able to see this environment if you run the following command:
$ conda env list
If ocdt
is listed in the virtual environments, that means that the environment is installed successfully. You can activate the environment using the command below:
$ conda activate ocdt
Currently, the repository is available for training the OCDT model.
To be able to start with the runs, we need data. There are 2 datasets used in the repository. All of these are available within the data
folder. In particular, there are synthetic datasets (generated by running the generate_constrained_dataset_with_nonlinearity()
function within the library/Constrained_Data_Generation.py
), which has the following naming format according to the number of targets and dataset size: df_size_<DATASET_SIZE>_targets_<NUMBER_OF_TARGETS>_seed_<SEED>
. There are also hierarchical time series datasets.
ocdt_min_samples_split
: Minimum number of instances that a decision node should have in order to perform splitting.ocdt_min_samples_leaf
: Minimum number of instances that a node should have in order to become leaf node.ocdt_depth
: Maximum depth of OCDT.evaluation_method
: Evaluation metric that is used to calculate the gains of the split candidates. Available values aremse
(i.e. Mean Squared Error),mad
(i.e. Mean Absolute Deviation), andpoisson
(i.e. Poisson Deviation).prediction_method
: Prediction approach used in splitting. Available values aremean
(i.e. to return the mean of target values as prediction),medoid
(i.e. to return the median of target values as prediction),optimal
(i.e. to return the optimal values the optimization problem that minimizes MSE objective function).prediction_method_leaf
: Prediction approach used in leaves. Available values aremedoid
(i.e. to return the median of target values as prediction),optimal
(i.e. to return the optimal values the optimization problem that minimizes MSE objective function).
ocdt_params
: All OCDT parameters.n_estimators
: The number of regressors used to construct random forest.max_features
: Maximum number of features to be used while training each regressor.
After the environment is activated, you can replicate the runs with results presented in the paper. For each dataset presented in the paper, dataset
parameter can be set to the values of synthetic_manifold
and hts
. To be able to retrieve multiple results at once, some of the parameters mentioned above are collected together to iterate over with the variable that has the suffix _list
.
Contributions are always welcome.
If you are reporting a bug, please include:
- Any details about your local setup that might be helpful in troubleshooting.
- Detailed steps to reproduce the bug.