Skip to content

Commit

Permalink
Merge pull request #188 from imbs-hl/predict_se
Browse files Browse the repository at this point in the history
Standard errors of predictions using Jackknife
  • Loading branch information
mnwright authored May 15, 2017
2 parents d72de4a + 95a3e29 commit 46e2b28
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 9 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Package: ranger
Type: Package
Title: A Fast Implementation of Random Forests
Version: 0.7.2
Date: 2017-04-11
Date: 2017-05-12
Author: Marvin N. Wright
Maintainer: Marvin N. Wright <[email protected]>
Description: A fast implementation of Random Forests, particularly suited for high
Expand Down
1 change: 1 addition & 0 deletions NEWS
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
##### Version 0.7.2
* Handle sparse data of class Matrix::dgCMatrix
* Add prediction of standard errors to predict()

##### Version 0.7.1
* Allow devtools::install_github() without subdir and on Windows
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
##### Version 0.7.2
* Handle sparse data of class Matrix::dgCMatrix
* Add prediction of standard errors to predict()

##### Version 0.7.1
* Allow devtools::install_github() without subdir and on Windows
Expand Down
70 changes: 65 additions & 5 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
##' Prediction with new data and a saved forest from Ranger.
##'
##' For \code{type = 'response'} (the default), the predicted classes (classification), predicted numeric values (regression), predicted probabilities (probability estimation) or survival probabilities (survival) are returned.
##' For \code{type = 'se'}, the standard error of the predictions are returned (regression only). The jackknife-after-bootstrap is used to estimate the standard errors based on out-of-bag predictions. See Wager et al. (2014) for details.
##' For \code{type = 'terminalNodes'}, the IDs of the terminal node in each tree for each observation in the given dataset are returned.
##'
##' For classification and \code{predict.all = TRUE}, a factor levels are returned as numerics.
Expand All @@ -39,10 +40,11 @@
##' @param data New test data of class \code{data.frame} or \code{gwaa.data} (GenABEL).
##' @param predict.all Return individual predictions for each tree instead of aggregated predictions for all trees. Return a matrix (sample x tree) for classification and regression, a 3d array for probability estimation (sample x class x tree) and survival (sample x time x tree).
##' @param num.trees Number of trees used for prediction. The first \code{num.trees} in the forest are used.
##' @param type Type of prediction. One of 'response' or 'terminalNodes' with default 'response'. See below for details.
##' @param type Type of prediction. One of 'response', 'se', 'terminalNodes' with default 'response'. See below for details.
##' @param seed Random seed used in Ranger.
##' @param num.threads Number of threads. Default is number of CPUs available.
##' @param verbose Verbose output on or off.
##' @param inbag.counts Number of times the observations are in-bag in the trees.
##' @param ... further arguments passed to or from other methods.
##' @return Object of class \code{ranger.prediction} with elements
##' \tabular{ll}{
Expand All @@ -55,6 +57,11 @@
##' \code{treetype} \tab Type of forest/tree. Classification, regression or survival. \cr
##' \code{num.samples} \tab Number of samples.
##' }
##' @references
##' \itemize{
##' \item Wright, M. N. & Ziegler, A. (2017). ranger: A Fast Implementation of Random Forests for High Dimensional Data in C++ and R. J Stat Softw 77:1-17. \url{http://dx.doi.org/10.18637/jss.v077.i01}.
##' \item Wager, S., Hastie T., & Efron, B. (2014). Confidence Intervals for Random Forests: The Jackknife and the Infinitesimal Jackknife. J Mach Learn Res 15:1625-1651. \url{http://jmlr.org/papers/v15/wager14a.html}.
##' }
##' @seealso \code{\link{ranger}}
##' @author Marvin N. Wright
##' @importFrom Matrix Matrix
Expand All @@ -63,7 +70,7 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE,
num.trees = object$num.trees,
type = "response",
seed = NULL, num.threads = NULL,
verbose = TRUE, ...) {
verbose = TRUE, inbag.counts = NULL,...) {

## GenABEL GWA data
if ("gwaa.data" %in% class(data)) {
Expand Down Expand Up @@ -101,13 +108,28 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE,
}

## Prediction type
if (type == "response") {
if (type == "response" || type == "se") {
prediction.type <- 1
} else if (type == "terminalNodes") {
prediction.type <- 2
} else {
stop("Error: Invalid value for 'type'. Use 'response' or 'terminalNodes'.")
}

## Type "se" only for regression
if (type == "se" && forest$treetype != "Regression") {
stop("Error: Standard error prediction currently only available for regression.")
}

## Type "se" requires keep.inbag=TRUE
if (type == "se" && is.null(inbag.counts)) {
stop("Error: No saved inbag counts in ranger object. Please set keep.inbag=TRUE when calling ranger.")
}

## Set predict.all if type is "se"
if (type == "se") {
predict.all <- TRUE
}

## Create final data
if (forest$treetype == "Survival") {
Expand Down Expand Up @@ -333,13 +355,46 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE,
}
}

## Compute Jackknife
if (type == "se") {
## Aggregated predictions
yhat <- rowMeans(result$predictions)

## Get inbag counts, keep only observations that are OOB at least once
inbag.counts <- simplify2array(inbag.counts)
if (is.vector(inbag.counts)) {
inbag.counts <- t(as.matrix(inbag.counts))
}
inbag.counts <- inbag.counts[rowSums(inbag.counts == 0) > 0, , drop = FALSE]
n <- nrow(inbag.counts)
oob <- inbag.counts == 0

if (all(!oob)) {
stop("Error: No OOB observations found, consider increasing num.trees or reducing sample.fraction.")
}

## Compute Jackknife
jack.n <- apply(oob, 1, function(x) rowMeans(result$predictions[, x, drop = FALSE]))
if (is.vector(jack.n)) {
jack.n <- t(as.matrix(jack.n))
}
jack <- (n - 1) / n * rowSums((jack.n - yhat)^2)
bias <- (exp(1) - 1) * n / result$num.trees^2 * rowSums((result$predictions - yhat)^2)
jab <- pmax(jack - bias, 0)
result$se <- sqrt(jab)

## Response as predictions
result$predictions <- yhat
}

class(result) <- "ranger.prediction"
return(result)
}

##' Prediction with new data and a saved forest from Ranger.
##'
##' For \code{type = 'response'} (the default), the predicted classes (classification), predicted numeric values (regression), predicted probabilities (probability estimation) or survival probabilities (survival) are returned.
##' For \code{type = 'se'}, the standard error of the predictions are returned (regression only). The jackknife-after-bootstrap is used to estimate the standard errors based on out-of-bag predictions. See Wager et al. (2014) for details.
##' For \code{type = 'terminalNodes'}, the IDs of the terminal node in each tree for each observation in the given dataset are returned.
##'
##' For classification and \code{predict.all = TRUE}, a factor levels are returned as numerics.
Expand All @@ -350,7 +405,7 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE,
##' @param data New test data of class \code{data.frame} or \code{gwaa.data} (GenABEL).
##' @param predict.all Return individual predictions for each tree instead of aggregated predictions for all trees. Return a matrix (sample x tree) for classification and regression, a 3d array for probability estimation (sample x class x tree) and survival (sample x time x tree).
##' @param num.trees Number of trees used for prediction. The first \code{num.trees} in the forest are used.
##' @param type Type of prediction. One of 'response' or 'terminalNodes' with default 'response'. See below for details.
##' @param type Type of prediction. One of 'response', 'se', 'terminalNodes' with default 'response'. See below for details.
##' @param seed Random seed used in Ranger.
##' @param num.threads Number of threads. Default is number of CPUs available.
##' @param verbose Verbose output on or off.
Expand All @@ -366,6 +421,11 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE,
##' \code{treetype} \tab Type of forest/tree. Classification, regression or survival. \cr
##' \code{num.samples} \tab Number of samples.
##' }
##' @references
##' \itemize{
##' \item Wright, M. N. & Ziegler, A. (2017). ranger: A Fast Implementation of Random Forests for High Dimensional Data in C++ and R. J Stat Softw 77:1-17. \url{http://dx.doi.org/10.18637/jss.v077.i01}.
##' \item Wager, S., Hastie T., & Efron, B. (2014). Confidence Intervals for Random Forests: The Jackknife and the Infinitesimal Jackknife. J Mach Learn Res 15:1625-1651. \url{http://jmlr.org/papers/v15/wager14a.html}.
##' }
##' @seealso \code{\link{ranger}}
##' @author Marvin N. Wright
##' @export
Expand All @@ -378,5 +438,5 @@ predict.ranger <- function(object, data, predict.all = FALSE,
if (is.null(forest)) {
stop("Error: No saved forest in ranger object. Please set write.forest to TRUE when calling ranger.")
}
predict(forest, data, predict.all, num.trees, type, seed, num.threads, verbose)
predict(forest, data, predict.all, num.trees, type, seed, num.threads, verbose, object$inbag.counts, ...)
}
9 changes: 8 additions & 1 deletion man/predict.ranger.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 11 additions & 2 deletions man/predict.ranger.forest.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

80 changes: 80 additions & 0 deletions tests/testthat/test_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,83 @@ test_that("predict.all works for single observation", {

expect_equal(dim(pred$predictions), c(1, rf$num.trees))
})

test_that("standard error prediction working for regression", {
idx <- sample(nrow(iris), 10)
test <- iris[idx, ]
train <- iris[-idx, ]

rf <- ranger(Petal.Length ~ ., train, num.trees = 5, keep.inbag = TRUE)
pred <- predict(rf, test, type = "se")

expect_equal(length(pred$predictions), nrow(test))
})

test_that("standard error prediction not working for other tree types", {
rf <- ranger(Species ~ ., iris, num.trees = 5, keep.inbag = TRUE)
expect_error(predict(rf, iris, type = "se"),
"Error: Standard error prediction currently only available for regression.")

rf <- ranger(Species ~ ., iris, num.trees = 5, keep.inbag = TRUE, probability = TRUE)
expect_error(predict(rf, iris, type = "se"),
"Error: Standard error prediction currently only available for regression.")

rf <- ranger(Surv(time, status) ~ ., veteran, num.trees = 5, keep.inbag = TRUE)
expect_error(predict(rf, veteran, type = "se"),
"Error: Standard error prediction currently only available for regression.")
})

test_that("standard error prediction not working if keep.inbag = FALSE", {
rf <- ranger(Petal.Length ~ ., iris, num.trees = 5)
expect_error(predict(rf, iris, type = "se"),
"Error: No saved inbag counts in ranger object. Please set keep.inbag=TRUE when calling ranger.")
})

test_that("standard error prediction not working if no OOB observations", {
test <- iris[-1, ]
train <- iris[1, ]
rf <- ranger(Petal.Length ~ ., train, num.trees = 5, keep.inbag = TRUE)
expect_error(predict(rf, iris, type = "se"),
"Error: No OOB observations found, consider increasing num.trees or reducing sample.fraction.")
})

test_that("standard error prediction working for single testing observation", {
test <- iris[1, ]
train <- iris[-1, ]

rf <- ranger(Petal.Length ~ ., train, num.trees = 5, keep.inbag = TRUE)
pred <- predict(rf, test, type = "se")

expect_equal(length(pred$predictions), nrow(test))
})

test_that("standard error response prediction is the same as response prediction", {
idx <- sample(nrow(iris), 10)
test <- iris[idx, ]
train <- iris[-idx, ]

set.seed(100)
rf_se <- ranger(Petal.Length ~ ., train, num.trees = 5, keep.inbag = TRUE)
pred_se <- predict(rf_se, test, type = "se")

set.seed(100)
rf_resp <- ranger(Petal.Length ~ ., train, num.trees = 5)
pred_resp <- predict(rf_resp, test, type = "response")

expect_equal(pred_se$predictions, pred_resp$predictions)
})

test_that("standard error is larger for fewer trees", {
idx <- sample(nrow(iris), 10)
test <- iris[idx, ]
train <- iris[-idx, ]

rf5 <- ranger(Petal.Length ~ ., train, num.trees = 5, keep.inbag = TRUE)
pred5 <- predict(rf5, test, type = "se")

rf50 <- ranger(Petal.Length ~ ., train, num.trees = 50, keep.inbag = TRUE)
pred50 <- predict(rf50, test, type = "se")

expect_lt(mean(pred50$se), mean(pred5$se))
})

0 comments on commit 46e2b28

Please sign in to comment.