From aad223e045512f5c59249e838cfff2fd5d279e2d Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Mon, 3 May 2021 01:01:40 -0400 Subject: [PATCH] fix: early stopping test and average precision metric (#1034) Co-authored-by: Mark Hamilton --- .../ml/spark/lightgbm/TrainUtils.scala | 3 +- .../split1/VerifyLightGBMClassifier.scala | 61 ++++++++++--------- 2 files changed, 34 insertions(+), 30 deletions(-) diff --git a/src/main/scala/com/microsoft/ml/spark/lightgbm/TrainUtils.scala b/src/main/scala/com/microsoft/ml/spark/lightgbm/TrainUtils.scala index f4414ab0f6..3c2d835a54 100644 --- a/src/main/scala/com/microsoft/ml/spark/lightgbm/TrainUtils.scala +++ b/src/main/scala/com/microsoft/ml/spark/lightgbm/TrainUtils.scala @@ -379,7 +379,8 @@ private object TrainUtils extends Serializable { val score = lightgbmlib.doubleArray_getitem(evalResults, index.toLong) log.info(s"Valid $evalName=$score") val cmp = - if (evalName.startsWith("auc") || evalName.startsWith("ndcg@") || evalName.startsWith("map@")) + if (evalName.startsWith("auc") || evalName.startsWith("ndcg@") || evalName.startsWith("map@") || + evalName.startsWith("average_precision")) (x: Double, y: Double, tol: Double) => x - y > tol else (x: Double, y: Double, tol: Double) => x - y < tol diff --git a/src/test/scala/com/microsoft/ml/spark/lightgbm/split1/VerifyLightGBMClassifier.scala b/src/test/scala/com/microsoft/ml/spark/lightgbm/split1/VerifyLightGBMClassifier.scala index 879d6c6fae..0bb33bb5cb 100644 --- a/src/test/scala/com/microsoft/ml/spark/lightgbm/split1/VerifyLightGBMClassifier.scala +++ b/src/test/scala/com/microsoft/ml/spark/lightgbm/split1/VerifyLightGBMClassifier.scala @@ -152,6 +152,7 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM lazy val pimaDF: DataFrame = loadBinary("PimaIndian.csv", "Diabetes mellitus").cache() lazy val taskDF: DataFrame = loadBinary("task.train.csv", "TaskFailed10").cache() lazy val breastTissueDF: DataFrame = loadMulticlass("BreastTissue.csv", "Class").cache() + lazy val au3DF: DataFrame = loadMulticlass("au3_25000.csv", "class").cache() lazy val unfeaturizedBankTrainDF: DataFrame = { val categoricalColumns = Array( "job", "marital", "education", "default", "housing", "loan", "contact", "y") @@ -375,35 +376,37 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM ) } - ignore("Verify LightGBM Classifier with validation dataset") { - tryWithRetries(Array(0, 0, 0, 0)) { () => // TODO fix flakiness - val df = taskDF.orderBy(rand()).withColumn(validationCol, lit(false)) - - val Array(train, validIntermediate, test) = df.randomSplit(Array(0.1, 0.6, 0.3), seed) - val valid = validIntermediate.withColumn(validationCol, lit(true)) - val trainAndValid = train.union(valid.orderBy(rand())) - - // model1 should overfit on the given dataset - val model1 = baseModel - .setNumLeaves(100) - .setNumIterations(200) - .setIsUnbalance(true) - - // model2 should terminate early before overfitting - val model2 = baseModel - .setNumLeaves(100) - .setNumIterations(200) - .setIsUnbalance(true) - .setValidationIndicatorCol(validationCol) - .setEarlyStoppingRound(5) - - // Assert evaluation metric improves - Array("auc", "binary_logloss", "binary_error").foreach { metric => - assertBinaryImprovement( - model1, train, test, - model2.setMetric(metric), trainAndValid, test - ) - } + test("Verify LightGBM Classifier with validation dataset") { + val df = au3DF.orderBy(rand()).withColumn(validationCol, lit(false)) + + val Array(train, validIntermediate, test) = df.randomSplit(Array(0.5, 0.2, 0.3), seed) + val valid = validIntermediate.withColumn(validationCol, lit(true)) + val trainAndValid = train.union(valid.orderBy(rand())) + + // model1 should overfit on the given dataset + val model1 = baseModel + .setNumLeaves(100) + .setNumIterations(100) + .setLearningRate(0.9) + .setMinDataInLeaf(2) + .setValidationIndicatorCol(validationCol) + .setEarlyStoppingRound(100) + + // model2 should terminate early before overfitting + val model2 = baseModel + .setNumLeaves(100) + .setNumIterations(100) + .setLearningRate(0.9) + .setMinDataInLeaf(2) + .setValidationIndicatorCol(validationCol) + .setEarlyStoppingRound(5) + + // Assert evaluation metric improves + Array("auc", "binary_logloss", "binary_error").foreach { metric => + assertBinaryImprovement( + model1.setMetric(metric), trainAndValid, test, + model2.setMetric(metric), trainAndValid, test + ) } }