-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path07-sl3.Rmd
2098 lines (1781 loc) · 92.7 KB
/
07-sl3.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
---
output:
pdf_document: default
html_document: default
---
# Super Learning {#sl3}
_Rachael Phillips_
Based on the [`sl3` `R` package](https://github.com/tlverse/sl3) by _Jeremy
Coyle, Nima Hejazi, Ivana Malenica, Rachael Phillips, and Oleg Sofrygin_.
### Learning Objectives {-}
By the end of this chapter you will be able to:
1. Select a performance metric that is optimized by the true prediction
function, or define the true prediction prediction of interest as the
optimizer of the performance metric.
2. Assemble a diverse set ("library") of learners to be considered in the super
learner. In particular, you should be able to:
a. Customize a learner by modifying its tuning parameters.
b. Create variations of the same base learner with different tuning
parameter specifications.
c. Couple screener(s) with learner(s) to create learners that consider as
covariates a reduced, screener-selected subset of them.
3. Specify a meta-learner that optimizes the objective function of interest.
4. Justify the library and the meta-learner in terms of the prediction problem
at hand, intended use of the analysis in the real world, statistical model,
sample size, number of covariates, and outcome prevalence for discrete
outcomes.
5. Interpret the fit for a super learner from the table of cross-validated risk
estimates and the super learner coefficients.
## Introduction
A common task in data analysis is prediction, or using the observed data to
learn a function that takes as input data on covariates/predictors and outputs a
predicted value. Occasionally, the scientific question of interest lends itself
to causal effect estimation. Even in these scenarios, where prediction is not in
the forefront, prediction tasks are embedded in the procedure. For instance, in
targeted minimum loss-based estimation (TMLE), predictive modeling is necessary
for estimating outcome regressions and propensity scores.
There are various strategies that can be employed to model relationships from
data, which we refer to interchangeably as "estimators", "algorithms", and
"learners". For some data algorithms that can pick up on complex relationships
in the data are necessary to adequately model it, and for other data parametric
regression learners might fit the data reasonably well. It is generally
impossible to know in advance which approach will be the best for a given data
set and prediction problem.
The Super Learner (SL) solves the issue of selecting an algorithm, as it can
consider many of them, from the simplest parametric regressions to the most
complex machine learning algorithms (e.g., neural nets, support vector machines,
etc). Additionally, it is proven to perform as well as possible in large
samples, given the learners specified [@vdl2007super]. The SL represents an
entirely pre-specified, flexible, and theoretically grounded approach for
predictive modeling. It has been shown to be adaptive and robust in a variety of
applications, and in even in very small samples. Detailed descriptions outlining
the SL procedure are widely available [@polley2010super; @naimi2018stacked].
Practical considerations for specifying the SL, including how to specify a rich
and diverse library of learners, choose a performance metric for the SL, and
specify a cross-validation (CV) scheme, are described in a pre-print article
[@rvp2022super]. Here, we focus on introducing `sl3`, the standard `tlverse`
software package for SL.
<!--
Add more about sl3: how it's different from SuperLearner, supported features,
GitHub, filing issues, support, documentation, etc.
-->
## Super Learning with `sl3`: {-}
## How to Fit the Super Learner
In this section, the core functionality for fitting any SL with `sl3` is
illustrated. In the sections that follow, additional `sl3` functionality is
presented.
Fitting any SL with `sl3` consists of the following three steps:
1. Define the prediction task with `make_sl3_Task`.
2. Instantiate the SL with `Lrnr_sl`.
3. Fit the SL to the task with `train`.
#### Running example with WASH Benefits dataset {-}
We will use the WASH Benefits Bangladesh study as an example to guide this
overview of `sl3`. In this study, say we are interested in predicting the child development outcome, weight-for-height z-score, from covariates/predictors,
including socio-economic status variables, gestational age, and maternal
features. More information on this dataset is described in the ["Meet
the Data"](https://tlverse.org/tlverse-handbook/data.html#wash) chapter of the
`tlverse` handbook.
#### Preliminaries {-}
First, we need to load the data and relevant packages into the R session.
##### Load the data {-}
```{r setup-handbook-utils-noecho, echo = FALSE}
library(knitr)
library(kableExtra)
library(data.table)
```
We will use the `fread` function in the `data.table` R package to load the
WASH Benefits example dataset:
```{r load-data}
library(data.table)
washb_data <- fread(
paste0(
"https://raw.githubusercontent.com/tlverse/tlverse-data/master/",
"wash-benefits/washb_data.csv"
),
stringsAsFactors = TRUE
)
```
Next, we will take a peek at the first few rows of our dataset:
```{r show-data-normal-noeval, eval = FALSE}
head(washb_data)
```
```{r show-data-handbook, echo = FALSE}
if (knitr::is_latex_output()) {
head(washb_data) %>%
kable(format = "latex")
} else if (knitr::is_html_output()) {
head(washb_data) %>%
kable() %>%
kableExtra:::kable_styling(fixed_thead = TRUE) %>%
scroll_box(width = "100%", height = "300px")
}
```
##### Install `sl3` software (as needed) {-}
To install any package, we recommend first clearing the R workspace and then
restarting the R session. In RStudio, this can be achieved by clicking the
tab "Session" then "Clear Workspace", and then clicking "Session" again then
"Restart R".
We can install `sl3` using the function `install_github` provided in the
`devtools` R package. We are using the development ("devel") version of `sl3`
in these materials, so we show how to install that version below.
```{r install-sl3, eval = FALSE}
library(devtools)
install_github("tlverse/sl3@devel")
```
Once the R package is installed, we recommend restarting the R session again.
<!--
If you would like to use newer `sl3` functionality that is available in the
devel branch of the `sl3` GitHub repository, you need to install that version
of the package (i.e., `usethis::install_github(tlverse/sl3@devel)`), re-start
your `R` session, and then re-load the `sl3` package.
-->
##### Load `sl3` software {-}
Once `sl3` is installed, we can load it like any other `R` package:
```{r load-sl3}
library(sl3)
```
### 1. Define the prediction task with `make_sl3_Task` {-}
The `sl3_Task` object defines the prediction task of interest. Recall that
our task in this illustrative example is to use the WASH Benefits Bangladesh
example dataset to learn a function of the covariates for predicting
weight-for-height Z-score `whz`.
```{r task}
# create the task (i.e., use washb_data to predict outcome using covariates)
task <- make_sl3_Task(
data = washb_data,
outcome = "whz",
covariates = c("tr", "fracode", "month", "aged", "sex", "momage", "momedu",
"momheight", "hfiacat", "Nlt18", "Ncomp", "watmin", "elec",
"floor", "walls", "roof", "asset_wardrobe", "asset_table",
"asset_chair", "asset_khat", "asset_chouki", "asset_tv",
"asset_refrig", "asset_bike", "asset_moto", "asset_sewmach",
"asset_mobile")
)
# let's examine the task
task
```
The `sl3_Task` keeps track of the roles the variables play in the prediction
problem. Additional information relevant to the prediction task (such as
observational-level weights, offset, id, CV folds) can also be specified in
`make_sl3_Task`. The default CV fold structure in `sl3` is V-fold CV (VFCV)
with V=10 folds; if `id` is specified in the task then a clustered V=10 VFCV
scheme is considered, and if the outcome type is binary or categorical then
a stratified V=10 VFCV scheme is considered. Different CV schemes can be
specified by inputting an `origami` folds object, as generated by the
`make_folds` function in the `origami` R package. Refer to the documentation
on `origami`'s `make_folds` function for more information (e.g., in RStudio, by
loading the `origami` R package and then inputting "?make_folds" in the
Console). For more details on `sl3_Task`, refer to its documentation (e.g., by
inputting "?sl3_Task" in R).
*Tip:* If you type `task$` and then press the tab key (press tab twice if not in
RStudio), you can view all of the active and public fields, and methods that
can be accessed from the `task$` object. This `$` is like the key to access
many internals of an object. In the next section, will see how we can use `$`
to dig into SL fit objects as well, to obtain predictions from an SL fit or
candidate learners, examine an SL fit or its candidates, and summarize an SL
fit.
### 2. Instantiate the Super Learner with `Lrnr_sl` {-}
In order to create `Lrnr_sl` we need to specify, at the minimum, a set of
learners for the SL to consider as candidates. This set of algorithms is
also commonly referred to as the "library". We might also specify the
meta-learner, which is the algorithm that ensembles the learners, but this is
optional since there are already defaults set up in `sl3`. See "Practical
considerations for specifying a super learner" for step-by-step guidelines for
tailoring the SL specification, including the library and meta-learner(s), to
perform well for the prediction task at hand [@rvp2022super].
Learners have properties that indicate what features they support. We may use
the `sl3_list_properties()` function to get a list of all properties supported
by at least one learner:
```{r list-properties}
sl3_list_properties()
```
Since `whz` is a continuous outcome, we can identify the learners that support
this outcome type with `sl3_list_learners()`:
```{r list-learners}
sl3_list_learners(properties = "continuous")
```
Now that we have an idea of some learners, let's instantiate a few of them.
Below we instantiate `Lrnr_glm` and `Lrnr_mean`, a main terms generalized
linear model (GLM) and a mean model, respectively.
```{r learners}
lrn_glm <- Lrnr_glm$new()
lrn_mean <- Lrnr_mean$new()
```
For both of the learners created above, we just used the default tuning
parameters. We can also customize a learner's tuning parameters to incorporate
a diversity of different settings, and consider the same learner with different
tuning parameter specifications.
Below, we consider the same base learner, `Lrnr_glmnet` (i.e., GLMs
with elastic net regression), and create two different candidates from it:
an L2-penalized/ridge regression and an L1-penalized/lasso regression.
```{r more-learners}
# penalized regressions:
lrn_ridge <- Lrnr_glmnet$new(alpha = 0)
lrn_lasso <- Lrnr_glmnet$new(alpha = 1)
```
By setting `alpha` in `Lrnr_glmnet` above, we customized this learner's tuning
parameter. When we instantiate `Lrnr_hal9001` below we show how multiple tuning
parameters (specifically, `max_degree`and `num_knots`) can be modified at the
same time.
Let's also instantiate some more learners that do not enforce relationships to
be linear or monotonic, and to further diversify the set of candidates to
include nonparametric learners, since up to this point all of the learners we've
instantiated have been parametric.
```{r more-learners-np}
# spline regressions:
lrn_polspline <- Lrnr_polspline$new()
lrn_earth <- Lrnr_earth$new()
# fast highly adaptive lasso (HAL) implementation
lrn_hal <- Lrnr_hal9001$new(max_degree = 2, num_knots = c(3,2), nfolds = 5)
# tree-based methods
lrn_ranger <- Lrnr_ranger$new()
lrn_xgb <- Lrnr_xgboost$new()
```
Let's also include a generalized additive model (GAM) and Bayesian GLM to
further diversify the pool that we will consider as candidates in the SL.
```{r more-learners-final}
lrn_gam <- Lrnr_gam$new()
lrn_bayesglm <- Lrnr_bayesglm$new()
```
Now that we've instantiated a set of learners, we need to put them together so
the SL can consider them as candidates. In `sl3`, we do this by creating a
so-called `Stack` of learners. A `Stack` is created in the same way we
created the learners. This is because `Stack` is a learner itself; it has the
same interface as all of the other learners. What makes a stack special is that
it considers multiple learners at once: it can train them simultaneously, so
that their predictions can be combined and/or compared.
```{r stack}
stack <- Stack$new(
lrn_glm, lrn_mean, lrn_ridge, lrn_lasso, lrn_polspline, lrn_earth, lrn_hal,
lrn_ranger, lrn_xgb, lrn_gam, lrn_bayesglm
)
stack
```
We can see that the names of the learners in the stack are long. This is
because the default naming of a learner in `sl3` is clunky: for each learner,
every tuning parameter in `sl3` is contained in the name. In the next section,
["Naming
Learners"](https://tlverse.org/tlverse-handbook/sl3.html#naming-learners),
we show a few different ways for the user to name learners as they wish.
Now that we have instantiated a set of learners and stacked them together, we
are ready to instantiate the SL. We will use the default meta-learner, which is
non-negative least squares (NNLS) regression (`Lrnr_nnls`) for continuous
outcomes, and we will still go ahead and specify it for illustrative purposes.
```{r make-sl}
sl <- Lrnr_sl$new(learners = stack, metalearner = Lrnr_nnls$new())
```
### 3. Fit the Super Learner to the prediction task with `train` {-}
The last step for fitting the SL to the prediction task is to call `train` and
supply the task. Before we call `train`, we will set a random number generator
so the results are reproducible, and we will also time it.
```{r train-sl}
start_time <- proc.time() # start time
set.seed(4197)
sl_fit <- sl$train(task = task)
runtime_sl_fit <- proc.time() - start_time # end time - start time = run time
runtime_sl_fit
```
It took `r round(as.numeric(runtime_sl_fit["elapsed"]),1)` seconds
(`r round(as.numeric(runtime_sl_fit["elapsed"])/60,1)` minutes) to fit the SL.
#### Summary {-}
In this section, the core functionality for fitting any SL with `sl3` was
illustrated. This consists of the following three steps:
1. Define the prediction task with `make_sl3_Task`.
2. Instantiate the SL with `Lrnr_sl`.
3. Fit the SL to the task with `train`.
This example was for demonstrative purposes only. See @rvp2022super for
step-by-step guidelines for constructing a SL that is well-specified for the
prediction task at hand.
## Additional `sl3` Topics: {-}
## Obtaining Predictions
### Super learner and candidate learner predictions
We will draw on the fitted SL object from above, `sl_fit`, to obtain the
SL's predicted `whz` value for each subject.
```{r sl-predictions}
sl_preds <- sl_fit$predict(task = task)
head(sl_preds)
```
We can also obtain predicted values from a candidate learner in the SL. Below
we obtain predictions for the GLM learner.
```{r glm-predictions}
glm_preds <- sl_fit$learner_fits$Lrnr_glm_TRUE$predict(task = task)
head(glm_preds)
```
Note that the predicted values for the SL correspond to so-called "full fits"
of the candidate learners, in which the candidates are fit to the entire
analytic dataset, i.e., all of the data supplied as `data` to `make_sl3_Task`.
Figure 2 in @rvp2022super provides a visual overview of the SL fitting
procedure.
```{r glm-predictions-fullfit}
# we can also access the candidate learner full fits directly and obtain
# the same "full fit" candidate predictions from there
# (we split this into two lines to avoid overflow)
stack_full_fits <- sl_fit$fit_object$full_fit$learner_fits$Stack$learner_fits
glm_preds_full_fit <- stack_full_fits$Lrnr_glm_TRUE$predict(task)
# check that they are identical
identical(glm_preds, glm_preds_full_fit)
```
Below we visualize the observed values for `whz` and predicted `whz` values for
SL, GLM and the mean.
```{r predvobs}
# table of observed and predicted outcome values and arrange by observed values
df_plot <- data.table(
Obs = washb_data[["whz"]], SL_Pred = sl_preds, GLM_Pred = glm_preds,
Mean_Pred = sl_fit$learner_fits$Lrnr_mean$predict(task)
)
df_plot <- df_plot[order(df_plot$Obs), ]
```
```{r predvobs-head, eval = FALSE}
head(df_plot)
```
```{r predvobs-head-handbook, echo = FALSE}
if (knitr::is_latex_output()) {
head(df_plot) %>%
kable(format = "latex")
} else if (knitr::is_html_output()) {
head(df_plot) %>%
kable() %>%
kableExtra:::kable_styling(fixed_thead = TRUE) %>%
scroll_box(width = "100%", height = "300px")
}
```
```{r predobs-plot, fig.asp = .55, fig.cap = "Observed and predicted values for weight-for-height z-score (whz)"}
# melt the table so we can plot observed and predicted values
df_plot$id <- seq(1:nrow(df_plot))
df_plot_melted <- melt(
df_plot, id.vars = "id",
measure.vars = c("Obs", "SL_Pred", "GLM_Pred", "Mean_Pred")
)
library(ggplot2)
ggplot(df_plot_melted, aes(id, value, color = variable)) +
geom_point(size = 0.1) +
labs(x = "Subjects (ordered by increasing whz)",
y = "whz") +
theme(legend.position = "bottom", legend.title = element_blank(),
axis.text.x = element_blank(), axis.ticks.x = element_blank()) +
guides(color = guide_legend(override.aes = list(size = 1)))
```
### Cross-validated predictions
We can also obtain the cross-validated (CV) predictions for the candidate
learners. We can do this is a few different ways.
```{r cv-predictions}
# one way to obtain the CV predictions for the candidate learners
cv_preds_option1 <- sl_fit$fit_object$cv_fit$predict_fold(
task = task, fold_number = "validation"
)
# another way to obtain the CV predictions for the candidate learners
cv_preds_option2 <- sl_fit$fit_object$cv_fit$predict(task = task)
# we can check that they are identical
identical(cv_preds_option1, cv_preds_option2)
```
```{r cv-predictions-head, eval = FALSE}
head(cv_preds_option1)
```
```{r cv-predictions-head-handbook, echo = FALSE}
if (knitr::is_latex_output()) {
head(cv_preds_option1) %>%
kable(format = "latex")
} else if (knitr::is_html_output()) {
head(cv_preds_option1) %>%
kable() %>%
kableExtra:::kable_styling(fixed_thead = TRUE) %>%
scroll_box(width = "100%", height = "300px")
}
```
##### `predict_fold` {-}
Our first option to get CV predictions, `cv_preds_option1`, used the
`predict_fold` function to obtain validation set predictions across all folds.
This function only exists for learner fits that are cross-validated in `sl3`,
like those in `Lrnr_sl`. In addition to supplying `fold_number = "validation"`
in `predict_fold`, we can set `fold_number = "full"` to obtain predictions from
learners fit to the entire analytic dataset (i.e., all of the data supplied to
`make_sl3_Task`). For instance, below we show that `glm_preds` we calculated
above can also be obtained by setting `fold_number = "full"`.
```{r glm-predict-fold}
full_fit_preds <- sl_fit$fit_object$cv_fit$predict_fold(
task = task, fold_number = "full"
)
glm_full_fit_preds <- full_fit_preds$Lrnr_glm_TRUE
# check that they are identical
identical(glm_preds, glm_full_fit_preds)
```
We can also supply a specific an integer between 1 and the number of CV folds
to the `fold_number` argument in `predict_fold`, and an example of this
functionality is shown in the next part.
##### Cross-validated predictions by hand {-}
We can get the CV predictions "by hand", by tapping into each of the folds, and
then using the fitted candidate learners (which were trained to the training
set for each fold) to predict validation set outcomes (which were not seen in
training).
```{r cv-predictions-long}
##### CV predictions "by hand" #####
# for each fold, i, we obtain validation set predictions:
cv_preds_list <- lapply(seq_along(task$folds), function(i){
# get validation dataset for fold i:
v_data <- task$data[task$folds[[i]]$validation_set, ]
# get observed outcomes in fold i's validation dataset:
v_outcomes <- v_data[["whz"]]
# make task (for prediction) using fold i's validation dataset as data,
# and keeping all else the same:
v_task <- make_sl3_Task(covariates = task$nodes$covariates, data = v_data)
# get predicted outcomes for fold i's validation dataset, using candidates
# trained to fold i's training dataset
v_preds <- sl_fit$fit_object$cv_fit$predict_fold(
task = v_task, fold_number = i
)
# note: v_preds is a matrix of candidate learner predictions, where the
# number of rows is the number of observations in fold i's validation dataset
# and the number of columns is the number of candidate learners (excluding
# any that might have failed)
# an identical way to get v_preds, which is used when we calculate the
# cv risk by hand in a later part of this chapter:
# v_preds <- sl_fit$fit_object$cv_fit$fit_object$fold_fits[[i]]$predict(
# task = v_task
# )
# we will also return the row indices for fold i's validation set, so we
# can later reorder the CV predictions and make sure they are equal to what
# we obtained above
return(list("v_preds" = v_preds, "v_index" = task$folds[[i]]$validation_set))
})
# extract the validation set predictions across all folds
cv_preds_byhand <- do.call(rbind, lapply(cv_preds_list, "[[", "v_preds"))
# extract the indices of validation set observations across all folds
# then reorder cv_preds_byhand to correspond to the ordering in the data
row_index_in_data <- unlist(lapply(cv_preds_list, "[[", "v_index"))
cv_preds_byhand_ordered <- cv_preds_byhand[order(row_index_in_data), ]
# now we can check that they are identical
identical(cv_preds_option1, cv_preds_byhand_ordered)
```
### Predictions with new data
If we wanted to obtain predicted values for new data then we would need to
create a new `sl3_Task` from the new data. Also, the covariates in this new
`sl3_Task` must be identical to the covariates in the `sl3_Task` for training.
As an example, let's assume we have new covariate data `washb_data_new` for
which we want to use the fitted SL to obtain predicted weight-for-height
z-score values.
```{r predictions-new-task, eval = FALSE}
# we do not evaluate this code chunk, as `washb_data_new` does not exist
prediction_task <- make_sl3_Task(
data = washb_data_new, # assuming we have some new data for predictions
covariates = c("tr", "fracode", "month", "aged", "sex", "momage", "momedu",
"momheight", "hfiacat", "Nlt18", "Ncomp", "watmin", "elec",
"floor", "walls", "roof", "asset_wardrobe", "asset_table",
"asset_chair", "asset_khat", "asset_chouki", "asset_tv",
"asset_refrig", "asset_bike", "asset_moto", "asset_sewmach",
"asset_mobile")
)
sl_preds_new_task <- sl_fit$predict(task = prediction_task)
```
### Counterfactual predictions
Counterfactual predictions are predicted values under an intervention of
interest. Recall from above that we can obtain predicted values for new data by
creating a `sl3_Task` with the new data whose covariates match the set
considered for training. As an example that draws on the WASH Benefits
Bangladesh study, suppose we would like to obtain predictions for every
subject's weight-for-height z-score (`whz`) outcome under an intervention on
treatment (`tr`) that sets it to the nutrition, water, sanitation, and
handwashing regime.
First we need to create a copy of the dataset, and then we can intervene on
`tr` in the copied dataset, create a new `sl3_Task` using the copied data and
the same covariates as the training task, and finally obtain predictions
from the fitted SL (which we named `sl_fit` in the previous section).
```{r cf-predictions-static}
### 1. Copy data
tr_intervention_data <- data.table::copy(washb_data)
### 2. Define intervention in copied dataset
tr_intervention <- rep("Nutrition + WSH", nrow(washb_data))
# NOTE: When we intervene on a categorical variable (such as "tr"), we need to
# define the intervention as a categorical variable (ie a factor).
# Also, even though not all levels of the factor will be represented in
# the intervention, we still need this factor to reflect all of the
# levels that are present in the observed data
tr_levels <- levels(washb_data[["tr"]])
tr_levels
tr_intervention <- factor(tr_intervention, levels = tr_levels)
tr_intervention_data[,"tr" := tr_intervention, ]
### 3. Create a new sl3_Task
# note that we do not need to specify the outcome in this new task since we are
# only using it to obtain predictions
tr_intervention_task <- make_sl3_Task(
data = tr_intervention_data,
covariates = c("tr", "fracode", "month", "aged", "sex", "momage", "momedu",
"momheight", "hfiacat", "Nlt18", "Ncomp", "watmin", "elec",
"floor", "walls", "roof", "asset_wardrobe", "asset_table",
"asset_chair", "asset_khat", "asset_chouki", "asset_tv",
"asset_refrig", "asset_bike", "asset_moto", "asset_sewmach",
"asset_mobile")
)
### 4. Get predicted values under intervention of interest
# SL predictions of what "whz" would have been had everyone received "tr"
# equal to "Nutrition + WSH"
counterfactual_pred <- sl_fit$predict(tr_intervention_task)
```
Note that this type of intervention, where every subject receives the same
intervention, is referred to as "static". Interventions that vary depending on
the characteristics of the subject are referred to as "dynamic". For instance,
we might consider an intervention that sets the treatment to the desired
(nutrition, water, sanitation, and handwashing) regime if the subject has
a refridgerator, and a nutrition-omitted (water, sanitation, and handwashing)
regime otherwise.
```{r cf-predictions-dynamic}
dynamic_tr_intervention_data <- data.table::copy(washb_data)
dynamic_tr_intervention <- ifelse(
washb_data[["asset_refrig"]] == 1, "Nutrition + WSH", "WSH"
)
dynamic_tr_intervention <- factor(dynamic_tr_intervention, levels = tr_levels)
dynamic_tr_intervention_data[,"tr" := dynamic_tr_intervention, ]
dynamic_tr_intervention_task <- make_sl3_Task(
data = dynamic_tr_intervention_data,
covariates = c("tr", "fracode", "month", "aged", "sex", "momage", "momedu",
"momheight", "hfiacat", "Nlt18", "Ncomp", "watmin", "elec",
"floor", "walls", "roof", "asset_wardrobe", "asset_table",
"asset_chair", "asset_khat", "asset_chouki", "asset_tv",
"asset_refrig", "asset_bike", "asset_moto", "asset_sewmach",
"asset_mobile")
)
### 4. Get predicted values under intervention of interest
# SL predictions of what "whz" would have been had every subject received "tr"
# equal to "Nutrition + WSH" if they had a fridge and "WSH" if they didn't have
# a fridge
counterfactual_pred <- sl_fit$predict(dynamic_tr_intervention_task)
```
## Summarizing Super Learner Fits
### Super Learner coefficients / fitted meta-learner summary
We can see how the meta-learner created a function of the learners in a few
ways. In our illustrative example, we considered the default, NNLS meta-learner
for continuous outcomes. For meta-learners that simply learn a weighted
combination, we can examine their coefficients.
```{r sl-coefs-simple}
round(sl_fit$coefficients, 3)
```
We can also examine the coefficients by directly accessing the meta-learner's
fit object.
```{r metalearner-fit}
metalrnr_fit <- sl_fit$fit_object$cv_meta_fit$fit_object
round(metalrnr_fit$coefficients, 3)
```
Direct access to the meta-learner fit object is also handy for more
complex meta-learners (e.g., non-parametric meta-learners) that are not defined
by a simple set of main terms regression coefficients.
### Cross-validated predictive performance
We can obtain a table of the cross-validated (CV) predictive performance, i.e.,
the CV risk, for each learner included in the SL. Below, we use the
squared error loss for the evaluation function, which equates to the mean
squared error (MSE) as the metric to summarize predictive performance. The
reason why we use the MSE is because it is a valid metric for estimating the
conditional mean, which is what we're learning the prediction function for in
the WASH Benefits example. For more information on selecting an appropriate
performance metric, see @rvp2022super.
```{r sl-summary}
cv_risk_table <- sl_fit$cv_risk(eval_fun = loss_squared_error)
```
```{r cv-risk-summary, eval = FALSE}
cv_risk_table[,c(1:3)]
```
```{r cv-risk-summary-handbook, echo = FALSE}
if (knitr::is_latex_output()) {
cv_risk_table[,c(1:3)] %>%
kable(format = "latex")
} else if (knitr::is_html_output()) {
cv_risk_table[,c(1:3)] %>%
kable() %>%
kableExtra:::kable_styling(fixed_thead = TRUE) %>%
scroll_box(width = "100%", height = "300px")
}
```
##### Cross-validated predictive performance by hand {-}
Similar to how we got the CV predictions "by hand", we can also calculate the CV
performance/risk in a way that exposes the procedure. Specifically, this is done
by tapping into each of the folds, and then using the fitted candidate learners
(which were trained to the training set for each fold) to predict validation set
outcomes (which were not seen in training) and then measure the predictive
performance (i.e., risk). Each candidate learner's fold-specific risk is then
averaged across all folds to obtain the CV risk. The function `cv_risk` does
all of this internally and we show how to do it by hand below, which can be
helpful for understanding the CV risk and how it is calculated.
```{r cv-risk-byhand}
##### CV risk "by hand" #####
# for each fold, i, we obtain predictive performance/risk for each candidate:
cv_risks_list <- lapply(seq_along(task$folds), function(i){
# get validation dataset for fold i:
v_data <- task$data[task$folds[[i]]$validation_set, ]
# get observed outcomes in fold i's validation dataset:
v_outcomes <- v_data[["whz"]]
# make task (for prediction) using fold i's validation dataset as data,
# and keeping all else the same:
v_task <- make_sl3_Task(covariates = task$nodes$covariates, data = v_data)
# get predicted outcomes for fold i's validation dataset, using candidates
# trained to fold i's training dataset
v_preds <- sl_fit$fit_object$cv_fit$fit_object$fold_fits[[i]]$predict(v_task)
# note: v_preds is a matrix of candidate learner predictions, where the
# number of rows is the number of observations in fold i's validation dataset
# and the number of columns is the number of candidate learners (excluding
# any that might have failed)
# calculate predictive performance for fold i for each candidate
eval_function <- loss_squared_error # valid for estimation of conditional mean
v_losses <- apply(v_preds, 2, eval_function, v_outcomes)
cv_risks <- colMeans(v_losses)
return(cv_risks)
})
# average the predictive performance across all folds for each candidate
cv_risks_byhand <- colMeans(do.call(rbind, cv_risks_list))
cv_risk_table_byhand <- data.table(
learner = names(cv_risks_byhand), MSE = cv_risks_byhand
)
# check that the CV risks are identical when calculated by hand and function
# (ignoring small differences by rounding to the fourth decimal place)
identical(
round(cv_risk_table_byhand$MSE,4), round(as.numeric(cv_risk_table$MSE),4)
)
```
<!--
We can plot the CV risks as well.
```{r sl-summary-plot, eval = F}
# Column "se" in the CV risk table is the standard error across all losses for
# a learner, i.e., se = sd(loss)/sqrt(n), where loss is an n length vector of
# validation set predictions across all folds, and n is the number of
# validation set observations across all folds. We can use this to
cv_risk_table[, "lower" := MSE - qnorm(.975)*se]
cv_risk_table[, "upper" := MSE + qnorm(.975)*se]
ggplot(cv_risk_table,
aes_string(x = "learner", y = "MSE", ymin = "lower", ymax = "upper")) +
geom_pointrange() +
coord_flip() +
ylab("V-fold CV Risk Estimate") +
xlab("Learner")
```
-->
<!--
Column "se" in the CV risk table is the standard error across all losses for a learner, i.e., se = sd(loss)/sqrt(n), where loss is an n length vector of validation set predictions across all folds, and n is the number of validation set observations across all folds.
Column "fold_sd" in the CV risk table is the standard deviation of the V-fold-specific risks, i.e., fold_sd = sd(risk), where risk is a V length vector of the mean loss across the folds.
-->
### Cross-validated Super Learner
We can see from the CV risk table above that the SL is not listed. This is
because we do not have a CV risk for the SL unless we cross-validate it or
include it as a candidate in another SL; the latter is shown in [the next
subsection](https://tlverse.org/tlverse-handbook/sl3.html#discrete-super-learner).
Below, we show how to obtain a CV risk estimate for the SL using function
`cv_sl`. Like before when we called `sl$train`, we will set a random number
generator so the results are reproducible, and we will also time this.
<!--
(Note: we did not evaluate this part of the code, as it was
taking too long).
-->
```{r cvsl, eval = FALSE}
start_time <- proc.time()
set.seed(569)
cv_sl_fit <- cv_sl(lrnr_sl = sl_fit, task = task, eval_fun = loss_squared_error)
runtime_cv_sl_fit <- proc.time() - start_time
runtime_cv_sl_fit
```
```{r cvsl-save, eval = FALSE, echo = FALSE}
library(here)
save(cv_sl_fit, file=here("data", "fit_objects", "cv_sl_fit.Rdata"), compress=T)
save(runtime_cv_sl_fit, file=here("data", "fit_objects", "runtime_cv_sl_fit.Rdata"))
```
```{r cvsl-load, eval = TRUE, echo = FALSE}
library(here)
load(here("data", "fit_objects", "cv_sl_fit.Rdata"))
load(here("data", "fit_objects", "runtime_cv_sl_fit.Rdata"))
runtime_cv_sl_fit
```
It took `r round(as.numeric(runtime_cv_sl_fit["elapsed"]),1)` seconds (`r round(as.numeric(runtime_cv_sl_fit["elapsed"])/60,1)` minutes) to fit the CV SL.
```{r cvsl-risk-summary, eval = FALSE}
cv_sl_fit$cv_risk[,c(1:3)]
```
```{r cvsl-risk-summary-handbook, echo = FALSE}
if (knitr::is_latex_output()) {
cv_sl_fit$cv_risk[,c(1:3)] %>%
kable(format = "latex")
} else if (knitr::is_html_output()) {
cv_sl_fit$cv_risk[,c(1:3)] %>%
kable() %>%
kableExtra:::kable_styling(fixed_thead = TRUE) %>%
scroll_box(width = "100%", height = "300px")
}
```
The CV risk of the SL is
`r round(cv_sl_fit$cv_risk[nrow(cv_sl_fit$cv_risk),3], 4)`, which is lower
than all of the candidates' CV risks.
We can see how the SL fits varied across the folds by the coefficients for the
SL on each fold.
```{r cvsl-risk-summary-coefs, eval = FALSE}
round(cv_sl_fit$coef, 3)
```
```{r cvsl-risk-summary-coefs-handbook, echo = FALSE}
if (knitr::is_latex_output()) {
round(cv_sl_fit$coef, 3) %>%
kable(format = "latex")
} else if (knitr::is_html_output()) {
round(cv_sl_fit$coef, 3) %>%
kable() %>%
kableExtra:::kable_styling(fixed_thead = TRUE) %>%
scroll_box(width = "100%", height = "300px")
}
```
### Revere-cross-validated predictive performance of Super Learner
We can also use so-called "revere", to obtain a partial CV risk for the SL,
where the SL candidate learner fits are cross-validated but the meta-learner fit
is not. It takes essentially no extra time to calculate a revere-CV
performance/risk estimate of the SL, since we already have the CV fits of the
candidates. This isn't to say that revere-CV SL performance can replace that
obtained from actual CV SL. Revere can be used to very quickly examine an
approximate lower bound on the SL's CV risk *when the meta-learner is a simple model*,
like NNLS. We can output the revere-based CV risk estimate by setting
`get_sl_revere_risk = TRUE` in `cv_risk`.
```{r sl-revere-risk}
cv_risk_w_sl_revere <- sl_fit$cv_risk(
eval_fun = loss_squared_error, get_sl_revere_risk = TRUE
)
```
```{r sl-revere-risk-summary, eval = FALSE}
cv_risk_w_sl_revere[,c(1:3)]
```
```{r sl-revere-risk-handbook, echo = FALSE}
if (knitr::is_latex_output()) {
cv_risk_w_sl_revere[,c(1:3)] %>%
kable(format = "latex")
} else if (knitr::is_html_output()) {
cv_risk_w_sl_revere[,c(1:3)] %>%
kable() %>%
kableExtra:::kable_styling(fixed_thead = TRUE) %>%
scroll_box(width = "100%", height = "300px")
}
```
##### Revere-cross-validated predictive performance of Super Learner by hand {-}
We show how to calculate the revere-CV predictive performance/risk of
the SL by hand below, as this might be helpful for understanding revere and
how it can be used to obtain a partial CV performance/risk estimate for the
SL.
```{r sl-revere-risk-byhand}
##### revere-based risk "by hand" #####
# for each fold, i, we obtain predictive performance/risk for the SL
sl_revere_risk_list <- lapply(seq_along(task$folds), function(i){
# get validation dataset for fold i:
v_data <- task$data[task$folds[[i]]$validation_set, ]
# get observed outcomes in fold i's validation dataset:
v_outcomes <- v_data[["whz"]]
# make task (for prediction) using fold i's validation dataset as data,
# and keeping all else the same:
v_task <- make_sl3_Task(
covariates = task$nodes$covariates, data = v_data
)
# get predicted outcomes for fold i's validation dataset, using candidates
# trained to fold i's training dataset
v_preds <- sl_fit$fit_object$cv_fit$fit_object$fold_fits[[i]]$predict(v_task)
# make a metalevel task (for prediction with sl):
v_meta_task <- make_sl3_Task(
covariates = sl_fit$fit_object$cv_meta_task$nodes$covariates,
data = v_preds
)
# get predicted outcomes for fold i's metalevel dataset, using the fitted
# metalearner, cv_meta_fit
sl_revere_v_preds <- sl_fit$fit_object$cv_meta_fit$predict(task=v_meta_task)
# note: cv_meta_fit was trained on the metalevel dataset, which contains the
# candidates' cv predictions and validation dataset outcomes across ALL folds,
# so cv_meta_fit has already seen fold i's validation dataset outcomes.
# calculate predictive performance for fold i for the SL
eval_function <- loss_squared_error # valid for estimation of conditional mean
# note: by evaluating the predictive performance of the SL using outcomes
# that were already seen by the metalearner, this is not a cross-validated
# measure of predictive performance for the SL.
sl_revere_v_loss <- eval_function(
pred = sl_revere_v_preds, observed = v_outcomes
)
sl_revere_v_risk <- mean(sl_revere_v_loss)
return(sl_revere_v_risk)
})
# average the predictive performance across all folds for the SL
sl_revere_risk_byhand <- mean(unlist(sl_revere_risk_list))
sl_revere_risk_byhand
# check that our calculation by hand equals what is output in cv_risk_table_revere
sl_revere_risk <- as.numeric(cv_risk_w_sl_revere[learner=="SuperLearner","MSE"])
sl_revere_risk
```
The reason why this is not a fully cross-validated risk estimate is because the
`cv_meta_fit` object above (which is the trained meta-learner), was previously
fit to the *entire* matrix of CV predictions from *every* fold (i.e., the
meta-level dataset; see Figure 2 in @rvp2022super for more detail). This is why
revere-based risks are not a true CV risk. If the meta-learner is not a simple
regression function, and instead a more flexible learner (e.g., random
forest) is used as the meta-learner, then the revere-CV risk estimate of the
resulting SL will be a worse approximation of the CV risk estimate. This is
because more flexible learners are more likely to overfit. When simple
parametric regressions are used as a meta-learner, like what we considered in
our SL (NNLS with `Lrnr_nnls`), and like all of the default meta-learners in
`sl3`, then the revere-CV risk is a quick way to examine an approximation of
the CV risk estimate of the SL and it can thought of as a ballpark lower bound
on it. This idea holds in our example; that is, with the simple NNLS
meta-learner the revere risk estimate of the SL (`r round(sl_revere_risk, 4)`)
is very close to, and slightly lower than, the CV risk estimate for the SL
(`r round(cv_sl_fit$cv_risk[nrow(cv_sl_fit$cv_risk),2], 4)`).
## Discrete Super Learner
From the glossary (Table 1) entry for discrete SL (dSL) in @rvp2022super,
the dSL is "a SL that uses a winner-take-all meta-learner called
the cross-validated selector. The dSL is therefore identical to the candidate
with the best cross-validated performance; its predictions will be the same as
this candidate’s predictions". The cross-validated selector is
`Lrnr_cv_selector` in `sl3` (see `Lrnr_cv_selector` documentation for more
detail) and a dSL is instantiated in `sl3` by using `Lrnr_cv_selector` as the
meta-learner in `Lrnr_sl`.
```{r make-dSL}
cv_selector <- Lrnr_cv_selector$new(eval_function = loss_squared_error)
dSL <- Lrnr_sl$new(learners = stack, metalearner = cv_selector)
```
Just like before, we use the learner's `train` method to fit it to the
prediction task.
```{r fit-dSL}
set.seed(4197)
dSL_fit <- dSL$train(task)
```
Following from subsection ["Summarizing Super Learner
Fits"](https://tlverse.org/tlverse-handbook/sl3.html#summarizing-super-learner-fits)
above, we can see how the `Lrnr_cv_selector` meta-learner created a function of
the candidates.
```{r summarize-dSL-coefs}
round(dSL_fit$coefficients, 3)