diff --git a/src/main/scala/com/microsoft/ml/spark/lightgbm/TrainUtils.scala b/src/main/scala/com/microsoft/ml/spark/lightgbm/TrainUtils.scala index f4414ab0f6..3c2d835a54 100644 --- a/src/main/scala/com/microsoft/ml/spark/lightgbm/TrainUtils.scala +++ b/src/main/scala/com/microsoft/ml/spark/lightgbm/TrainUtils.scala @@ -379,7 +379,8 @@ private object TrainUtils extends Serializable { val score = lightgbmlib.doubleArray_getitem(evalResults, index.toLong) log.info(s"Valid $evalName=$score") val cmp = - if (evalName.startsWith("auc") || evalName.startsWith("ndcg@") || evalName.startsWith("map@")) + if (evalName.startsWith("auc") || evalName.startsWith("ndcg@") || evalName.startsWith("map@") || + evalName.startsWith("average_precision")) (x: Double, y: Double, tol: Double) => x - y > tol else (x: Double, y: Double, tol: Double) => x - y < tol diff --git a/src/test/scala/com/microsoft/ml/spark/lightgbm/split1/VerifyLightGBMClassifier.scala b/src/test/scala/com/microsoft/ml/spark/lightgbm/split1/VerifyLightGBMClassifier.scala index 879d6c6fae..0bb33bb5cb 100644 --- a/src/test/scala/com/microsoft/ml/spark/lightgbm/split1/VerifyLightGBMClassifier.scala +++ b/src/test/scala/com/microsoft/ml/spark/lightgbm/split1/VerifyLightGBMClassifier.scala @@ -152,6 +152,7 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM lazy val pimaDF: DataFrame = loadBinary("PimaIndian.csv", "Diabetes mellitus").cache() lazy val taskDF: DataFrame = loadBinary("task.train.csv", "TaskFailed10").cache() lazy val breastTissueDF: DataFrame = loadMulticlass("BreastTissue.csv", "Class").cache() + lazy val au3DF: DataFrame = loadMulticlass("au3_25000.csv", "class").cache() lazy val unfeaturizedBankTrainDF: DataFrame = { val categoricalColumns = Array( "job", "marital", "education", "default", "housing", "loan", "contact", "y") @@ -375,35 +376,37 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM ) } - ignore("Verify LightGBM Classifier with validation dataset") { - tryWithRetries(Array(0, 0, 0, 0)) { () => // TODO fix flakiness - val df = taskDF.orderBy(rand()).withColumn(validationCol, lit(false)) - - val Array(train, validIntermediate, test) = df.randomSplit(Array(0.1, 0.6, 0.3), seed) - val valid = validIntermediate.withColumn(validationCol, lit(true)) - val trainAndValid = train.union(valid.orderBy(rand())) - - // model1 should overfit on the given dataset - val model1 = baseModel - .setNumLeaves(100) - .setNumIterations(200) - .setIsUnbalance(true) - - // model2 should terminate early before overfitting - val model2 = baseModel - .setNumLeaves(100) - .setNumIterations(200) - .setIsUnbalance(true) - .setValidationIndicatorCol(validationCol) - .setEarlyStoppingRound(5) - - // Assert evaluation metric improves - Array("auc", "binary_logloss", "binary_error").foreach { metric => - assertBinaryImprovement( - model1, train, test, - model2.setMetric(metric), trainAndValid, test - ) - } + test("Verify LightGBM Classifier with validation dataset") { + val df = au3DF.orderBy(rand()).withColumn(validationCol, lit(false)) + + val Array(train, validIntermediate, test) = df.randomSplit(Array(0.5, 0.2, 0.3), seed) + val valid = validIntermediate.withColumn(validationCol, lit(true)) + val trainAndValid = train.union(valid.orderBy(rand())) + + // model1 should overfit on the given dataset + val model1 = baseModel + .setNumLeaves(100) + .setNumIterations(100) + .setLearningRate(0.9) + .setMinDataInLeaf(2) + .setValidationIndicatorCol(validationCol) + .setEarlyStoppingRound(100) + + // model2 should terminate early before overfitting + val model2 = baseModel + .setNumLeaves(100) + .setNumIterations(100) + .setLearningRate(0.9) + .setMinDataInLeaf(2) + .setValidationIndicatorCol(validationCol) + .setEarlyStoppingRound(5) + + // Assert evaluation metric improves + Array("auc", "binary_logloss", "binary_error").foreach { metric => + assertBinaryImprovement( + model1.setMetric(metric), trainAndValid, test, + model2.setMetric(metric), trainAndValid, test + ) } }