Skip to content

Commit

Permalink
fix: early stopping test and average precision metric (#1034)
Browse files Browse the repository at this point in the history
Co-authored-by: Mark Hamilton <[email protected]>
  • Loading branch information
imatiach-msft and mhamilton723 authored May 3, 2021
1 parent 04a9876 commit aad223e
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,8 @@ private object TrainUtils extends Serializable {
val score = lightgbmlib.doubleArray_getitem(evalResults, index.toLong)
log.info(s"Valid $evalName=$score")
val cmp =
if (evalName.startsWith("auc") || evalName.startsWith("ndcg@") || evalName.startsWith("map@"))
if (evalName.startsWith("auc") || evalName.startsWith("ndcg@") || evalName.startsWith("map@") ||
evalName.startsWith("average_precision"))
(x: Double, y: Double, tol: Double) => x - y > tol
else
(x: Double, y: Double, tol: Double) => x - y < tol
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM
lazy val pimaDF: DataFrame = loadBinary("PimaIndian.csv", "Diabetes mellitus").cache()
lazy val taskDF: DataFrame = loadBinary("task.train.csv", "TaskFailed10").cache()
lazy val breastTissueDF: DataFrame = loadMulticlass("BreastTissue.csv", "Class").cache()
lazy val au3DF: DataFrame = loadMulticlass("au3_25000.csv", "class").cache()
lazy val unfeaturizedBankTrainDF: DataFrame = {
val categoricalColumns = Array(
"job", "marital", "education", "default", "housing", "loan", "contact", "y")
Expand Down Expand Up @@ -375,35 +376,37 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM
)
}

ignore("Verify LightGBM Classifier with validation dataset") {
tryWithRetries(Array(0, 0, 0, 0)) { () => // TODO fix flakiness
val df = taskDF.orderBy(rand()).withColumn(validationCol, lit(false))

val Array(train, validIntermediate, test) = df.randomSplit(Array(0.1, 0.6, 0.3), seed)
val valid = validIntermediate.withColumn(validationCol, lit(true))
val trainAndValid = train.union(valid.orderBy(rand()))

// model1 should overfit on the given dataset
val model1 = baseModel
.setNumLeaves(100)
.setNumIterations(200)
.setIsUnbalance(true)

// model2 should terminate early before overfitting
val model2 = baseModel
.setNumLeaves(100)
.setNumIterations(200)
.setIsUnbalance(true)
.setValidationIndicatorCol(validationCol)
.setEarlyStoppingRound(5)

// Assert evaluation metric improves
Array("auc", "binary_logloss", "binary_error").foreach { metric =>
assertBinaryImprovement(
model1, train, test,
model2.setMetric(metric), trainAndValid, test
)
}
test("Verify LightGBM Classifier with validation dataset") {
val df = au3DF.orderBy(rand()).withColumn(validationCol, lit(false))

val Array(train, validIntermediate, test) = df.randomSplit(Array(0.5, 0.2, 0.3), seed)
val valid = validIntermediate.withColumn(validationCol, lit(true))
val trainAndValid = train.union(valid.orderBy(rand()))

// model1 should overfit on the given dataset
val model1 = baseModel
.setNumLeaves(100)
.setNumIterations(100)
.setLearningRate(0.9)
.setMinDataInLeaf(2)
.setValidationIndicatorCol(validationCol)
.setEarlyStoppingRound(100)

// model2 should terminate early before overfitting
val model2 = baseModel
.setNumLeaves(100)
.setNumIterations(100)
.setLearningRate(0.9)
.setMinDataInLeaf(2)
.setValidationIndicatorCol(validationCol)
.setEarlyStoppingRound(5)

// Assert evaluation metric improves
Array("auc", "binary_logloss", "binary_error").foreach { metric =>
assertBinaryImprovement(
model1.setMetric(metric), trainAndValid, test,
model2.setMetric(metric), trainAndValid, test
)
}
}

Expand Down

0 comments on commit aad223e

Please sign in to comment.