The goal of {mlr3torch} is to connect {mlr3} with {torch}.
It is in the very early stages of development and it’s future and scope are yet to be determined.
remotes::install_github("mlr-org/mlr3torch")
Using the {tabnet} learner for classification:
library(mlr3)
library(mlr3viz)
library(mlr3torch)
task = tsk("german_credit")
# Set up the learner
lrn_tabnet = lrn("classif.tabnet", epochs = 5)
# Train and Predict
lrn_tabnet$train(task, row_ids = 1:900)
preds = lrn_tabnet$predict(task, row_ids = 901:1000)
# Investigate predictions
preds$confusion
preds$score(msr("classif.acc"))
# Predict probabilities instead
lrn_tabnet$predict_type = "prob"
preds_prob = lrn_tabnet$predict(task)
autoplot(preds_prob, type = "roc")
# Examine variable importance scores
lrn_tabnet$importance()
task = tsk("iris")
graph = top("input") %>>%
top("tokenizer_tabular", d_token = 1) %>>%
top("flatten") %>>%
top("relu_1") %>>%
top("linear_1", out_features = 10) %>>%
top("relu_2") %>>%
top("output") %>>%
top("model.classif", epochs = 10L, batch_size = 16L, .loss = "cross_entropy", .optimizer = "adam")
glrn = as_learner_torch(graph)
glrn$train(task)
Some parts of the implementation are inspired by other deep learning libraries: