diff --git a/src/main/scala/com/microsoft/lightgbm/SWIG.scala b/src/main/scala/com/microsoft/lightgbm/SWIG.scala new file mode 100644 index 00000000000..efeb042dfa3 --- /dev/null +++ b/src/main/scala/com/microsoft/lightgbm/SWIG.scala @@ -0,0 +1,10 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.lightgbm + +import com.microsoft.ml.lightgbm.SWIGTYPE_p_void + +class SwigPtrWrapper(val value: SWIGTYPE_p_void) extends SWIGTYPE_p_void { + def getCPtrValue(): Long = SWIGTYPE_p_void.getCPtr(value) +} diff --git a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMBase.scala b/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMBase.scala index 8f55c954fa9..2bdcc75e962 100644 --- a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMBase.scala +++ b/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMBase.scala @@ -4,6 +4,9 @@ package com.microsoft.ml.spark.lightgbm import com.microsoft.ml.spark.core.utils.ClusterUtil +import com.microsoft.ml.spark.lightgbm.booster.LightGBMBooster +import com.microsoft.ml.spark.lightgbm.params.{DartModeParams, ExecutionParams, LightGBMParams, + ObjectiveParams, TrainParams} import com.microsoft.ml.spark.logging.BasicLogging import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.SQLDataTypes.VectorType @@ -180,14 +183,30 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine } } + /** + * Constructs the DartModeParams + * @return DartModeParams object containing parameters related to dart mode. + */ protected def getDartParams(): DartModeParams = { DartModeParams(getDropRate, getMaxDrop, getSkipDrop, getXGBoostDartMode, getUniformDrop) } + /** + * Constructs the ExecutionParams. + * @return ExecutionParams object containing parameters related to LightGBM execution. + */ protected def getExecutionParams(): ExecutionParams = { ExecutionParams(getChunkSize, getMatrixType) } + /** + * Constructs the ObjectiveParams. + * @return ObjectiveParams object containing parameters related to the objective function. + */ + protected def getObjectiveParams(): ObjectiveParams = { + ObjectiveParams(getObjective, if (isDefined(fobj)) Some(getFObj) else None) + } + /** * Inner train method for LightGBM learners. Calculates the number of workers, * creates a driver thread, and runs mapPartitions on the dataset. diff --git a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMClassifier.scala b/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMClassifier.scala index c87daab7fff..65290aafd1b 100644 --- a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMClassifier.scala +++ b/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMClassifier.scala @@ -3,6 +3,9 @@ package com.microsoft.ml.spark.lightgbm +import com.microsoft.ml.spark.lightgbm.booster.LightGBMBooster +import com.microsoft.ml.spark.lightgbm.params.{ClassifierTrainParams, LightGBMModelParams, + LightGBMPredictionParams, TrainParams} import com.microsoft.ml.spark.logging.BasicLogging import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable} import org.apache.spark.ml.param._ @@ -46,11 +49,11 @@ class LightGBMClassifier(override val uid: String) ClassifierTrainParams(getParallelism, getTopK, getNumIterations, getLearningRate, getNumLeaves, getMaxBin, getBinSampleCount, getBaggingFraction, getPosBaggingFraction, getNegBaggingFraction, getBaggingFreq, getBaggingSeed, getEarlyStoppingRound, getImprovementTolerance, - getFeatureFraction, getMaxDepth, getMinSumHessianInLeaf, numTasks, getObjective, modelStr, + getFeatureFraction, getMaxDepth, getMinSumHessianInLeaf, numTasks, modelStr, getIsUnbalance, getVerbosity, categoricalIndexes, actualNumClasses, getBoostFromAverage, getBoostingType, getLambdaL1, getLambdaL2, getIsProvideTrainingMetric, getMetric, getMinGainToSplit, getMaxDeltaStep, getMaxBinByFeature, getMinDataInLeaf, getSlotNames, - getDelegate, getDartParams(), getExecutionParams()) + getDelegate, getDartParams(), getExecutionParams(), getObjectiveParams()) } def getModel(trainParams: TrainParams, lightGBMBooster: LightGBMBooster): LightGBMClassificationModel = { @@ -187,7 +190,7 @@ class LightGBMClassificationModel(override val uid: String) object LightGBMClassificationModel extends ComplexParamsReadable[LightGBMClassificationModel] { def loadNativeModelFromFile(filename: String): LightGBMClassificationModel = { - val uid = Identifiable.randomUID("LightGBMClassifier") + val uid = Identifiable.randomUID("LightGBMClassificationModel") val session = SparkSession.builder().getOrCreate() val textRdd = session.read.text(filename) val text = textRdd.collect().map { row => row.getString(0) }.mkString("\n") @@ -197,7 +200,7 @@ object LightGBMClassificationModel extends ComplexParamsReadable[LightGBMClassif } def loadNativeModelFromString(model: String): LightGBMClassificationModel = { - val uid = Identifiable.randomUID("LightGBMClassifier") + val uid = Identifiable.randomUID("LightGBMClassificationModel") val lightGBMBooster = new LightGBMBooster(model) val actualNumClasses = lightGBMBooster.numClasses new LightGBMClassificationModel(uid).setLightGBMBooster(lightGBMBooster).setActualNumClasses(actualNumClasses) diff --git a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMDelegate.scala b/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMDelegate.scala index 640c4fb5af0..956de2c348f 100644 --- a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMDelegate.scala +++ b/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMDelegate.scala @@ -3,7 +3,8 @@ package com.microsoft.ml.spark.lightgbm -import com.microsoft.ml.lightgbm.SWIGTYPE_p_void +import com.microsoft.ml.spark.lightgbm.booster.LightGBMBooster +import com.microsoft.ml.spark.lightgbm.params.TrainParams import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType import org.slf4j.Logger @@ -40,12 +41,12 @@ trait LightGBMDelegate extends Serializable { } def beforeTrainIteration(batchIndex: Int, partitionId: Int, curIters: Int, log: Logger, - trainParams: TrainParams, boosterPtr: Option[SWIGTYPE_p_void], hasValid: Boolean): Unit = { + trainParams: TrainParams, booster: LightGBMBooster, hasValid: Boolean): Unit = { // override this function and write code } def afterTrainIteration(batchIndex: Int, partitionId: Int, curIters: Int, log: Logger, - trainParams: TrainParams, boosterPtr: Option[SWIGTYPE_p_void], hasValid: Boolean, + trainParams: TrainParams, booster: LightGBMBooster, hasValid: Boolean, isFinished: Boolean, trainEvalResults: Option[Map[String, Double]], validEvalResults: Option[Map[String, Double]]): Unit = { diff --git a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMModelMethods.scala b/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMModelMethods.scala index e4d9e0e9a39..d65cfb8b771 100644 --- a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMModelMethods.scala +++ b/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMModelMethods.scala @@ -3,6 +3,7 @@ package com.microsoft.ml.spark.lightgbm +import com.microsoft.ml.spark.lightgbm.params.LightGBMModelParams import org.apache.spark.internal.Logging import org.apache.spark.ml.linalg.{Vector, Vectors} diff --git a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMRanker.scala b/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMRanker.scala index f020b2a357f..af1cc3bc906 100644 --- a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMRanker.scala +++ b/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMRanker.scala @@ -3,6 +3,9 @@ package com.microsoft.ml.spark.lightgbm +import com.microsoft.ml.spark.lightgbm.booster.LightGBMBooster +import com.microsoft.ml.spark.lightgbm.params.{LightGBMModelParams, LightGBMPredictionParams, + RankerTrainParams, TrainParams} import com.microsoft.ml.spark.logging.BasicLogging import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Ranker, RankerModel} import org.apache.spark.ml.param._ @@ -51,12 +54,13 @@ class LightGBMRanker(override val uid: String) def getTrainParams(numTasks: Int, categoricalIndexes: Array[Int], dataset: Dataset[_]): TrainParams = { val modelStr = if (getModelString == null || getModelString.isEmpty) None else get(modelString) RankerTrainParams(getParallelism, getTopK, getNumIterations, getLearningRate, getNumLeaves, - getObjective, getMaxBin, getBinSampleCount, getBaggingFraction, getPosBaggingFraction, getNegBaggingFraction, + getMaxBin, getBinSampleCount, getBaggingFraction, getPosBaggingFraction, getNegBaggingFraction, getBaggingFreq, getBaggingSeed, getEarlyStoppingRound, getImprovementTolerance, getFeatureFraction, getMaxDepth, getMinSumHessianInLeaf, numTasks, modelStr, getVerbosity, categoricalIndexes, getBoostingType, getLambdaL1, getLambdaL2, getMaxPosition, getLabelGain, getIsProvideTrainingMetric, getMetric, getEvalAt, getMinGainToSplit, getMaxDeltaStep, - getMaxBinByFeature, getMinDataInLeaf, getSlotNames, getDelegate, getDartParams(), getExecutionParams()) + getMaxBinByFeature, getMinDataInLeaf, getSlotNames, getDelegate, getDartParams(), getExecutionParams(), + getObjectiveParams()) } def getModel(trainParams: TrainParams, lightGBMBooster: LightGBMBooster): LightGBMRankerModel = { @@ -157,7 +161,7 @@ class LightGBMRankerModel(override val uid: String) object LightGBMRankerModel extends ComplexParamsReadable[LightGBMRankerModel] { def loadNativeModelFromFile(filename: String): LightGBMRankerModel = { - val uid = Identifiable.randomUID("LightGBMRanker") + val uid = Identifiable.randomUID("LightGBMRankerModel") val session = SparkSession.builder().getOrCreate() val textRdd = session.read.text(filename) val text = textRdd.collect().map { row => row.getString(0) }.mkString("\n") @@ -166,7 +170,7 @@ object LightGBMRankerModel extends ComplexParamsReadable[LightGBMRankerModel] { } def loadNativeModelFromString(model: String): LightGBMRankerModel = { - val uid = Identifiable.randomUID("LightGBMRanker") + val uid = Identifiable.randomUID("LightGBMRankerModel") val lightGBMBooster = new LightGBMBooster(model) new LightGBMRankerModel(uid).setLightGBMBooster(lightGBMBooster) } diff --git a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMRegressor.scala b/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMRegressor.scala index 03813add450..f240c1f26d8 100644 --- a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMRegressor.scala +++ b/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMRegressor.scala @@ -3,6 +3,9 @@ package com.microsoft.ml.spark.lightgbm +import com.microsoft.ml.spark.lightgbm.booster.LightGBMBooster +import com.microsoft.ml.spark.lightgbm.params.{LightGBMModelParams, LightGBMPredictionParams, + RegressorTrainParams, TrainParams} import com.microsoft.ml.spark.logging.BasicLogging import org.apache.spark.ml.{BaseRegressor, ComplexParamsReadable, ComplexParamsWritable} import org.apache.spark.ml.param._ @@ -58,12 +61,12 @@ class LightGBMRegressor(override val uid: String) def getTrainParams(numTasks: Int, categoricalIndexes: Array[Int], dataset: Dataset[_]): TrainParams = { val modelStr = if (getModelString == null || getModelString.isEmpty) None else get(modelString) RegressorTrainParams(getParallelism, getTopK, getNumIterations, getLearningRate, getNumLeaves, - getObjective, getAlpha, getTweedieVariancePower, getMaxBin, getBinSampleCount, - getBaggingFraction, getPosBaggingFraction, getNegBaggingFraction, getBaggingFreq, getBaggingSeed, - getEarlyStoppingRound, getImprovementTolerance, getFeatureFraction, getMaxDepth, getMinSumHessianInLeaf, - numTasks, modelStr, getVerbosity, categoricalIndexes, getBoostFromAverage, getBoostingType, getLambdaL1, - getLambdaL2, getIsProvideTrainingMetric, getMetric, getMinGainToSplit, getMaxDeltaStep, - getMaxBinByFeature, getMinDataInLeaf, getSlotNames, getDelegate, getDartParams(), getExecutionParams()) + getAlpha, getTweedieVariancePower, getMaxBin, getBinSampleCount, getBaggingFraction, getPosBaggingFraction, + getNegBaggingFraction, getBaggingFreq, getBaggingSeed, getEarlyStoppingRound, getImprovementTolerance, + getFeatureFraction, getMaxDepth, getMinSumHessianInLeaf, numTasks, modelStr, getVerbosity, categoricalIndexes, + getBoostFromAverage, getBoostingType, getLambdaL1, getLambdaL2, getIsProvideTrainingMetric, getMetric, + getMinGainToSplit, getMaxDeltaStep, getMaxBinByFeature, getMinDataInLeaf, getSlotNames, getDelegate, + getDartParams(), getExecutionParams(), getObjectiveParams()) } def getModel(trainParams: TrainParams, lightGBMBooster: LightGBMBooster): LightGBMRegressionModel = { @@ -134,7 +137,7 @@ class LightGBMRegressionModel(override val uid: String) object LightGBMRegressionModel extends ComplexParamsReadable[LightGBMRegressionModel] { def loadNativeModelFromFile(filename: String): LightGBMRegressionModel = { - val uid = Identifiable.randomUID("LightGBMRegressor") + val uid = Identifiable.randomUID("LightGBMRegressionModel") val session = SparkSession.builder().getOrCreate() val textRdd = session.read.text(filename) val text = textRdd.collect().map { row => row.getString(0) }.mkString("\n") @@ -143,7 +146,7 @@ object LightGBMRegressionModel extends ComplexParamsReadable[LightGBMRegressionM } def loadNativeModelFromString(model: String): LightGBMRegressionModel = { - val uid = Identifiable.randomUID("LightGBMRegressor") + val uid = Identifiable.randomUID("LightGBMRegressionModel") val lightGBMBooster = new LightGBMBooster(model) new LightGBMRegressionModel(uid).setLightGBMBooster(lightGBMBooster) } diff --git a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMUtils.scala b/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMUtils.scala index 79a6f3b8022..188c31813cf 100644 --- a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMUtils.scala +++ b/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMUtils.scala @@ -11,6 +11,8 @@ import com.microsoft.ml.lightgbm._ import com.microsoft.ml.spark.core.env.NativeLoader import com.microsoft.ml.spark.core.utils.ClusterUtil import com.microsoft.ml.spark.featurize.{Featurize, FeaturizeUtilities} +import com.microsoft.ml.spark.lightgbm.dataset.LightGBMDataset +import com.microsoft.ml.spark.lightgbm.params.TrainParams import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.ml.PipelineModel import org.apache.spark.ml.attribute._ @@ -245,7 +247,7 @@ object LightGBMUtils { LightGBMUtils.validate(lightgbmlib.LGBM_DatasetCreateFromMats(featuresArray.get_chunks_count().toInt, featuresArray.data_as_void(), data64bitType, numRowsForChunks, numCols, - isRowMajor, datasetParams, referenceDataset.map(_.dataset).orNull, datasetOutPtr), + isRowMajor, datasetParams, referenceDataset.map(_.datasetPtr).orNull, datasetOutPtr), "Dataset create") } finally { featuresArray.release() @@ -275,7 +277,7 @@ object LightGBMUtils { LightGBMUtils.validate(lightgbmlib.LGBM_DatasetCreateFromCSRSpark( sparseRows.asInstanceOf[Array[Object]], sparseRows.length, - numCols, datasetParams, referenceDataset.map(_.dataset).orNull, + numCols, datasetParams, referenceDataset.map(_.datasetPtr).orNull, datasetOutPtr), "Dataset create") val dataset = new LightGBMDataset(lightgbmlib.voidpp_value(datasetOutPtr)) diff --git a/src/main/scala/com/microsoft/ml/spark/lightgbm/TrainUtils.scala b/src/main/scala/com/microsoft/ml/spark/lightgbm/TrainUtils.scala index 9540acee32b..9dea854d6b6 100644 --- a/src/main/scala/com/microsoft/ml/spark/lightgbm/TrainUtils.scala +++ b/src/main/scala/com/microsoft/ml/spark/lightgbm/TrainUtils.scala @@ -9,6 +9,10 @@ import java.net._ import com.microsoft.ml.lightgbm._ import com.microsoft.ml.spark.core.env.StreamUtilities._ import com.microsoft.ml.spark.downloader.FaultToleranceUtils +import com.microsoft.ml.spark.lightgbm.booster.LightGBMBooster +import com.microsoft.ml.spark.lightgbm.dataset.LightGBMDataset +import com.microsoft.ml.spark.lightgbm.params.{ClassifierTrainParams, TrainParams} +import com.microsoft.ml.spark.lightgbm.swig.SwigUtils import org.apache.spark.{BarrierTaskContext, TaskContext} import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.attribute._ @@ -277,56 +281,35 @@ private object TrainUtils extends Serializable { groupCardinality } - def createBooster(trainParams: TrainParams, trainDatasetPtr: Option[LightGBMDataset], - validDatasetPtr: Option[LightGBMDataset]): Option[SWIGTYPE_p_void] = { + def createBooster(trainParams: TrainParams, trainDatasetPtr: LightGBMDataset, + validDatasetPtr: Option[LightGBMDataset]): LightGBMBooster = { // Create the booster - val boosterOutPtr = lightgbmlib.voidpp_handle() val parameters = trainParams.toString() - LightGBMUtils.validate(lightgbmlib.LGBM_BoosterCreate(trainDatasetPtr.map(_.dataset).get, - parameters, boosterOutPtr), "Booster") - val boosterPtr = Some(lightgbmlib.voidpp_value(boosterOutPtr)) + val booster = new LightGBMBooster(trainDatasetPtr, parameters) trainParams.modelString.foreach { modelStr => - val booster = LightGBMUtils.getBoosterPtrFromModelString(modelStr) - LightGBMUtils.validate(lightgbmlib.LGBM_BoosterMerge(boosterPtr.get, booster), "Booster Merge") + booster.mergeBooster(modelStr) } validDatasetPtr.foreach { lgbmdataset => - LightGBMUtils.validate(lightgbmlib.LGBM_BoosterAddValidData(boosterPtr.get, - lgbmdataset.dataset), "Add Validation Dataset") + booster.addValidationDataset(lgbmdataset) } - boosterPtr - } - - def saveBoosterToString(boosterPtr: Option[SWIGTYPE_p_void], log: Logger): String = { - val bufferLength = LightGBMConstants.DefaultBufferLength - val bufferOutLengthPtr = lightgbmlib.new_int64_tp() - lightgbmlib.LGBM_BoosterSaveModelToStringSWIG(boosterPtr.get, 0, -1, 0, bufferLength, bufferOutLengthPtr) - } - - def getEvalNames(boosterPtr: Option[SWIGTYPE_p_void]): Array[String] = { - // Need to keep track of best scores for each metric, see callback.py in lightgbm for reference - // For debugging, can get metric names - val stringArrayHandle = lightgbmlib.LGBM_BoosterGetEvalNamesSWIG(boosterPtr.get) - LightGBMUtils.validateArray(stringArrayHandle, "Booster Get Eval Names") - val evalNames = lightgbmlib.StringArrayHandle_get_strings(stringArrayHandle) - lightgbmlib.StringArrayHandle_free(stringArrayHandle) - evalNames + booster } def beforeTrainIteration(batchIndex: Int, partitionId: Int, curIters: Int, log: Logger, - trainParams: TrainParams, boosterPtr: Option[SWIGTYPE_p_void], hasValid: Boolean): Unit = { + trainParams: TrainParams, booster: LightGBMBooster, hasValid: Boolean): Unit = { if (trainParams.delegate.isDefined) { - trainParams.delegate.get.beforeTrainIteration(batchIndex, partitionId, curIters, log, trainParams, boosterPtr, + trainParams.delegate.get.beforeTrainIteration(batchIndex, partitionId, curIters, log, trainParams, booster, hasValid) } } def afterTrainIteration(batchIndex: Int, partitionId: Int, curIters: Int, log: Logger, - trainParams: TrainParams, boosterPtr: Option[SWIGTYPE_p_void], hasValid: Boolean, + trainParams: TrainParams, booster: LightGBMBooster, hasValid: Boolean, isFinished: Boolean, trainEvalResults: Option[Map[String, Double]], validEvalResults: Option[Map[String, Double]]): Unit = { if (trainParams.delegate.isDefined) { - trainParams.delegate.get.afterTrainIteration(batchIndex, partitionId, curIters, log, trainParams, boosterPtr, + trainParams.delegate.get.afterTrainIteration(batchIndex, partitionId, curIters, log, trainParams, booster, hasValid, isFinished, trainEvalResults, validEvalResults) } } @@ -340,12 +323,44 @@ private object TrainUtils extends Serializable { } } - def trainCore(batchIndex: Int, trainParams: TrainParams, boosterPtr: Option[SWIGTYPE_p_void], - log: Logger, hasValid: Boolean): Option[Int] = { + def updateOneIteration(trainParams: TrainParams, + booster: LightGBMBooster, + log: Logger, + iters: Int): Boolean = { + var isFinished = false val isFinishedPtr = lightgbmlib.new_intp() + try { + val result = + if (trainParams.objectiveParams.fobj.isDefined) { + val classification = trainParams.isInstanceOf[ClassifierTrainParams] + val gradient = trainParams.objectiveParams.fobj.get.getGradient( + booster.innerPredict(0, classification), booster.trainDataset.get) + val grad = SwigUtils.floatArrayToNative(gradient(0)) + val hess = SwigUtils.floatArrayToNative(gradient(1)) + lightgbmlib.LGBM_BoosterUpdateOneIterCustom(booster.boosterHandler.boosterPtr, grad, hess, isFinishedPtr) + } else { + lightgbmlib.LGBM_BoosterUpdateOneIter(booster.boosterHandler.boosterPtr, isFinishedPtr) + } + LightGBMUtils.validate(result, "Booster Update One Iter") + isFinished = lightgbmlib.intp_value(isFinishedPtr) == 1 + log.info("LightGBM running iteration: " + iters + " with is finished: " + isFinished) + } catch { + case e: java.lang.Exception => + log.warn("LightGBM reached early termination on one task," + + " stopping training on task. This message should rarely occur." + + " Inner exception: " + e.toString) + isFinished = true + } finally { + lightgbmlib.delete_intp(isFinishedPtr) + } + isFinished + } + + def trainCore(batchIndex: Int, trainParams: TrainParams, booster: LightGBMBooster, + log: Logger, hasValid: Boolean): Option[Int] = { var isFinished = false var iters = 0 - val evalNames = getEvalNames(boosterPtr) + val evalNames = booster.getEvalNames() val evalCounts = evalNames.length val bestScore = new Array[Double](evalCounts) val bestScores = new Array[Array[Double]](evalCounts) @@ -354,85 +369,55 @@ private object TrainUtils extends Serializable { var learningRate: Double = trainParams.learningRate var bestIterResult: Option[Int] = None while (!isFinished && iters < trainParams.numIterations) { - beforeTrainIteration(batchIndex, partitionId, iters, log, trainParams, boosterPtr, hasValid) + beforeTrainIteration(batchIndex, partitionId, iters, log, trainParams, booster, hasValid) val newLearningRate = getLearningRate(batchIndex, partitionId, iters, log, trainParams, learningRate) if (newLearningRate != learningRate) { log.info(s"LightGBM task calling LGBM_BoosterResetParameter to reset learningRate" + s" (newLearningRate: $newLearningRate)") - LightGBMUtils.validate(lightgbmlib.LGBM_BoosterResetParameter(boosterPtr.get, - s"learning_rate=$newLearningRate"), "Booster Reset learning_rate Param") + booster.resetParameter(s"learning_rate=$newLearningRate") learningRate = newLearningRate } - try { - val result = lightgbmlib.LGBM_BoosterUpdateOneIter(boosterPtr.get, isFinishedPtr) - LightGBMUtils.validate(result, "Booster Update One Iter") - isFinished = lightgbmlib.intp_value(isFinishedPtr) == 1 - log.info("LightGBM running iteration: " + iters + " with result: " + - result + " and is finished: " + isFinished) - } catch { - case _: java.lang.Exception => - isFinished = true - log.warn("LightGBM reached early termination on one task," + - " stopping training on task. This message should rarely occur") - } + isFinished = updateOneIteration(trainParams, booster, log, iters) val trainEvalResults: Option[Map[String, Double]] = if (trainParams.isProvideTrainingMetric && !isFinished) { - val trainResults = lightgbmlib.new_doubleArray(evalNames.length.toLong) - val dummyEvalCountsPtr = lightgbmlib.new_intp() - val resultEval = lightgbmlib.LGBM_BoosterGetEval(boosterPtr.get, 0, dummyEvalCountsPtr, trainResults) - lightgbmlib.delete_intp(dummyEvalCountsPtr) - LightGBMUtils.validate(resultEval, "Booster Get Train Eval") - - val results: Array[(String, Double)] = evalNames.zipWithIndex.map { case (evalName, index) => - val score = lightgbmlib.doubleArray_getitem(trainResults, index.toLong) - log.info(s"Train $evalName=$score") - (evalName, score) - } - - Option(Map(results:_*)) + val evalResults: Array[(String, Double)] = booster.getEvalResults(evalNames, 0) + evalResults.foreach { case (evalName: String, score: Double) => log.info(s"Train $evalName=$score") } + Option(Map(evalResults:_*)) } else { None } val validEvalResults: Option[Map[String, Double]] = if (hasValid && !isFinished) { - val evalResults = lightgbmlib.new_doubleArray(evalNames.length.toLong) - val dummyEvalCountsPtr = lightgbmlib.new_intp() - val resultEval = lightgbmlib.LGBM_BoosterGetEval(boosterPtr.get, 1, dummyEvalCountsPtr, evalResults) - lightgbmlib.delete_intp(dummyEvalCountsPtr) - LightGBMUtils.validate(resultEval, "Booster Get Valid Eval") - val results: Array[(String, Double)] = evalNames.zipWithIndex.map { case (evalName, index) => - val score = lightgbmlib.doubleArray_getitem(evalResults, index.toLong) - log.info(s"Valid $evalName=$score") + val evalResults: Array[(String, Double)] = booster.getEvalResults(evalNames, 1) + val results: Array[(String, Double)] = evalResults.zipWithIndex.map { case ((evalName, evalScore), index) => + log.info(s"Valid $evalName=$evalScore") val cmp = if (evalName.startsWith("auc") || evalName.startsWith("ndcg@") || evalName.startsWith("map@") || evalName.startsWith("average_precision")) (x: Double, y: Double, tol: Double) => x - y > tol else (x: Double, y: Double, tol: Double) => x - y < tol - if (bestScores(index) == null || cmp(score, bestScore(index), trainParams.improvementTolerance)) { - bestScore(index) = score + if (bestScores(index) == null || cmp(evalScore, bestScore(index), trainParams.improvementTolerance)) { + bestScore(index) = evalScore bestIter(index) = iters - bestScores(index) = evalNames.indices - .map(j => lightgbmlib.doubleArray_getitem(evalResults, j.toLong)).toArray + bestScores(index) = evalResults.map(_._2) } else if (iters - bestIter(index) >= trainParams.earlyStoppingRound) { isFinished = true log.info("Early stopping, best iteration is " + bestIter(index)) bestIterResult = Some(bestIter(index)) } - (evalName, score) + (evalName, evalScore) } - lightgbmlib.delete_doubleArray(evalResults) - Option(Map(results:_*)) } else { None } - afterTrainIteration(batchIndex, partitionId, iters, log, trainParams, boosterPtr, hasValid, isFinished, + afterTrainIteration(batchIndex, partitionId, iters, log, trainParams, booster, hasValid, isFinished, trainEvalResults, validEvalResults) iters = iters + 1 @@ -495,43 +480,42 @@ private object TrainUtils extends Serializable { def translate(batchIndex: Int, columnParams: ColumnParams, validationData: Option[Broadcast[Array[Row]]], log: Logger, trainParams: TrainParams, returnBooster: Boolean, schema: StructType, inputRows: Iterator[Row]): Iterator[LightGBMBooster] = { - var trainDatasetPtr: Option[LightGBMDataset] = None - var validDatasetPtr: Option[LightGBMDataset] = None + var trainDatasetOpt: Option[LightGBMDataset] = None + var validDatasetOpt: Option[LightGBMDataset] = None try { beforeGenerateTrainDataset(batchIndex, columnParams, schema, log, trainParams) - trainDatasetPtr = generateDataset(inputRows, columnParams, None, schema, log, trainParams) + trainDatasetOpt = generateDataset(inputRows, columnParams, None, schema, log, trainParams) afterGenerateTrainDataset(batchIndex, columnParams, schema, log, trainParams) if (validationData.isDefined) { beforeGenerateValidDataset(batchIndex, columnParams, schema, log, trainParams) - validDatasetPtr = generateDataset(validationData.get.value.toIterator, columnParams, - trainDatasetPtr, schema, log, trainParams) + validDatasetOpt = generateDataset(validationData.get.value.toIterator, columnParams, + trainDatasetOpt, schema, log, trainParams) afterGenerateValidDataset(batchIndex, columnParams, schema, log, trainParams) } - var boosterPtr: Option[SWIGTYPE_p_void] = None + var boosterOpt: Option[LightGBMBooster] = None try { - boosterPtr = createBooster(trainParams, trainDatasetPtr, validDatasetPtr) - val bestIterResult = trainCore(batchIndex, trainParams, boosterPtr, log, validDatasetPtr.isDefined) + val booster = createBooster(trainParams, trainDatasetOpt.get, validDatasetOpt) + boosterOpt = Some(booster) + val bestIterResult = trainCore(batchIndex, trainParams, booster, log, validDatasetOpt.isDefined) if (returnBooster) { - val model = saveBoosterToString(boosterPtr, log) - val booster = new LightGBMBooster(model) + val model = booster.saveToString() + val modelBooster = new LightGBMBooster(model) // Set best iteration on booster if hit early stopping criteria in trainCore - bestIterResult.foreach(booster.setBestIteration(_)) - Iterator.single(booster) + bestIterResult.foreach(modelBooster.setBestIteration(_)) + Iterator.single(modelBooster) } else { Iterator.empty } } finally { // Free booster - boosterPtr.foreach { booster => - LightGBMUtils.validate(lightgbmlib.LGBM_BoosterFree(booster), "Finalize Booster") - } + boosterOpt.foreach(_.freeNativeMemory()) } } finally { // Free datasets - trainDatasetPtr.foreach(_.close()) - validDatasetPtr.foreach(_.close()) + trainDatasetOpt.foreach(_.close()) + validDatasetOpt.foreach(_.close()) } } diff --git a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMBooster.scala b/src/main/scala/com/microsoft/ml/spark/lightgbm/booster/LightGBMBooster.scala similarity index 67% rename from src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMBooster.scala rename to src/main/scala/com/microsoft/ml/spark/lightgbm/booster/LightGBMBooster.scala index e9597f6bbc7..e92ee99379b 100644 --- a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMBooster.scala +++ b/src/main/scala/com/microsoft/ml/spark/lightgbm/booster/LightGBMBooster.scala @@ -1,12 +1,15 @@ // Copyright (C) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. See LICENSE in project root for information. -package com.microsoft.ml.spark.lightgbm +package com.microsoft.ml.spark.lightgbm.booster import com.microsoft.ml.lightgbm._ +import com.microsoft.ml.spark.lightgbm.{LightGBMConstants, LightGBMUtils} import com.microsoft.ml.spark.lightgbm.LightGBMUtils.getBoosterPtrFromModelString +import com.microsoft.ml.spark.lightgbm.dataset.LightGBMDataset import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector} import org.apache.spark.sql.{SaveMode, SparkSession} +import org.slf4j.Logger //scalastyle:off protected abstract class NativePtrHandler[T](val ptr: T) { @@ -32,14 +35,19 @@ protected class LongLongNativePtrHandler(ptr: SWIGTYPE_p_long_long) extends Nati /** Wraps the boosterPtr and guarantees that Native library is initialized * everytime it is needed - * @param model The string serialized representation of the learner + * @param boosterPtr The pointer to the native lightgbm booster */ -protected class BoosterHandler(model: String) { - LightGBMUtils.initializeNativeLibrary() +protected class BoosterHandler(var boosterPtr: SWIGTYPE_p_void) { - var boosterPtr: SWIGTYPE_p_void = { - getBoosterPtrFromModelString(model) + /** Wraps the boosterPtr and guarantees that Native library is initialized + * everytime it is needed + * + * @param model The string serialized representation of the learner + */ + def this(model: String) = { + this(getBoosterPtrFromModelString(model)) } + LightGBMUtils.initializeNativeLibrary() val scoredDataOutPtr: ThreadLocal[DoubleNativePtrHandler] = { new ThreadLocal[DoubleNativePtrHandler] { @@ -127,6 +135,13 @@ protected class BoosterHandler(model: String) { lazy val dataInt32bitType: Int = lightgbmlibConstants.C_API_DTYPE_INT32 lazy val data64bitType: Int = lightgbmlibConstants.C_API_DTYPE_FLOAT64 + def freeNativeMemory(): Unit = { + if (boosterPtr != null) { + LightGBMUtils.validate(lightgbmlib.LGBM_BoosterFree(boosterPtr), "Finalize Booster") + boosterPtr = null + } + } + private def getNumClasses: Int = { val numClassesOut = lightgbmlib.new_intp() LightGBMUtils.validate( @@ -167,13 +182,6 @@ protected class BoosterHandler(model: String) { out } - private def freeNativeMemory(): Unit = { - if (boosterPtr != null) { - LightGBMUtils.validate(lightgbmlib.LGBM_BoosterFree(boosterPtr), "Finalize Booster") - boosterPtr = null - } - } - override protected def finalize(): Unit = { freeNativeMemory() super.finalize() @@ -181,21 +189,134 @@ protected class BoosterHandler(model: String) { } /** Represents a LightGBM Booster learner - * @param model The string serialized representation of the learner + * @param trainDataset The training dataset + * @param parameters The booster initialization parameters + * @param modelStr Optional parameter with the string serialized representation of the learner */ @SerialVersionUID(777L) -class LightGBMBooster(val model: String) extends Serializable { - /** Transient variable containing local machine's pointer to native booster +class LightGBMBooster(val trainDataset: Option[LightGBMDataset] = None, val parameters: Option[String] = None, + modelStr: Option[String] = None) extends Serializable { + + /** Represents a LightGBM Booster learner + * @param trainDataset The training dataset + * @param parameters The booster initialization parameters + */ + def this(trainDataset: LightGBMDataset, parameters: String) = { + this(Some(trainDataset), Some(parameters)) + } + + /** Represents a LightGBM Booster learner + * @param model The string serialized representation of the learner */ + def this(model: String) = { + this(modelStr = Some(model)) + } + @transient lazy val boosterHandler: BoosterHandler = { - new BoosterHandler(model) + val boosterOutPtr = lightgbmlib.voidpp_handle() + if (trainDataset.isEmpty && model.isEmpty) { + throw new IllegalArgumentException("One of training dataset or serialized model parameters must be specified") + } + if (trainDataset.isEmpty) { + new BoosterHandler(model) + } else { + LightGBMUtils.validate(lightgbmlib.LGBM_BoosterCreate(trainDataset.map(_.datasetPtr).get, + parameters.get, boosterOutPtr), "Booster") + new BoosterHandler(lightgbmlib.voidpp_value(boosterOutPtr)) + } } + val model = modelStr.getOrElse("") var bestIteration: Int = -1 private var startIteration: Int = 0 private var numIterations: Int = -1 + /** Merges this Booster with the specified model. + * @param model The string serialized representation of the learner to merge. + */ + def mergeBooster(model: String): Unit = { + val boosterPtr = LightGBMUtils.getBoosterPtrFromModelString(model) + LightGBMUtils.validate(lightgbmlib.LGBM_BoosterMerge(boosterHandler.boosterPtr, boosterPtr), + "Booster Merge") + } + + /** Adds the specified LightGBMDataset to be the validation dataset. + * @param dataset The LightGBMDataset to add as the validation dataset. + */ + def addValidationDataset(dataset: LightGBMDataset): Unit = { + LightGBMUtils.validate(lightgbmlib.LGBM_BoosterAddValidData(boosterHandler.boosterPtr, + dataset.datasetPtr), "Add Validation Dataset") + } + + /** Saves the booster to string representation. + * @return The serialized string representation of the Booster. + */ + def saveToString(): String = { + val bufferLength = LightGBMConstants.DefaultBufferLength + val bufferOutLengthPtr = lightgbmlib.new_int64_tp() + lightgbmlib.LGBM_BoosterSaveModelToStringSWIG(boosterHandler.boosterPtr, + 0, -1, 0, bufferLength, bufferOutLengthPtr) + } + + /** Get the evaluation dataset column names from the native booster. + * @return The evaluation dataset column names. + */ + def getEvalNames(): Array[String] = { + // Need to keep track of best scores for each metric, see callback.py in lightgbm for reference + // For debugging, can get metric names + val stringArrayHandle = lightgbmlib.LGBM_BoosterGetEvalNamesSWIG(boosterHandler.boosterPtr) + LightGBMUtils.validateArray(stringArrayHandle, "Booster Get Eval Names") + val evalNames = lightgbmlib.StringArrayHandle_get_strings(stringArrayHandle) + lightgbmlib.StringArrayHandle_free(stringArrayHandle) + evalNames + } + + def getEvalResults(evalNames: Array[String], dataIndex: Int): Array[(String, Double)] = { + val evalResults = lightgbmlib.new_doubleArray(evalNames.length.toLong) + val dummyEvalCountsPtr = lightgbmlib.new_intp() + val resultEval = lightgbmlib.LGBM_BoosterGetEval(boosterHandler.boosterPtr, dataIndex, + dummyEvalCountsPtr, evalResults) + lightgbmlib.delete_intp(dummyEvalCountsPtr) + LightGBMUtils.validate(resultEval, s"Booster Get Eval Results for data index: ${dataIndex}") + + val results: Array[(String, Double)] = evalNames.zipWithIndex.map { case (evalName, index) => + val score = lightgbmlib.doubleArray_getitem(evalResults, index.toLong) + (evalName, score) + } + lightgbmlib.delete_doubleArray(evalResults) + results + } + + /** Reset the specified parameters on the native booster. + * @param newParameters The new parameters to set. + */ + def resetParameter(newParameters: String) = { + LightGBMUtils.validate(lightgbmlib.LGBM_BoosterResetParameter(boosterHandler.boosterPtr, + newParameters), "Booster Reset learning_rate Param") + } + + def innerPredict(dataIndex: Int, classification: Boolean): Array[Array[Double]] = { + val numRows = this.trainDataset.get.numData() + val scoredDataOutPtr = lightgbmlib.new_doubleArray(numClasses.toLong * numRows) + val scoredDataLengthPtr = lightgbmlib.new_int64_tp() + lightgbmlib.int64_tp_assign(scoredDataLengthPtr, 1) + lightgbmlib.LGBM_BoosterGetPredict(boosterHandler.boosterPtr, dataIndex, + scoredDataLengthPtr, scoredDataOutPtr) + val scoredDataLength = lightgbmlib.int64_tp_value(scoredDataLengthPtr) + if (classification && numClasses == 1) { + (0L until scoredDataLength).map(index => + Array(lightgbmlib.doubleArray_getitem(scoredDataOutPtr, index))).toArray + } else { + val numRows = scoredDataLength / numClasses + (0L until numRows).map(rowIndex => { + val startIndex = rowIndex * numClasses + (0 until numClasses).map(classIndex => + lightgbmlib.doubleArray_getitem(scoredDataOutPtr, startIndex + classIndex)).toArray + }).toArray + } + } + def score(features: Vector, raw: Boolean, classification: Boolean): Array[Double] = { val kind = if (raw) boosterHandler.rawScoreConstant @@ -255,6 +376,57 @@ class LightGBMBooster(val model: String) extends Serializable { this.numIterations = bestIteration } + /** Saves the native model serialized representation to file. + * @param session The spark session + * @param filename The name of the file to save the model to + * @param overwrite Whether to overwrite if the file already exists + */ + def saveNativeModel(session: SparkSession, filename: String, overwrite: Boolean): Unit = { + if (filename == null || filename.isEmpty) { + throw new IllegalArgumentException("filename should not be empty or null.") + } + val rdd = session.sparkContext.parallelize(Seq(model)) + import session.sqlContext.implicits._ + val dataset = session.sqlContext.createDataset(rdd) + val mode = if (overwrite) SaveMode.Overwrite else SaveMode.ErrorIfExists + dataset.coalesce(1).write.mode(mode).text(filename) + } + + /** Dumps the native model pointer to file. + * @param session The spark session + * @param filename The name of the file to save the model to + * @param overwrite Whether to overwrite if the file already exists + */ + def dumpModel(session: SparkSession, filename: String, overwrite: Boolean): Unit = { + val json = lightgbmlib.LGBM_BoosterDumpModelSWIG(boosterHandler.boosterPtr, 0, -1, 0, 1, + boosterHandler.dumpModelOutPtr.get().ptr) + val rdd = session.sparkContext.parallelize(Seq(json)) + import session.sqlContext.implicits._ + val dataset = session.sqlContext.createDataset(rdd) + val mode = if (overwrite) SaveMode.Overwrite else SaveMode.ErrorIfExists + dataset.coalesce(1).write.mode(mode).text(filename) + } + + /** Frees any native memory held by the underlying booster pointer. + */ + def freeNativeMemory(): Unit = { + boosterHandler.freeNativeMemory() + } + + /** + * Calls into LightGBM to retrieve the feature importances. + * @param importanceType Can be "split" or "gain" + * @return The feature importance values as an array. + */ + def getFeatureImportances(importanceType: String): Array[Double] = { + val importanceTypeNum = if (importanceType.toLowerCase.trim == "gain") 1 else 0 + LightGBMUtils.validate( + lightgbmlib.LGBM_BoosterFeatureImportance(boosterHandler.boosterPtr, -1, + importanceTypeNum, boosterHandler.featureImportanceOutPtr.get().ptr), + "Booster FeatureImportance") + (0L until numFeatures.toLong).map(lightgbmlib.doubleArray_getitem(boosterHandler.featureImportanceOutPtr.get().ptr, _)).toArray + } + lazy val numClasses: Int = boosterHandler.numClasses lazy val numFeatures: Int = boosterHandler.numFeatures @@ -302,41 +474,6 @@ class LightGBMBooster(val model: String) extends Serializable { "Booster Predict") } - def saveNativeModel(session: SparkSession, filename: String, overwrite: Boolean): Unit = { - if (filename == null || filename.isEmpty) { - throw new IllegalArgumentException("filename should not be empty or null.") - } - val rdd = session.sparkContext.parallelize(Seq(model)) - import session.sqlContext.implicits._ - val dataset = session.sqlContext.createDataset(rdd) - val mode = if (overwrite) SaveMode.Overwrite else SaveMode.ErrorIfExists - dataset.coalesce(1).write.mode(mode).text(filename) - } - - def dumpModel(session: SparkSession, filename: String, overwrite: Boolean): Unit = { - val json = lightgbmlib.LGBM_BoosterDumpModelSWIG(boosterHandler.boosterPtr, 0, -1, 0, 1, - boosterHandler.dumpModelOutPtr.get().ptr) - val rdd = session.sparkContext.parallelize(Seq(json)) - import session.sqlContext.implicits._ - val dataset = session.sqlContext.createDataset(rdd) - val mode = if (overwrite) SaveMode.Overwrite else SaveMode.ErrorIfExists - dataset.coalesce(1).write.mode(mode).text(filename) - } - - /** - * Calls into LightGBM to retrieve the feature importances. - * @param importanceType Can be "split" or "gain" - * @return The feature importance values as an array. - */ - def getFeatureImportances(importanceType: String): Array[Double] = { - val importanceTypeNum = if (importanceType.toLowerCase.trim == "gain") 1 else 0 - LightGBMUtils.validate( - lightgbmlib.LGBM_BoosterFeatureImportance(boosterHandler.boosterPtr, -1, - importanceTypeNum, boosterHandler.featureImportanceOutPtr.get().ptr), - "Booster FeatureImportance") - (0L until numFeatures.toLong).map(lightgbmlib.doubleArray_getitem(boosterHandler.featureImportanceOutPtr.get().ptr, _)).toArray - } - private def predScoreToArray(classification: Boolean, scoredDataOutPtr: SWIGTYPE_p_double, kind: Int): Array[Double] = { if (classification && numClasses == 1) { diff --git a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMDataset.scala b/src/main/scala/com/microsoft/ml/spark/lightgbm/dataset/LightGBMDataset.scala similarity index 60% rename from src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMDataset.scala rename to src/main/scala/com/microsoft/ml/spark/lightgbm/dataset/LightGBMDataset.scala index 1cd222419f8..60fd33623e8 100644 --- a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMDataset.scala +++ b/src/main/scala/com/microsoft/ml/spark/lightgbm/dataset/LightGBMDataset.scala @@ -1,30 +1,81 @@ // Copyright (C) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. See LICENSE in project root for information. -package com.microsoft.ml.spark.lightgbm +package com.microsoft.ml.spark.lightgbm.dataset -import com.microsoft.ml.lightgbm.{floatChunkedArray, _} +import com.microsoft.lightgbm.SwigPtrWrapper +import com.microsoft.ml.lightgbm._ +import com.microsoft.ml.spark.lightgbm.LightGBMUtils + +import scala.reflect.ClassTag /** Represents a LightGBM dataset. * Wraps the native implementation. - * @param dataset The native representation of the dataset. + * @param datasetPtr The native representation of the dataset. */ -class LightGBMDataset(val dataset: SWIGTYPE_p_void) extends AutoCloseable { - def validateDataset(): Unit = { - // Validate num rows +class LightGBMDataset(val datasetPtr: SWIGTYPE_p_void) extends AutoCloseable { + def getLabel(): Array[Float] = { + getField[Float]("label") + } + + def getField[T: ClassTag](fieldName: String): Array[T] = { + // The result length + val tmpOutLenPtr = lightgbmlib.new_int32_tp() + // The type of the result array + val outTypePtr = lightgbmlib.new_int32_tp() + // The pointer to the result + val outArray = lightgbmlib.new_voidpp() + lightgbmlib.LGBM_DatasetGetField(datasetPtr, fieldName, tmpOutLenPtr, outArray, outTypePtr) + val outType = lightgbmlib.int32_tp_value(outTypePtr) + val outLength = lightgbmlib.int32_tp_value(tmpOutLenPtr) + // Note: hacky workaround for now until new pointer manipulation functions are added + val voidptr = lightgbmlib.voidpp_value(outArray) + val address = new SwigPtrWrapper(voidptr).getCPtrValue() + if (outType == lightgbmlibConstants.C_API_DTYPE_INT32) { + (0 until outLength).map(index => + lightgbmlibJNI.intArray_getitem(address, index).asInstanceOf[T]).toArray + } else if (outType == lightgbmlibConstants.C_API_DTYPE_FLOAT32) { + (0 until outLength).map(index => + lightgbmlibJNI.floatArray_getitem(address, index).asInstanceOf[T]).toArray + } else if (outType == lightgbmlibConstants.C_API_DTYPE_FLOAT64) { + (0 until outLength).map(index => + lightgbmlibJNI.doubleArray_getitem(address, index).asInstanceOf[T]).toArray + } else { + throw new Exception("Unknown type returned from native lightgbm in LightGBMDataset getField") + } + } + + /** Get the number of rows in the Dataset. + * @return The number of rows. + */ + def numData(): Int = { val numDataPtr = lightgbmlib.new_intp() - LightGBMUtils.validate(lightgbmlib.LGBM_DatasetGetNumData(dataset, numDataPtr), "DatasetGetNumData") + LightGBMUtils.validate(lightgbmlib.LGBM_DatasetGetNumData(datasetPtr, numDataPtr), "DatasetGetNumData") val numData = lightgbmlib.intp_value(numDataPtr) + lightgbmlib.delete_intp(numDataPtr) + numData + } + + /** Get the number of features in the Dataset. + * @return The number of features. + */ + def numFeature(): Int = { + val numFeaturePtr = lightgbmlib.new_intp() + LightGBMUtils.validate(lightgbmlib.LGBM_DatasetGetNumFeature(datasetPtr, numFeaturePtr), "DatasetGetNumFeature") + val numFeature = lightgbmlib.intp_value(numFeaturePtr) + lightgbmlib.delete_intp(numFeaturePtr) + numFeature + } + + def validateDataset(): Unit = { + // Validate num rows + val numData = this.numData() if (numData <= 0) { throw new Exception("Unexpected num data: " + numData) } // Validate num cols - val numFeaturePtr = lightgbmlib.new_intp() - LightGBMUtils.validate( - lightgbmlib.LGBM_DatasetGetNumFeature(dataset, numFeaturePtr), - "DatasetGetNumFeature") - val numFeature = lightgbmlib.intp_value(numFeaturePtr) + val numFeature = this.numFeature() if (numFeature <= 0) { throw new Exception("Unexpected num feature: " + numFeature) } @@ -60,7 +111,7 @@ class LightGBMDataset(val dataset: SWIGTYPE_p_void) extends AutoCloseable { val colAsVoidPtr = lightgbmlib.float_to_voidp_ptr(field) val data32bitType = lightgbmlibConstants.C_API_DTYPE_FLOAT32 LightGBMUtils.validate( - lightgbmlib.LGBM_DatasetSetField(dataset, fieldName, colAsVoidPtr, numRows, data32bitType), + lightgbmlib.LGBM_DatasetSetField(datasetPtr, fieldName, colAsVoidPtr, numRows, data32bitType), "DatasetSetField") } @@ -94,7 +145,7 @@ class LightGBMDataset(val dataset: SWIGTYPE_p_void) extends AutoCloseable { val colAsVoidPtr = lightgbmlib.double_to_voidp_ptr(field) val data64bitType = lightgbmlibConstants.C_API_DTYPE_FLOAT64 LightGBMUtils.validate( - lightgbmlib.LGBM_DatasetSetField(dataset, fieldName, colAsVoidPtr, numRows, data64bitType), + lightgbmlib.LGBM_DatasetSetField(datasetPtr, fieldName, colAsVoidPtr, numRows, data64bitType), "DatasetSetField") } @@ -108,7 +159,7 @@ class LightGBMDataset(val dataset: SWIGTYPE_p_void) extends AutoCloseable { val colAsVoidPtr = lightgbmlib.int_to_voidp_ptr(colArray.get) val data32bitType = lightgbmlibConstants.C_API_DTYPE_INT32 LightGBMUtils.validate( - lightgbmlib.LGBM_DatasetSetField(dataset, fieldName, colAsVoidPtr, numRows, data32bitType), + lightgbmlib.LGBM_DatasetSetField(datasetPtr, fieldName, colAsVoidPtr, numRows, data32bitType), "DatasetSetField") } finally { // Free column @@ -120,7 +171,7 @@ class LightGBMDataset(val dataset: SWIGTYPE_p_void) extends AutoCloseable { // Add in slot names if they exist featureNamesOpt.foreach { featureNamesVal => if (featureNamesVal.nonEmpty) { - LightGBMUtils.validate(lightgbmlib.LGBM_DatasetSetFeatureNames(dataset, featureNamesVal, numCols), + LightGBMUtils.validate(lightgbmlib.LGBM_DatasetSetFeatureNames(datasetPtr, featureNamesVal, numCols), "Dataset set feature names") } } @@ -128,6 +179,6 @@ class LightGBMDataset(val dataset: SWIGTYPE_p_void) extends AutoCloseable { override def close(): Unit = { // Free dataset - LightGBMUtils.validate(lightgbmlib.LGBM_DatasetFree(dataset), "Finalize Dataset") + LightGBMUtils.validate(lightgbmlib.LGBM_DatasetFree(datasetPtr), "Finalize Dataset") } } diff --git a/src/main/scala/com/microsoft/ml/spark/lightgbm/params/FObjParam.scala b/src/main/scala/com/microsoft/ml/spark/lightgbm/params/FObjParam.scala new file mode 100644 index 00000000000..ff166a9d388 --- /dev/null +++ b/src/main/scala/com/microsoft/ml/spark/lightgbm/params/FObjParam.scala @@ -0,0 +1,19 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.ml.spark.lightgbm.params + +import com.microsoft.ml.spark.core.serialize.ComplexParam +import org.apache.spark.ml.param.Params + +/** Param for FObjTrait. Needed as spark has explicit params for many different + * types but not FObjTrait. + */ +class FObjParam(parent: Params, name: String, doc: String, + isValid: FObjTrait => Boolean) + + extends ComplexParam[FObjTrait](parent, name, doc, isValid) { + + def this(parent: Params, name: String, doc: String) = + this(parent, name, doc, {_ => true}) +} diff --git a/src/main/scala/com/microsoft/ml/spark/lightgbm/params/FObjTrait.scala b/src/main/scala/com/microsoft/ml/spark/lightgbm/params/FObjTrait.scala new file mode 100644 index 00000000000..fae0070e2cb --- /dev/null +++ b/src/main/scala/com/microsoft/ml/spark/lightgbm/params/FObjTrait.scala @@ -0,0 +1,17 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.ml.spark.lightgbm.params + +import com.microsoft.ml.spark.lightgbm.dataset.LightGBMDataset + +trait FObjTrait extends Serializable { + /** + * User defined objective function, returns gradient and second order gradient + * + * @param predictions untransformed margin predicts + * @param trainingData training data + * @return List with two float array, correspond to grad and hess + */ + def getGradient(predictions: Array[Array[Double]], trainingData: LightGBMDataset): List[Array[Float]] +} diff --git a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMBoosterParam.scala b/src/main/scala/com/microsoft/ml/spark/lightgbm/params/LightGBMBoosterParam.scala similarity index 85% rename from src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMBoosterParam.scala rename to src/main/scala/com/microsoft/ml/spark/lightgbm/params/LightGBMBoosterParam.scala index ed1429f84e4..50afdec45b0 100644 --- a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMBoosterParam.scala +++ b/src/main/scala/com/microsoft/ml/spark/lightgbm/params/LightGBMBoosterParam.scala @@ -1,9 +1,10 @@ // Copyright (C) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. See LICENSE in project root for information. -package com.microsoft.ml.spark.lightgbm +package com.microsoft.ml.spark.lightgbm.params import com.microsoft.ml.spark.core.serialize.ComplexParam +import com.microsoft.ml.spark.lightgbm.booster.LightGBMBooster import org.apache.spark.ml.param.Params /** Custom ComplexParam for LightGBMBooster, to make it settable on the LightGBM models. diff --git a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMParams.scala b/src/main/scala/com/microsoft/ml/spark/lightgbm/params/LightGBMParams.scala similarity index 96% rename from src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMParams.scala rename to src/main/scala/com/microsoft/ml/spark/lightgbm/params/LightGBMParams.scala index a37d8b94267..d38b291fc1f 100644 --- a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMParams.scala +++ b/src/main/scala/com/microsoft/ml/spark/lightgbm/params/LightGBMParams.scala @@ -1,10 +1,12 @@ // Copyright (C) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. See LICENSE in project root for information. -package com.microsoft.ml.spark.lightgbm +package com.microsoft.ml.spark.lightgbm.params import com.microsoft.ml.spark.codegen.Wrappable import com.microsoft.ml.spark.core.contracts.{HasInitScoreCol, HasValidationIndicatorCol, HasWeightCol} +import com.microsoft.ml.spark.lightgbm.booster.LightGBMBooster +import com.microsoft.ml.spark.lightgbm.{LightGBMConstants, LightGBMDelegate} import org.apache.spark.ml.param._ import org.apache.spark.ml.util.DefaultParamsWritable @@ -273,12 +275,31 @@ trait LightGBMModelParams extends Wrappable { def setNumIterations(value: Int): this.type = set(numIterations, value) } +/** Defines common objective parameters + */ +trait LightGBMObjectiveParams extends Wrappable { + val objective = new Param[String](this, "objective", + "The Objective. For regression applications, this can be: " + + "regression_l2, regression_l1, huber, fair, poisson, quantile, mape, gamma or tweedie. " + + "For classification applications, this can be: binary, multiclass, or multiclassova. ") + setDefault(objective -> "regression") + + def getObjective: String = $(objective) + def setObjective(value: String): this.type = set(objective, value) + + val fobj = new FObjParam(this, "fobj", "Customized objective function. " + + "Should accept two parameters: preds, train_data, and return (grad, hess).") + + def getFObj: FObjTrait = $(fobj) + def setFObj(value: FObjTrait): this.type = set(fobj, value) +} + /** Defines common parameters across all LightGBM learners. */ trait LightGBMParams extends Wrappable with DefaultParamsWritable with HasWeightCol with HasValidationIndicatorCol with HasInitScoreCol with LightGBMExecutionParams with LightGBMSlotParams with LightGBMFractionParams with LightGBMBinParams with LightGBMLearnerParams - with LightGBMDartParams with LightGBMPredictionParams { + with LightGBMDartParams with LightGBMPredictionParams with LightGBMObjectiveParams { val numIterations = new IntParam(this, "numIterations", "Number of iterations, LightGBM constructs num_class * num_iterations trees") setDefault(numIterations->100) @@ -298,15 +319,6 @@ trait LightGBMParams extends Wrappable with DefaultParamsWritable with HasWeight def getNumLeaves: Int = $(numLeaves) def setNumLeaves(value: Int): this.type = set(numLeaves, value) - val objective = new Param[String](this, "objective", - "The Objective. For regression applications, this can be: " + - "regression_l2, regression_l1, huber, fair, poisson, quantile, mape, gamma or tweedie. " + - "For classification applications, this can be: binary, multiclass, or multiclassova. ") - setDefault(objective -> "regression") - - def getObjective: String = $(objective) - def setObjective(value: String): this.type = set(objective, value) - val baggingFreq = new IntParam(this, "baggingFreq", "Bagging frequency") setDefault(baggingFreq->0) diff --git a/src/main/scala/com/microsoft/ml/spark/lightgbm/TrainParams.scala b/src/main/scala/com/microsoft/ml/spark/lightgbm/params/TrainParams.scala similarity index 70% rename from src/main/scala/com/microsoft/ml/spark/lightgbm/TrainParams.scala rename to src/main/scala/com/microsoft/ml/spark/lightgbm/params/TrainParams.scala index dfd787d8617..2f74805d638 100644 --- a/src/main/scala/com/microsoft/ml/spark/lightgbm/TrainParams.scala +++ b/src/main/scala/com/microsoft/ml/spark/lightgbm/params/TrainParams.scala @@ -1,7 +1,9 @@ // Copyright (C) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. See LICENSE in project root for information. -package com.microsoft.ml.spark.lightgbm +package com.microsoft.ml.spark.lightgbm.params + +import com.microsoft.ml.spark.lightgbm.{LightGBMConstants, LightGBMDelegate} /** Defines the common Booster parameters passed to the LightGBM learners. */ @@ -24,7 +26,6 @@ abstract class TrainParams extends Serializable { def maxDepth: Int def minSumHessianInLeaf: Double def numMachines: Int - def objective: String def modelString: Option[String] def verbosity: Int def categoricalFeatures: Array[Int] @@ -41,6 +42,7 @@ abstract class TrainParams extends Serializable { def delegate: Option[LightGBMDelegate] def dartModeParams: DartModeParams def executionParams: ExecutionParams + def objectiveParams: ObjectiveParams override def toString: String = { // Since passing `isProvideTrainingMetric` to LightGBM as a config parameter won't work, @@ -51,9 +53,9 @@ abstract class TrainParams extends Serializable { s"neg_bagging_fraction=$negBaggingFraction bagging_freq=$baggingFreq " + s"bagging_seed=$baggingSeed early_stopping_round=$earlyStoppingRound " + s"feature_fraction=$featureFraction max_depth=$maxDepth min_sum_hessian_in_leaf=$minSumHessianInLeaf " + - s"num_machines=$numMachines objective=$objective verbosity=$verbosity " + + s"num_machines=$numMachines verbosity=$verbosity " + s"lambda_l1=$lambdaL1 lambda_l2=$lambdaL2 metric=$metric min_gain_to_split=$minGainToSplit " + - s"max_delta_step=$maxDeltaStep min_data_in_leaf=$minDataInLeaf " + + s"max_delta_step=$maxDeltaStep min_data_in_leaf=$minDataInLeaf ${objectiveParams.toString()} " + (if (categoricalFeatures.isEmpty) "" else s"categorical_feature=${categoricalFeatures.mkString(",")} ") + (if (maxBinByFeature.isEmpty) "" else s"max_bin_by_feature=${maxBinByFeature.mkString(",")} ") + (if (boostingType == "dart") s"${dartModeParams.toString()}" else "") @@ -68,18 +70,18 @@ case class ClassifierTrainParams(parallelism: String, topK: Int, numIterations: baggingFreq: Int, baggingSeed: Int, earlyStoppingRound: Int, improvementTolerance: Double, featureFraction: Double, maxDepth: Int, minSumHessianInLeaf: Double, - numMachines: Int, objective: String, modelString: Option[String], - isUnbalance: Boolean, verbosity: Int, categoricalFeatures: Array[Int], - numClass: Int, boostFromAverage: Boolean, - boostingType: String, lambdaL1: Double, lambdaL2: Double, + numMachines: Int, modelString: Option[String], isUnbalance: Boolean, + verbosity: Int, categoricalFeatures: Array[Int], numClass: Int, + boostFromAverage: Boolean, boostingType: String, lambdaL1: Double, lambdaL2: Double, isProvideTrainingMetric: Boolean, metric: String, minGainToSplit: Double, maxDeltaStep: Double, maxBinByFeature: Array[Int], minDataInLeaf: Int, featureNames: Array[String], delegate: Option[LightGBMDelegate], - dartModeParams: DartModeParams, executionParams: ExecutionParams) + dartModeParams: DartModeParams, executionParams: ExecutionParams, + objectiveParams: ObjectiveParams) extends TrainParams { override def toString(): String = { val extraStr = - if (objective != LightGBMConstants.BinaryObjective) s"num_class=$numClass" + if (objectiveParams.objective != LightGBMConstants.BinaryObjective) s"num_class=$numClass" else s"is_unbalance=${isUnbalance.toString}" s"metric=$metric boost_from_average=${boostFromAverage.toString} ${super.toString()} $extraStr" } @@ -88,11 +90,10 @@ case class ClassifierTrainParams(parallelism: String, topK: Int, numIterations: /** Defines the Booster parameters passed to the LightGBM regressor. */ case class RegressorTrainParams(parallelism: String, topK: Int, numIterations: Int, learningRate: Double, - numLeaves: Int, objective: String, alpha: Double, - tweedieVariancePower: Double, maxBin: Int, binSampleCount: Int, - baggingFraction: Double, posBaggingFraction: Double, negBaggingFraction: Double, - baggingFreq: Int, baggingSeed: Int, earlyStoppingRound: Int, - improvementTolerance: Double, featureFraction: Double, + numLeaves: Int, alpha: Double, tweedieVariancePower: Double, maxBin: Int, + binSampleCount: Int, baggingFraction: Double, posBaggingFraction: Double, + negBaggingFraction: Double, baggingFreq: Int, baggingSeed: Int, + earlyStoppingRound: Int, improvementTolerance: Double, featureFraction: Double, maxDepth: Int, minSumHessianInLeaf: Double, numMachines: Int, modelString: Option[String], verbosity: Int, categoricalFeatures: Array[Int], boostFromAverage: Boolean, @@ -100,7 +101,8 @@ case class RegressorTrainParams(parallelism: String, topK: Int, numIterations: I isProvideTrainingMetric: Boolean, metric: String, minGainToSplit: Double, maxDeltaStep: Double, maxBinByFeature: Array[Int], minDataInLeaf: Int, featureNames: Array[String], delegate: Option[LightGBMDelegate], - dartModeParams: DartModeParams, executionParams: ExecutionParams) + dartModeParams: DartModeParams, executionParams: ExecutionParams, + objectiveParams: ObjectiveParams) extends TrainParams { override def toString(): String = { s"alpha=$alpha tweedie_variance_power=$tweedieVariancePower boost_from_average=${boostFromAverage.toString} " + @@ -111,9 +113,9 @@ case class RegressorTrainParams(parallelism: String, topK: Int, numIterations: I /** Defines the Booster parameters passed to the LightGBM ranker. */ case class RankerTrainParams(parallelism: String, topK: Int, numIterations: Int, learningRate: Double, - numLeaves: Int, objective: String, maxBin: Int, binSampleCount: Int, - baggingFraction: Double, posBaggingFraction: Double, negBaggingFraction: Double, - baggingFreq: Int, baggingSeed: Int, earlyStoppingRound: Int, improvementTolerance: Double, + numLeaves: Int, maxBin: Int, binSampleCount: Int, baggingFraction: Double, + posBaggingFraction: Double, negBaggingFraction: Double, baggingFreq: Int, + baggingSeed: Int, earlyStoppingRound: Int, improvementTolerance: Double, featureFraction: Double, maxDepth: Int, minSumHessianInLeaf: Double, numMachines: Int, modelString: Option[String], verbosity: Int, categoricalFeatures: Array[Int], boostingType: String, @@ -122,7 +124,8 @@ case class RankerTrainParams(parallelism: String, topK: Int, numIterations: Int, metric: String, evalAt: Array[Int], minGainToSplit: Double, maxDeltaStep: Double, maxBinByFeature: Array[Int], minDataInLeaf: Int, featureNames: Array[String], delegate: Option[LightGBMDelegate], - dartModeParams: DartModeParams, executionParams: ExecutionParams) + dartModeParams: DartModeParams, executionParams: ExecutionParams, + objectiveParams: ObjectiveParams) extends TrainParams { override def toString(): String = { val labelGainStr = @@ -143,4 +146,28 @@ case class DartModeParams(dropRate: Double, maxDrop: Int, skipDrop: Double, } } +/** Defines parameters related to lightgbm execution in spark. + * + * @param chunkSize Advanced parameter to specify the chunk size for copying Java data to native. + * @param matrixType Advanced parameter to specify whether the native lightgbm matrix + * constructed should be sparse or dense. + */ case class ExecutionParams(chunkSize: Int, matrixType: String) extends Serializable + +/** Defines parameters related to the lightgbm objective function. + * + * @param objective The Objective. For regression applications, this can be: + * regression_l2, regression_l1, huber, fair, poisson, quantile, mape, gamma or tweedie. + * For classification applications, this can be: binary, multiclass, or multiclassova. + * @param fobj Customized objective function. + * Should accept two parameters: preds, train_data, and return (grad, hess). + */ +case class ObjectiveParams(objective: String, fobj: Option[FObjTrait]) extends Serializable { + override def toString(): String = { + if (fobj.isEmpty) { + s"objective=$objective " + } else { + "" + } + } +} diff --git a/src/main/scala/com/microsoft/ml/spark/lightgbm/swig/SwigUtils.scala b/src/main/scala/com/microsoft/ml/spark/lightgbm/swig/SwigUtils.scala new file mode 100644 index 00000000000..c9b192553e6 --- /dev/null +++ b/src/main/scala/com/microsoft/ml/spark/lightgbm/swig/SwigUtils.scala @@ -0,0 +1,15 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.ml.spark.lightgbm.swig + +import com.microsoft.ml.lightgbm.{SWIGTYPE_p_float, lightgbmlib} + +object SwigUtils extends Serializable { + def floatArrayToNative(array: Array[Float]): SWIGTYPE_p_float = { + val colArray = lightgbmlib.new_floatArray(array.length) + array.zipWithIndex.foreach(ri => + lightgbmlib.floatArray_setitem(colArray, ri._2.toLong, ri._1.toFloat)) + colArray + } +} diff --git a/src/test/scala/com/microsoft/ml/spark/lightgbm/split1/VerifyLightGBMClassifier.scala b/src/test/scala/com/microsoft/ml/spark/lightgbm/split1/VerifyLightGBMClassifier.scala index 3efbba9765c..e7ba5b31e0e 100644 --- a/src/test/scala/com/microsoft/ml/spark/lightgbm/split1/VerifyLightGBMClassifier.scala +++ b/src/test/scala/com/microsoft/ml/spark/lightgbm/split1/VerifyLightGBMClassifier.scala @@ -12,6 +12,8 @@ import com.microsoft.ml.spark.core.test.benchmarks.{Benchmarks, DatasetUtils} import com.microsoft.ml.spark.core.test.fuzzing.{EstimatorFuzzing, TestObject} import com.microsoft.ml.spark.featurize.ValueIndexer import com.microsoft.ml.spark.lightgbm._ +import com.microsoft.ml.spark.lightgbm.dataset.LightGBMDataset +import com.microsoft.ml.spark.lightgbm.params.{FObjTrait, TrainParams} import com.microsoft.ml.spark.stages.{MultiColumnAdapter, SPConstants, StratifiedRepartition} import org.apache.commons.io.FileUtils import org.apache.spark.TaskContext @@ -26,6 +28,8 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.functions._ import org.slf4j.Logger +import scala.math.exp + @SerialVersionUID(100L) class TrainDelegate extends LightGBMDelegate { @@ -310,6 +314,28 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM assertMulticlassImprovement(scoredDF1, scoredDF2) } + test("Verify LightGBM Classifier with custom loss function") { + class LogLikelihood extends FObjTrait { + override def getGradient(predictions: Array[Array[Double]], + trainingData: LightGBMDataset): List[Array[Float]] = { + // Get the labels + val labels = trainingData.getLabel() + val probabilities = predictions.map(rowPrediction => + rowPrediction.map(prediction => 1.0 / (1.0 + exp(-prediction)))) + // Compute gradient and hessian + val grad = probabilities.zip(labels).map { + case (prob: Array[Double], label: Float) => (prob(0) - label).toFloat + } + val hess = probabilities.map(probabilityArray => (probabilityArray(0) * (1 - probabilityArray(0))).toFloat) + List(grad, hess) + } + } + val Array(train, test) = pimaDF.randomSplit(Array(0.8, 0.2), seed) + val scoredDF1 = baseModel.fit(train).transform(test) + val scoredDF2 = baseModel.setFObj(new LogLikelihood()).fit(train).transform(test) + assertBinaryImprovement(scoredDF1, scoredDF2) + } + test("Verify LightGBM Classifier with min gain to split parameter") { // If the min gain to split is too high, assert AUC lower for training data (assert parameter works) val scoredDF1 = baseModel.setMinGainToSplit(99999).fit(pimaDF).transform(pimaDF)