+
Skip to content

Support CIs for non-decomposable measures #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ Imports:
R6,
withr
Suggests:
testthat (>= 3.0.0)
testthat (>= 3.0.0),
rpart
Remotes:
mlr-org/mlr3
Config/testthat/edition: 3
Expand All @@ -38,12 +39,12 @@ RoxygenNote: 7.3.2
Collate:
'MeasureAbstractCi.R'
'aaa.R'
'MeasureCI.R'
'MeasureCIConZ.R'
'MeasureCICorT.R'
'MeasureCIHoldout.R'
'MeasureCINaiveCV.R'
'MeasureCi.R'
'MeasureCiNestedCV.R'
'MeasureCINestedCV.R'
'MeasureCIWaldCV.R'
'ResamplingNestedCV.R'
'ResamplingPairedSubsampling.R'
'bibentries.R'
Expand Down
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ export(MeasureCi)
export(MeasureCiConZ)
export(MeasureCiCorrectedT)
export(MeasureCiHoldout)
export(MeasureCiNaiveCV)
export(MeasureCiNestedCV)
export(MeasureCiWaldCV)
export(ResamplingNestedCV)
export(ResamplingPairedSubsampling)
import(checkmate)
Expand Down
39 changes: 28 additions & 11 deletions R/MeasureAbstractCi.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#' The measure for which to calculate a confidence interval. Must have `$obs_loss`.
#' @param resamplings (`character()`)\cr
#' To which resampling classes this measure can be applied.
#' @param requires_obs_loss (`logical(1)`)\cr
#' Whether the inference method requires a pointwise loss function.
#' @template param_param_set
#' @template param_packages
#' @template param_label
Expand All @@ -28,7 +30,8 @@
#' @section Inheriting:
#' To define a new CI method, inherit from the abstract base class and implement the private method:
#' `ci: function(tbl: data.table, rr: ResampleResult, param_vals: named `list()`) -> numeric(3)`
#' Here, `tbl` contains the columns `loss`, `row_id` and `iteration`, which are the pointwise loss,
#' If `requires_obs_loss` is set to `TRUE`, `tbl` contains the columns `loss`, `row_id` and `iteration`, which are the pointwise loss,
#' Otherwise, `tbl` contains the result of `rr$score()` with the name of the loss column set to `"loss"`.
#' the identifier of the observation and the resampling iteration.
#' It should return a vector containing the `estimate`, `lower` and `upper` boundary in that order.
#'
Expand All @@ -49,19 +52,28 @@ MeasureAbstractCi = R6Class("MeasureAbstractCi",
measure = NULL,
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(measure = NULL, param_set = ps(), packages = character(), resamplings, label, delta_method = FALSE) {
initialize = function(measure = NULL, param_set = ps(), packages = character(), resamplings, label, delta_method = FALSE,
requires_obs_loss = TRUE) { # nolint
private$.delta_method = assert_flag(delta_method, na.ok = TRUE)
self$measure = if (test_string(measure)) {
msr(measure)
} else {
private$.requires_obs_loss = assert_flag(requires_obs_loss)
if (test_string(measure)) measure = msr(measure)
self$measure = measure

if (private$.requires_obs_loss) {
assert(
check_class(measure, "Measure"),
check_false(inherits(measure, "MeasureCi")),
check_function(measure$obs_loss),
combine = "and",
.var.name = "Argument measure must be a scalar Measure with a pointwise loss function (has $obs_loss field)"
)
measure
} else {
assert(
check_class(measure, "Measure"),
check_false(inherits(measure, "MeasureCi")),
combine = "and",
.var.name = "Argument measure must be a scalar Measure."
)
}

param_set = c(param_set,
Expand Down Expand Up @@ -108,10 +120,15 @@ MeasureAbstractCi = R6Class("MeasureAbstractCi",
}

param_vals = self$param_set$get_values()
tbl = rr$obs_loss(self$measure)
names(tbl)[names(tbl) == self$measure$id] = "loss"
tbl = if (private$.requires_obs_loss) {
rr$obs_loss(self$measure)
} else {
rr$score(self$measure)
}
setnames(tbl, self$measure$id, "loss")

ci = private$.ci(tbl, rr, param_vals)
if (!is.null(self$measure$trafo)) {
if (!is.null(self$measure$trafo) && private$.requires_obs_loss) {
ci = private$.trafo(ci)
}
if (param_vals$within_range) {
Expand All @@ -121,15 +138,15 @@ MeasureAbstractCi = R6Class("MeasureAbstractCi",
}
),
private = list(
.requires_obs_loss = NULL,
.delta_method = FALSE,
.trafo = function(ci) {
if (!private$.delta_method) {
stopf("Measure '%s' has a trafo, but the CI does handle it", self$measure$id)
stopf("Measure '%s' has a trafo, but the CI does not handle it", self$measure$id)
}
measure = self$measure
# delta-rule
multiplier = measure$trafo$deriv(ci[[1]])
ci[[1]] = measure$trafo$fn(ci[[1]])
halfwidth = (ci[[3]] - ci[[1]])
est_t = measure$trafo$fn(ci[[1]])
ci_t = c(est_t, est_t - halfwidth * multiplier, est_t + halfwidth * multiplier)
Expand Down
File renamed without changes.
3 changes: 2 additions & 1 deletion R/MeasureCIConZ.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#' @description
#' The conservative-z confidence intervals based on the [`ResamplingPairedSubsampling`].
#' Because the variance estimate is obtained using only `n / 2` observations, it tends to be conservative.
#' This inference method can also be applied to non-decomposable losses.
#' @section Parameters:
#' Only those from [`MeasureAbstractCi`].
#' @template param_measure
Expand All @@ -22,6 +23,7 @@ MeasureCiConZ = R6Class("MeasureCiConZ",
measure = measure,
resamplings = "ResamplingPairedSubsampling",
label = "Conservative-Z CI",
requires_obs_loss = FALSE,
delta_method = TRUE
)
}
Expand All @@ -30,7 +32,6 @@ MeasureCiConZ = R6Class("MeasureCiConZ",
.ci = function(tbl, rr, param_vals) {
repeats_in = rr$resampling$param_set$values$repeats_in
repeats_out = rr$resampling$param_set$values$repeats_out
tbl = tbl[, list(loss = mean(get("loss"))), by = "iteration"]

estimate = tbl[get("iteration") <= repeats_in, mean(get("loss"))]

Expand Down
4 changes: 3 additions & 1 deletion R/MeasureCICorT.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#' Corrected-T confidence intervals based on [`ResamplingSubsampling`][mlr3::ResamplingSubsampling].
#' A heuristic factor is applied to correct for the dependence between the iterations.
#' The confidence intervals tend to be liberal.
#' This inference method can also be applied to non-decomposable losses.
#' @section Parameters:
#' Only those from [`MeasureAbstractCi`].
#' @template param_measure
Expand All @@ -29,6 +30,7 @@ MeasureCiCorrectedT = R6Class("MeasureCiCorrectedT",
measure = measure,
resamplings = "ResamplingSubsampling",
label = "Corrected-T CI",
requires_obs_loss = FALSE,
delta_method = TRUE
)
}
Expand All @@ -45,7 +47,7 @@ MeasureCiCorrectedT = R6Class("MeasureCiCorrectedT",
n2 = n - n1

# the different mu in the rows are the mu_j
mus = tbl[, list(estimate = mean(get("loss"))), by = "iteration"]$estimate
mus = tbl$loss
# the global estimator
estimate = mean(mus)
# The naive SD estimate (does not take correlation between folds into account)
Expand Down
1 change: 1 addition & 0 deletions R/MeasureCIHoldout.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#' @name mlr_measures_ci_holdout
#' @description
#' Standard holdout CI.
#' This inference method can only be applied to decomposable losses.
#' @section Parameters:
#' Only those from [`MeasureAbstractCi`].
#' @template param_measure
Expand Down
15 changes: 8 additions & 7 deletions R/MeasureCINaiveCV.R → R/MeasureCIWaldCV.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#' @title Naive Cross-Validation CI
#' @name mlr_measures_ci_naive_cv
#' @title Cross-Validation CI
#' @name mlr_measures_ci_wald_cv
#' @description
#' Confidence intervals for cross-validation.
#' The method is asymptotically exact for the so called *Test Error* as defined by Bayle et al. (2020).
#' For the (expected) risk, the confidence intervals tend to be too liberal.
#' This inference method can only be applied to decomposable losses.
#' @section Parameters:
#' Those from [`MeasureAbstractCi`], as well as:
#' * `variance` :: `"all-pairs"` or `"within-fold"`\cr
Expand All @@ -13,11 +14,11 @@
#' `r format_bib("bayle2020cross")`
#' @export
#' @examples
#' m_naivecv = msr("ci.naive_cv", "classif.ce")
#' m_naivecv
#' m_waldcv = msr("ci.wald_cv", "classif.ce")
#' m_waldcv
#' rr = resample(tsk("sonar"), lrn("classif.featureless"), rsmp("cv"))
#' rr$aggregate(m_naivecv)
MeasureCiNaiveCV = R6Class("MeasureCiNaiveCV",
#' rr$aggregate(m_waldcv)
MeasureCiWaldCV = R6Class("MeasureCiWaldCV",
inherit = MeasureAbstractCi,
public = list(
#' @description
Expand Down Expand Up @@ -60,4 +61,4 @@ MeasureCiNaiveCV = R6Class("MeasureCiNaiveCV",
)

#' @include aaa.R
measures[["ci.naive_cv"]] = list(MeasureCiNaiveCV, .prototype_args = list(measure = "classif.acc"))
measures[["ci.wald_cv"]] = list(MeasureCiWaldCV, .prototype_args = list(measure = "classif.acc"))
1 change: 1 addition & 0 deletions R/MeasureCiNestedCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#' @name mlr_measures_ci_ncv
#' @description
#' Confidence Intervals based on [`ResamplingNestedCV`][ResamplingNestedCV], including bias-correction.
#' This inference method can only be applied to decomposable losses.
#' @section Parameters:
#' Those from [`MeasureAbstractCi`], as well as:
#' * `bias` :: `logical(1)`\cr
Expand Down
2 changes: 1 addition & 1 deletion R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ register_mlr3 = function(...) {
mlr_reflections = mlr3::mlr_reflections
mlr_reflections$default_ci_methods = list(
ResamplingHoldout = "ci.holdout",
ResamplingCV = "ci.naive_cv",
ResamplingCV = "ci.wald_cv",
ResamplingSubsampling = "ci.cor_t",
ResamplingPairedSubsampling = "ci.con_z",
ResamplingNestedCV = "ci.ncv"
Expand Down
11 changes: 7 additions & 4 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -82,22 +82,25 @@ autoplot(bmr, "ci", msr("ci", "classif.ce"))

Note that:

* Confidence Intervals can only be obtained for measures that are based on pointwise loss functions, i.e. have an `$obs_loss` field.
* Some methods require pointwise loss functions, i.e. have an `$obs_loss` field.
* Not for every resampling method exists an inference method.
* There are combinations of datasets and learners, where inference methods can fail.

## Features

* Additional Resampling Methods
* Confidence Intervals for the Generalization Error for some resampling methods
* Confidence Intervals for the Generalization Error for some resampling methods


## Inference Methods

```{r, echo = FALSE}
content = as.data.table(mlr3::mlr_measures, objects = TRUE)[startsWith(get("key"), "ci."),]
content$resamplings = map(content$object, "resamplings")
content = content[, c("key", "label", "resamplings")]
content$resamplings = map(content$object, function(x) paste0(gsub("Resampling", "", x$resamplings), collapse = ", "))
content[["only pointwise loss"]] = map_chr(content$object, function(object) {
if (get_private(object)$.requires_obs_loss) "yes" else "false"
})
content = content[, c("key", "label", "resamplings", "only pointwise loss")]
knitr::kable(content, format = "markdown", col.names = tools::toTitleCase(names(content)))
```

Expand Down
28 changes: 19 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ autoplot(bmr, "ci", msr("ci", "classif.ce"))

Note that:

- Confidence Intervals can only be obtained for measures that are based
on pointwise loss functions, i.e. have an `$obs_loss` field.
- Some methods require pointwise loss functions, i.e. have an
`$obs_loss` field.
- Not for every resampling method exists an inference method.
- There are combinations of datasets and learners, where inference
methods can fail.
Expand All @@ -89,13 +89,23 @@ Note that:

## Inference Methods

| Key | Label | Resamplings |
|:------------|:------------------|:-----------------------------|
| ci.con_z | Conservative-Z CI | ResamplingPairedSubsampling |
| ci.cor_t | Corrected-T CI | ResamplingSubsampling |
| ci.holdout | Holdout CI | ResamplingHoldout |
| ci.naive_cv | Naive CV CI | ResamplingCV , ResamplingLOO |
| ci.ncv | Nested CV CI | ResamplingNestedCV |
``` r
content = as.data.table(mlr3::mlr_measures, objects = TRUE)[startsWith(get("key"), "ci."),]
content$resamplings = map(content$object, function(x) paste0(gsub("Resampling", "", x$resamplings), collapse = ", "))
content[["only pointwise loss"]] = map_chr(content$object, function(object) {
if (get_private(object)$.requires_obs_loss) "yes" else "false"
})
content = content[, c("key", "label", "resamplings", "only pointwise loss")]
knitr::kable(content, format = "markdown", col.names = tools::toTitleCase(names(content)))
```

| Key | Label | Resamplings | Only Pointwise Loss |
|:------------|:------------------|:------------------|:--------------------|
| ci.con_z | Conservative-Z CI | PairedSubsampling | false |
| ci.cor_t | Corrected-T CI | Subsampling | false |
| ci.holdout | Holdout CI | Holdout | yes |
| ci.wald_cv | Naive CV CI | CV, LOO | yes |
| ci.ncv | Nested CV CI | NestedCV | yes |

## Bugs, Questions, Feedback

Expand Down
9 changes: 7 additions & 2 deletions man/mlr_measures_abstract_ci.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/mlr_measures_ci.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_measures_ci_con_z.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_measures_ci_cor_t.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_measures_ci_holdout.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/mlr_measures_ci_ncv.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载