-
-
Notifications
You must be signed in to change notification settings - Fork 8
Description
I would like to implement mini-batch stochastic gradient descent learning using a custom dataloader / sampler / batch_sampler (not sure which is the best way to do this, please advise?)
I see in the LearnerTorch docs that it is possible to over-write the private .dataloader method, https://github.com/mlr-org/mlr3torch/blob/main/R/LearnerTorch.R#L12
I also see in the torch docs that it should be possible to define a custom sampler (or batch_sampler?) but there are no examples. mlverse/torch#1344
I see that @sebffischer also had a related issue, mlverse/torch#808
I have imbalanced data, and I would like mini batches to down-sample according to the stratum col_role specified in the task.
A simple example with two classes, 10% 0, 90% 1.
set.seed(1)
dt <- data.table(y=c(rep(0, 100), rep(1, 900)))
ds_gen <- torch::dataset(
initialize=function(){},
.getitem=function(i)dt[i]$y,
.length=function()nrow(dt))
ds <- ds_gen()
dl <- torch::dataloader(ds, batch_size=100, shuffle=TRUE)
coro::loop(for(batch in dl){
print(table(torch::as_array(batch)))
})
Using the standard dataloader code with batch_size 100, I get 10 batches, each with ~10 y=0 and ~90 y=1
> coro::loop(for(batch in dl){
+ print(table(torch::as_array(batch)))
+ })
0 1
7 93
0 1
8 92
0 1
10 90
0 1
8 92
0 1
11 89
0 1
11 89
0 1
12 88
0 1
13 87
0 1
16 84
0 1
4 96
my issue is that I would like these counts to always be 10 and 90.
Some code that can do that is below
stratum <- "y"
min_samples_per_stratum <- 10
shuffle_dt <- dt[
, row.id := 1:.N
][sample(.N)][
, i.in.stratum := 1:.N, by=stratum
][]
count_dt <- shuffle_dt[, .(max.i=max(i.in.stratum)), by=stratum][order(max.i)]
count_min <- count_dt$max.i[1]
shuffle_dt[
, n.samp := i.in.stratum/max(i.in.stratum)*count_min, by=stratum
][
, batch.i := ceiling(n.samp/min_samples_per_stratum)
]
dcast(shuffle_dt, batch.i ~ y, length)
The output counts show the expected numbers:
> dcast(shuffle_dt, batch.i ~ y, length)
Key: <batch.i>
batch.i 0 1
<num> <int> <int>
1: 1 10 90
2: 2 10 90
3: 3 10 90
4: 4 10 90
5: 5 10 90
6: 6 10 90
7: 7 10 90
8: 8 10 90
9: 9 10 90
10: 10 10 90
My question is: how do I plug this code into the dataloader for use with my mlr3torch learner?
I see there are two related parameters in the TorchLearner param_set (sampler and batch_sampler), and I get the feeling that setting one of these two is the solution, but I do not see any examples of what these should be. Can you please advise?
> mlr3torch::LearnerTorchMLP$new(task_type="regr")$param_set
<ParamSetCollection(37)>
id class lower upper nlevels default value
<char> <char> <num> <num> <num> <list> <list>
1: epochs ParamInt 0e+00 Inf Inf <NoDefault[0]> [NULL]
2: device ParamFct NA NA 12 <NoDefault[0]> auto
3: num_threads ParamInt 1e+00 Inf Inf <NoDefault[0]> 1
4: num_interop_threads ParamInt 1e+00 Inf Inf <NoDefault[0]> 1
5: seed ParamInt -Inf Inf Inf <NoDefault[0]> random
6: eval_freq ParamInt 1e+00 Inf Inf <NoDefault[0]> 1
7: measures_train ParamUty NA NA Inf <NoDefault[0]> <list[0]>
8: measures_valid ParamUty NA NA Inf <NoDefault[0]> <list[0]>
9: patience ParamInt 0e+00 Inf Inf <NoDefault[0]> 0
10: min_delta ParamDbl 0e+00 Inf Inf <NoDefault[0]> 0
11: batch_size ParamInt 1e+00 Inf Inf <NoDefault[0]> [NULL]
12: shuffle ParamLgl NA NA 2 FALSE TRUE
13: sampler ParamUty NA NA Inf <NoDefault[0]> [NULL]
14: batch_sampler ParamUty NA NA Inf <NoDefault[0]> [NULL]
15: num_workers ParamInt 0e+00 Inf Inf 0 [NULL]
16: collate_fn ParamUty NA NA Inf [NULL] [NULL]
17: pin_memory ParamLgl NA NA 2 FALSE [NULL]
18: drop_last ParamLgl NA NA 2 FALSE [NULL]
19: timeout ParamDbl -Inf Inf Inf -1 [NULL]
20: worker_init_fn ParamUty NA NA Inf <NoDefault[0]> [NULL]
21: worker_globals ParamUty NA NA Inf <NoDefault[0]> [NULL]
22: worker_packages ParamUty NA NA Inf <NoDefault[0]> [NULL]
23: tensor_dataset ParamFct NA NA 1 <NoDefault[0]> FALSE
24: jit_trace ParamLgl NA NA 2 <NoDefault[0]> FALSE
25: neurons ParamUty NA NA Inf <NoDefault[0]>
26: p ParamDbl 0e+00 1e+00 Inf <NoDefault[0]> 0.5
27: n_layers ParamInt 1e+00 Inf Inf <NoDefault[0]> [NULL]
28: activation ParamUty NA NA Inf <NoDefault[0]> <nn_relu[1]>
29: activation_args ParamUty NA NA Inf <NoDefault[0]> <list[0]>
30: shape ParamUty NA NA Inf <NoDefault[0]> [NULL]
31: opt.lr ParamDbl 0e+00 Inf Inf 0.001 [NULL]
32: opt.betas ParamUty NA NA Inf 0.900,0.999 [NULL]
33: opt.eps ParamDbl 1e-16 1e-04 Inf 1e-08 [NULL]
34: opt.weight_decay ParamDbl 0e+00 1e+00 Inf 0 [NULL]
35: opt.amsgrad ParamLgl NA NA 2 FALSE [NULL]
36: opt.param_groups ParamUty NA NA Inf <NoDefault[0]> [NULL]
37: loss.reduction ParamFct NA NA 2 mean [NULL]
id class lower upper nlevels default value
<char> <char> <num> <num> <num> <list> <list>