The goal of this vignette is to explain how to quantify the extent to which it is possible to train on one data subset, and predict on another data subset. This kind of problem occurs frequently in many different problem domains:

The ideas are similar to my previous blog posts about how to do this in python and R. Below we explain how to use mlr3resampling for this purpose, in simulated regression and classification problems. To use this method in real data, the important sections to read below are named “Benchmark: computing test error,” which show how to create these cross-validation experiments using mlr3 code.

Simulated regression problems

We begin by generating some data which can be used with regression algorithms. Assume there is a data set with some rows from one person, some rows from another,

N <- 300
library(data.table)
set.seed(1)
abs.x <- 2
reg.dt <- data.table(
  x=runif(N, -abs.x, abs.x),
  person=rep(1:2, each=0.5*N))
reg.pattern.list <- list(
  easy=function(x, person)x^2,
  impossible=function(x, person)(x^2+person*3)*(-1)^person)
reg.task.list <- list()
for(task_id in names(reg.pattern.list)){
  f <- reg.pattern.list[[task_id]]
  yname <- paste0("y_",task_id)
  reg.dt[, (yname) := f(x,person)+rnorm(N)][]
  task.dt <- reg.dt[, c("x","person",yname), with=FALSE]
  reg.task <- mlr3::TaskRegr$new(
    task_id, task.dt, target=yname)
  reg.task$col_roles$subset <- "person"
  reg.task$col_roles$stratum <- "person"
  reg.task$col_roles$feature <- "x"
  reg.task.list[[task_id]] <- reg.task
}
reg.dt
#>               x person      y_easy y_impossible
#>           <num>  <int>       <num>        <num>
#>   1: -0.9379653      1  1.32996609    -2.918082
#>   2: -0.5115044      1  0.24307692    -3.866062
#>   3:  0.2914135      1 -0.23314657    -3.837799
#>   4:  1.6328312      1  1.73677545    -7.221749
#>   5: -1.1932723      1 -0.06356159    -5.877792
#>  ---                                           
#> 296:  0.7257701      2 -2.48130642     5.180948
#> 297: -1.6033236      2  1.20453459     9.604312
#> 298: -1.5243898      2  1.89966190     7.511988
#> 299: -1.7982414      2  3.47047566    11.035397
#> 300:  1.7170157      2  0.60541972    10.719685

The table above shows some simulated data for two regression problems:

Static visualization of simulated data

First we reshape the data using the code below,

(reg.tall <- nc::capture_melt_single(
  reg.dt,
  task_id="easy|impossible",
  value.name="y"))
#>               x person    task_id           y
#>           <num>  <int>     <char>       <num>
#>   1: -0.9379653      1       easy  1.32996609
#>   2: -0.5115044      1       easy  0.24307692
#>   3:  0.2914135      1       easy -0.23314657
#>   4:  1.6328312      1       easy  1.73677545
#>   5: -1.1932723      1       easy -0.06356159
#>  ---                                         
#> 596:  0.7257701      2 impossible  5.18094849
#> 597: -1.6033236      2 impossible  9.60431191
#> 598: -1.5243898      2 impossible  7.51198770
#> 599: -1.7982414      2 impossible 11.03539747
#> 600:  1.7170157      2 impossible 10.71968480

The table above is a more convenient form for the visualization which we create using the code below,

if(require(animint2)){
  ggplot()+
    geom_point(aes(
      x, y),
      data=reg.tall)+
    facet_grid(
      task_id ~ person,
      labeller=label_both,
      space="free",
      scales="free")+
    scale_y_continuous(
      breaks=seq(-100, 100, by=2))
}
#> Loading required package: animint2

In the simulated data above, we can see that

  • for the easy pattern, it is the same for both people, so it should be possible/easy to train on one person, and accurately predict on another.
  • for the impossible pattern, it is different for each person, so it should not be possible to train on one person, and accurately predict on another.

Benchmark: computing test error

In the code below, we define a K-fold cross-validation experiment.

(reg_same_other <- mlr3resampling::ResamplingSameOtherCV$new())
#> <ResamplingSameOtherCV> : Same versus Other Cross-Validation
#> * Iterations:
#> * Instantiated: FALSE
#> * Parameters:
#> List of 1
#>  $ folds: int 3

In the code below, we define two learners to compare,

(reg.learner.list <- list(
  if(requireNamespace("rpart"))mlr3::LearnerRegrRpart$new(),
  mlr3::LearnerRegrFeatureless$new()))
#> Loading required namespace: rpart
#> [[1]]
#> <LearnerRegrRpart:regr.rpart>: Regression Tree
#> * Model: -
#> * Parameters: xval=0
#> * Packages: mlr3, rpart
#> * Predict Types:  [response]
#> * Feature Types: logical, integer, numeric, factor, ordered
#> * Properties: importance, missings, selected_features, weights
#> 
#> [[2]]
#> <LearnerRegrFeatureless:regr.featureless>: Featureless Regression Learner
#> * Model: -
#> * Parameters: robust=FALSE
#> * Packages: mlr3, stats
#> * Predict Types:  [response], se
#> * Feature Types: logical, integer, numeric, character, factor, ordered,
#>   POSIXct
#> * Properties: featureless, importance, missings, selected_features

In the code below, we define the benchmark grid, which is all combinations of tasks (easy and impossible), learners (rpart and featureless), and the one resampling method.

(reg.bench.grid <- mlr3::benchmark_grid(
  reg.task.list,
  reg.learner.list,
  reg_same_other))
#>          task          learner    resampling
#>        <char>           <char>        <char>
#> 1:       easy       regr.rpart same_other_cv
#> 2:       easy regr.featureless same_other_cv
#> 3: impossible       regr.rpart same_other_cv
#> 4: impossible regr.featureless same_other_cv

In the code below, we execute the benchmark experiment (in parallel using the multisession future plan).

if(FALSE){#for CRAN.
  if(require(future))plan("multisession")
}
if(require(lgr))get_logger("mlr3")$set_threshold("warn")
#> Loading required package: lgr
(reg.bench.result <- mlr3::benchmark(
  reg.bench.grid, store_models = TRUE))
#> <BenchmarkResult> of 72 rows with 4 resampling runs
#>  nr    task_id       learner_id resampling_id iters warnings errors
#>   1       easy       regr.rpart same_other_cv    18        0      0
#>   2       easy regr.featureless same_other_cv    18        0      0
#>   3 impossible       regr.rpart same_other_cv    18        0      0
#>   4 impossible regr.featureless same_other_cv    18        0      0

The code below computes the test error for each split,

reg.bench.score <- mlr3resampling::score(reg.bench.result)
reg.bench.score[1]
#>    train.subsets test.fold test.subset person iteration                  test
#>           <char>     <int>       <int>  <int>     <int>                <list>
#> 1:           all         1           1      1         1  1, 3, 5, 6,12,13,...
#>                    train                                uhash    nr
#>                   <list>                               <char> <int>
#> 1:  4, 7, 9,10,18,20,... 9d0598d4-4e81-4885-9be4-c6e919c8602e     1
#>               task task_id                       learner learner_id
#>             <list>  <char>                        <list>     <char>
#> 1: <TaskRegr:easy>    easy <LearnerRegrRpart:regr.rpart> regr.rpart
#>                 resampling resampling_id       prediction regr.mse algorithm
#>                     <list>        <char>           <list>    <num>    <char>
#> 1: <ResamplingSameOtherCV> same_other_cv <PredictionRegr> 1.638015     rpart

The code below visualizes the resulting test accuracy numbers.

if(require(animint2)){
  ggplot()+
    scale_x_log10()+
    geom_point(aes(
      regr.mse, train.subsets, color=algorithm),
      shape=1,
      data=reg.bench.score)+
    facet_grid(
      task_id ~ person,
      labeller=label_both,
      scales="free")
}

It is clear from the plot above that

  • for the easy task, training on same is just as good as all or other subsets. rpart has much lower test error than featureless, in all three train subsets.
  • for the impossible task, the least test error is using rpart with same train subsets; featureless with same train subsets is next best; training on all is substantially worse (for both featureless and rpart); training on other is even worse (patterns in the two people are completely different).
  • in a real data task, training on other will most likely not be quite as bad as in the impossible task above, but also not as good as in the easy task.

Interactive visualization of data, test error, and splits

The code below can be used to create an interactive data visualization which allows exploring how different functions are learned during different splits.

inst <- reg.bench.score$resampling[[1]]$instance
rect.expand <- 0.2
grid.dt <- data.table(x=seq(-abs.x, abs.x, l=101), y=0)
grid.task <- mlr3::TaskRegr$new("grid", grid.dt, target="y")
pred.dt.list <- list()
point.dt.list <- list()
for(score.i in 1:nrow(reg.bench.score)){
  reg.bench.row <- reg.bench.score[score.i]
  task.dt <- data.table(
    reg.bench.row$task[[1]]$data(),
    reg.bench.row$resampling[[1]]$instance$id.dt)
  names(task.dt)[1] <- "y"
  set.ids <- data.table(
    set.name=c("test","train")
  )[
  , data.table(row_id=reg.bench.row[[set.name]][[1]])
  , by=set.name]
  i.points <- set.ids[
    task.dt, on="row_id"
  ][
    is.na(set.name), set.name := "unused"
  ]
  point.dt.list[[score.i]] <- data.table(
    reg.bench.row[, .(task_id, iteration)],
    i.points)
  i.learner <- reg.bench.row$learner[[1]]
  pred.dt.list[[score.i]] <- data.table(
    reg.bench.row[, .(
      task_id, iteration, algorithm
    )],
    as.data.table(
      i.learner$predict(grid.task)
    )[, .(x=grid.dt$x, y=response)]
  )
}
(pred.dt <- rbindlist(pred.dt.list))
#>          task_id iteration   algorithm     x        y
#>           <char>     <int>      <char> <num>    <num>
#>    1:       easy         1       rpart -2.00 3.557968
#>    2:       easy         1       rpart -1.96 3.557968
#>    3:       easy         1       rpart -1.92 3.557968
#>    4:       easy         1       rpart -1.88 3.557968
#>    5:       easy         1       rpart -1.84 3.557968
#>   ---                                                
#> 7268: impossible        18 featureless  1.84 7.204232
#> 7269: impossible        18 featureless  1.88 7.204232
#> 7270: impossible        18 featureless  1.92 7.204232
#> 7271: impossible        18 featureless  1.96 7.204232
#> 7272: impossible        18 featureless  2.00 7.204232
(point.dt <- rbindlist(point.dt.list))
#>           task_id iteration set.name row_id           y          x  fold person
#>            <char>     <int>   <char>  <int>       <num>      <num> <int>  <int>
#>     1:       easy         1     test      1  1.32996609 -0.9379653     1      1
#>     2:       easy         1    train      2  0.24307692 -0.5115044     3      1
#>     3:       easy         1     test      3 -0.23314657  0.2914135     1      1
#>     4:       easy         1    train      4  1.73677545  1.6328312     2      1
#>     5:       easy         1     test      5 -0.06356159 -1.1932723     1      1
#>    ---                                                                         
#> 21596: impossible        18    train    296  5.18094849  0.7257701     1      2
#> 21597: impossible        18    train    297  9.60431191 -1.6033236     1      2
#> 21598: impossible        18     test    298  7.51198770 -1.5243898     3      2
#> 21599: impossible        18    train    299 11.03539747 -1.7982414     1      2
#> 21600: impossible        18     test    300 10.71968480  1.7170157     3      2
#>        subset display_row
#>         <int>       <int>
#>     1:      1           1
#>     2:      1         101
#>     3:      1           2
#>     4:      1          51
#>     5:      1           3
#>    ---                   
#> 21596:      2         198
#> 21597:      2         199
#> 21598:      2         299
#> 21599:      2         200
#> 21600:      2         300
set.colors <- c(
  train="#1B9E77",
  test="#D95F02",
  unused="white")
algo.colors <- c(
  featureless="blue",
  rpart="red")
make_person_subset <- function(DT){
  DT[, "person/subset" := person]
}
make_person_subset(point.dt)
make_person_subset(reg.bench.score)

if(require(animint2)){
  viz <- animint(
    title="Train/predict on subsets, regression",
    pred=ggplot()+
      ggtitle("Predictions for selected train/test split")+
      theme_animint(height=400)+
      scale_fill_manual(values=set.colors)+
      geom_point(aes(
        x, y, fill=set.name),
        showSelected="iteration",
        size=3,
        shape=21,
        data=point.dt)+
      scale_color_manual(values=algo.colors)+
      geom_line(aes(
        x, y, color=algorithm, subset=paste(algorithm, iteration)),
        showSelected="iteration",
        data=pred.dt)+
      facet_grid(
        task_id ~ `person/subset`,
        labeller=label_both,
        space="free",
        scales="free")+
      scale_y_continuous(
        breaks=seq(-100, 100, by=2)),
    err=ggplot()+
      ggtitle("Test error for each split")+
      theme_animint(height=400)+
      scale_y_log10(
        "Mean squared error on test set")+
      scale_fill_manual(values=algo.colors)+
      scale_x_discrete(
        "People/subsets in train set")+
      geom_point(aes(
        train.subsets, regr.mse, fill=algorithm),
        shape=1,
        size=5,
        stroke=2,
        color="black",
        color_off=NA,
        clickSelects="iteration",
        data=reg.bench.score)+
      facet_grid(
        task_id ~ `person/subset`,
        labeller=label_both,
        scales="free"),
    diagram=ggplot()+
      ggtitle("Select train/test split")+
      theme_bw()+
      theme_animint(height=300)+
      facet_grid(
        . ~ train.subsets,
        scales="free",
        space="free")+
      scale_size_manual(values=c(subset=3, fold=1))+
      scale_color_manual(values=c(subset="orange", fold="grey50"))+
      geom_rect(aes(
        xmin=-Inf, xmax=Inf,
        color=rows,
        size=rows,
        ymin=display_row, ymax=display_end),
        fill=NA,
        data=inst$viz.rect.dt)+
      scale_fill_manual(values=set.colors)+
      geom_rect(aes(
        xmin=iteration-rect.expand, ymin=display_row,
        xmax=iteration+rect.expand, ymax=display_end,
        fill=set.name),
        clickSelects="iteration",
        data=inst$viz.set.dt)+
      geom_text(aes(
        ifelse(rows=="subset", Inf, -Inf),
        (display_row+display_end)/2,
        hjust=ifelse(rows=="subset", 1, 0),
        label=paste0(rows, "=", ifelse(rows=="subset", subset, fold))),
        data=data.table(train.name="same", inst$viz.rect.dt))+
      scale_x_continuous(
        "Split number / cross-validation iteration")+
      scale_y_continuous(
        "Row number"),
    source="https://github.com/tdhock/mlr3resampling/blob/main/vignettes/ResamplingSameOtherCV.Rmd")
  viz
}