+
Skip to content

mlr-org/mlr3torch

 
 

Repository files navigation

mlr3torch

Lifecycle: experimental R-CMD-check CRAN status

The goal of {mlr3torch} is to connect {mlr3} with {torch}.

It is in the very early stages of development and it’s future and scope are yet to be determined.

Installation

remotes::install_github("mlr-org/mlr3torch")

tabnet Example

Using the {tabnet} learner for classification:

library(mlr3)
library(mlr3viz)
library(mlr3torch)

task = tsk("german_credit")

# Set up the learner
lrn_tabnet = lrn("classif.tabnet", epochs = 5)

# Train and Predict
lrn_tabnet$train(task, row_ids = 1:900)

preds = lrn_tabnet$predict(task, row_ids = 901:1000)

# Investigate predictions
preds$confusion
preds$score(msr("classif.acc"))

# Predict probabilities instead
lrn_tabnet$predict_type = "prob"
preds_prob = lrn_tabnet$predict(task)

autoplot(preds_prob, type = "roc")

# Examine variable importance scores
lrn_tabnet$importance()

Using TorchOps

task = tsk("iris")

graph = top("input") %>>%
  top("tokenizer_tabular", d_token = 1) %>>%
  top("flatten") %>>%
  top("relu_1") %>>%
  top("linear_1", out_features = 10) %>>%
  top("relu_2") %>>%
  top("output") %>>%
  top("model.classif", epochs = 10L, batch_size = 16L, .loss = "cross_entropy", .optimizer = "adam")

glrn = as_learner_torch(graph)
glrn$train(task)

Credit

Some parts of the implementation are inspired by other deep learning libraries:

  • Keras - Building networks using TorchOp’s feels similar to using keras.
  • Luz - Our implementation of callbacks is inspired by the R package luz

About

Deep learning framework for the mlr3 ecosystem based on torch

Topics

Resources

License

Stars

Watchers

Forks

Sponsor this project

 

Packages

No packages published

Contributors 9

Languages

点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载