+
Skip to content

TorchLearner: support / example for custom sampler #417

@tdhock

Description

@tdhock

I would like to implement mini-batch stochastic gradient descent learning using a custom dataloader / sampler / batch_sampler (not sure which is the best way to do this, please advise?)

I see in the LearnerTorch docs that it is possible to over-write the private .dataloader method, https://github.com/mlr-org/mlr3torch/blob/main/R/LearnerTorch.R#L12

I also see in the torch docs that it should be possible to define a custom sampler (or batch_sampler?) but there are no examples. mlverse/torch#1344

I see that @sebffischer also had a related issue, mlverse/torch#808

I have imbalanced data, and I would like mini batches to down-sample according to the stratum col_role specified in the task.

A simple example with two classes, 10% 0, 90% 1.

set.seed(1)
dt <- data.table(y=c(rep(0, 100), rep(1, 900)))
ds_gen <- torch::dataset(
  initialize=function(){},
  .getitem=function(i)dt[i]$y,
  .length=function()nrow(dt))
ds <- ds_gen()
dl <- torch::dataloader(ds, batch_size=100, shuffle=TRUE)
coro::loop(for(batch in dl){
  print(table(torch::as_array(batch)))
})

Using the standard dataloader code with batch_size 100, I get 10 batches, each with ~10 y=0 and ~90 y=1

> coro::loop(for(batch in dl){
+ print(table(torch::as_array(batch)))
+ })

 0  1 
 7 93 

 0  1 
 8 92 

 0  1 
10 90 

 0  1 
 8 92 

 0  1 
11 89 

 0  1 
11 89 

 0  1 
12 88 

 0  1 
13 87 

 0  1 
16 84 

 0  1 
 4 96 

my issue is that I would like these counts to always be 10 and 90.
Some code that can do that is below

stratum <- "y"
min_samples_per_stratum <- 10
shuffle_dt <- dt[
, row.id := 1:.N
][sample(.N)][
, i.in.stratum := 1:.N, by=stratum
][]
count_dt <- shuffle_dt[, .(max.i=max(i.in.stratum)), by=stratum][order(max.i)]
count_min <- count_dt$max.i[1]
shuffle_dt[
, n.samp := i.in.stratum/max(i.in.stratum)*count_min, by=stratum
][
, batch.i := ceiling(n.samp/min_samples_per_stratum)
]
dcast(shuffle_dt, batch.i ~ y, length)

The output counts show the expected numbers:

> dcast(shuffle_dt, batch.i ~ y, length)
Key: <batch.i>
    batch.i     0     1
      <num> <int> <int>
 1:       1    10    90
 2:       2    10    90
 3:       3    10    90
 4:       4    10    90
 5:       5    10    90
 6:       6    10    90
 7:       7    10    90
 8:       8    10    90
 9:       9    10    90
10:      10    10    90

My question is: how do I plug this code into the dataloader for use with my mlr3torch learner?
I see there are two related parameters in the TorchLearner param_set (sampler and batch_sampler), and I get the feeling that setting one of these two is the solution, but I do not see any examples of what these should be. Can you please advise?

> mlr3torch::LearnerTorchMLP$new(task_type="regr")$param_set
<ParamSetCollection(37)>
                     id    class lower upper nlevels        default        value
                 <char>   <char> <num> <num>   <num>         <list>       <list>
 1:              epochs ParamInt 0e+00   Inf     Inf <NoDefault[0]>       [NULL]
 2:              device ParamFct    NA    NA      12 <NoDefault[0]>         auto
 3:         num_threads ParamInt 1e+00   Inf     Inf <NoDefault[0]>            1
 4: num_interop_threads ParamInt 1e+00   Inf     Inf <NoDefault[0]>            1
 5:                seed ParamInt  -Inf   Inf     Inf <NoDefault[0]>       random
 6:           eval_freq ParamInt 1e+00   Inf     Inf <NoDefault[0]>            1
 7:      measures_train ParamUty    NA    NA     Inf <NoDefault[0]>    <list[0]>
 8:      measures_valid ParamUty    NA    NA     Inf <NoDefault[0]>    <list[0]>
 9:            patience ParamInt 0e+00   Inf     Inf <NoDefault[0]>            0
10:           min_delta ParamDbl 0e+00   Inf     Inf <NoDefault[0]>            0
11:          batch_size ParamInt 1e+00   Inf     Inf <NoDefault[0]>       [NULL]
12:             shuffle ParamLgl    NA    NA       2          FALSE         TRUE
13:             sampler ParamUty    NA    NA     Inf <NoDefault[0]>       [NULL]
14:       batch_sampler ParamUty    NA    NA     Inf <NoDefault[0]>       [NULL]
15:         num_workers ParamInt 0e+00   Inf     Inf              0       [NULL]
16:          collate_fn ParamUty    NA    NA     Inf         [NULL]       [NULL]
17:          pin_memory ParamLgl    NA    NA       2          FALSE       [NULL]
18:           drop_last ParamLgl    NA    NA       2          FALSE       [NULL]
19:             timeout ParamDbl  -Inf   Inf     Inf             -1       [NULL]
20:      worker_init_fn ParamUty    NA    NA     Inf <NoDefault[0]>       [NULL]
21:      worker_globals ParamUty    NA    NA     Inf <NoDefault[0]>       [NULL]
22:     worker_packages ParamUty    NA    NA     Inf <NoDefault[0]>       [NULL]
23:      tensor_dataset ParamFct    NA    NA       1 <NoDefault[0]>        FALSE
24:           jit_trace ParamLgl    NA    NA       2 <NoDefault[0]>        FALSE
25:             neurons ParamUty    NA    NA     Inf <NoDefault[0]>             
26:                   p ParamDbl 0e+00 1e+00     Inf <NoDefault[0]>          0.5
27:            n_layers ParamInt 1e+00   Inf     Inf <NoDefault[0]>       [NULL]
28:          activation ParamUty    NA    NA     Inf <NoDefault[0]> <nn_relu[1]>
29:     activation_args ParamUty    NA    NA     Inf <NoDefault[0]>    <list[0]>
30:               shape ParamUty    NA    NA     Inf <NoDefault[0]>       [NULL]
31:              opt.lr ParamDbl 0e+00   Inf     Inf          0.001       [NULL]
32:           opt.betas ParamUty    NA    NA     Inf    0.900,0.999       [NULL]
33:             opt.eps ParamDbl 1e-16 1e-04     Inf          1e-08       [NULL]
34:    opt.weight_decay ParamDbl 0e+00 1e+00     Inf              0       [NULL]
35:         opt.amsgrad ParamLgl    NA    NA       2          FALSE       [NULL]
36:    opt.param_groups ParamUty    NA    NA     Inf <NoDefault[0]>       [NULL]
37:      loss.reduction ParamFct    NA    NA       2           mean       [NULL]
                     id    class lower upper nlevels        default        value
                 <char>   <char> <num> <num>   <num>         <list>       <list>

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载