Skip to content

Commit

Permalink
feat: add custom objective function to lightgbm learners
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed May 24, 2021
1 parent d8bb51f commit 392f4a5
Show file tree
Hide file tree
Showing 18 changed files with 549 additions and 217 deletions.
10 changes: 10 additions & 0 deletions src/main/scala/com/microsoft/lightgbm/SWIG.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.lightgbm

import com.microsoft.ml.lightgbm.SWIGTYPE_p_void

class SwigPtrWrapper(val value: SWIGTYPE_p_void) extends SWIGTYPE_p_void {
def getCPtrValue(): Long = SWIGTYPE_p_void.getCPtr(value)
}
19 changes: 19 additions & 0 deletions src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
package com.microsoft.ml.spark.lightgbm

import com.microsoft.ml.spark.core.utils.ClusterUtil
import com.microsoft.ml.spark.lightgbm.booster.LightGBMBooster
import com.microsoft.ml.spark.lightgbm.params.{DartModeParams, ExecutionParams, LightGBMParams,
ObjectiveParams, TrainParams}
import com.microsoft.ml.spark.logging.BasicLogging
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
Expand Down Expand Up @@ -180,14 +183,30 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
}
}

/**
* Constructs the DartModeParams
* @return DartModeParams object containing parameters related to dart mode.
*/
protected def getDartParams(): DartModeParams = {
DartModeParams(getDropRate, getMaxDrop, getSkipDrop, getXGBoostDartMode, getUniformDrop)
}

/**
* Constructs the ExecutionParams.
* @return ExecutionParams object containing parameters related to LightGBM execution.
*/
protected def getExecutionParams(): ExecutionParams = {
ExecutionParams(getChunkSize, getMatrixType)
}

/**
* Constructs the ObjectiveParams.
* @return ObjectiveParams object containing parameters related to the objective function.
*/
protected def getObjectiveParams(): ObjectiveParams = {
ObjectiveParams(getObjective, if (isDefined(fobj)) Some(getFObj) else None)
}

/**
* Inner train method for LightGBM learners. Calculates the number of workers,
* creates a driver thread, and runs mapPartitions on the dataset.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

package com.microsoft.ml.spark.lightgbm

import com.microsoft.ml.spark.lightgbm.booster.LightGBMBooster
import com.microsoft.ml.spark.lightgbm.params.{ClassifierTrainParams, LightGBMModelParams,
LightGBMPredictionParams, TrainParams}
import com.microsoft.ml.spark.logging.BasicLogging
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable}
import org.apache.spark.ml.param._
Expand Down Expand Up @@ -46,11 +49,11 @@ class LightGBMClassifier(override val uid: String)
ClassifierTrainParams(getParallelism, getTopK, getNumIterations, getLearningRate, getNumLeaves, getMaxBin,
getBinSampleCount, getBaggingFraction, getPosBaggingFraction, getNegBaggingFraction,
getBaggingFreq, getBaggingSeed, getEarlyStoppingRound, getImprovementTolerance,
getFeatureFraction, getMaxDepth, getMinSumHessianInLeaf, numTasks, getObjective, modelStr,
getFeatureFraction, getMaxDepth, getMinSumHessianInLeaf, numTasks, modelStr,
getIsUnbalance, getVerbosity, categoricalIndexes, actualNumClasses, getBoostFromAverage,
getBoostingType, getLambdaL1, getLambdaL2, getIsProvideTrainingMetric,
getMetric, getMinGainToSplit, getMaxDeltaStep, getMaxBinByFeature, getMinDataInLeaf, getSlotNames,
getDelegate, getDartParams(), getExecutionParams())
getDelegate, getDartParams(), getExecutionParams(), getObjectiveParams())
}

def getModel(trainParams: TrainParams, lightGBMBooster: LightGBMBooster): LightGBMClassificationModel = {
Expand Down Expand Up @@ -187,7 +190,7 @@ class LightGBMClassificationModel(override val uid: String)

object LightGBMClassificationModel extends ComplexParamsReadable[LightGBMClassificationModel] {
def loadNativeModelFromFile(filename: String): LightGBMClassificationModel = {
val uid = Identifiable.randomUID("LightGBMClassifier")
val uid = Identifiable.randomUID("LightGBMClassificationModel")
val session = SparkSession.builder().getOrCreate()
val textRdd = session.read.text(filename)
val text = textRdd.collect().map { row => row.getString(0) }.mkString("\n")
Expand All @@ -197,7 +200,7 @@ object LightGBMClassificationModel extends ComplexParamsReadable[LightGBMClassif
}

def loadNativeModelFromString(model: String): LightGBMClassificationModel = {
val uid = Identifiable.randomUID("LightGBMClassifier")
val uid = Identifiable.randomUID("LightGBMClassificationModel")
val lightGBMBooster = new LightGBMBooster(model)
val actualNumClasses = lightGBMBooster.numClasses
new LightGBMClassificationModel(uid).setLightGBMBooster(lightGBMBooster).setActualNumClasses(actualNumClasses)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

package com.microsoft.ml.spark.lightgbm

import com.microsoft.ml.lightgbm.SWIGTYPE_p_void
import com.microsoft.ml.spark.lightgbm.booster.LightGBMBooster
import com.microsoft.ml.spark.lightgbm.params.TrainParams
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.StructType
import org.slf4j.Logger
Expand Down Expand Up @@ -40,12 +41,12 @@ trait LightGBMDelegate extends Serializable {
}

def beforeTrainIteration(batchIndex: Int, partitionId: Int, curIters: Int, log: Logger,
trainParams: TrainParams, boosterPtr: Option[SWIGTYPE_p_void], hasValid: Boolean): Unit = {
trainParams: TrainParams, booster: LightGBMBooster, hasValid: Boolean): Unit = {
// override this function and write code
}

def afterTrainIteration(batchIndex: Int, partitionId: Int, curIters: Int, log: Logger,
trainParams: TrainParams, boosterPtr: Option[SWIGTYPE_p_void], hasValid: Boolean,
trainParams: TrainParams, booster: LightGBMBooster, hasValid: Boolean,
isFinished: Boolean,
trainEvalResults: Option[Map[String, Double]],
validEvalResults: Option[Map[String, Double]]): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

package com.microsoft.ml.spark.lightgbm

import com.microsoft.ml.spark.lightgbm.params.LightGBMModelParams
import org.apache.spark.internal.Logging
import org.apache.spark.ml.linalg.{Vector, Vectors}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

package com.microsoft.ml.spark.lightgbm

import com.microsoft.ml.spark.lightgbm.booster.LightGBMBooster
import com.microsoft.ml.spark.lightgbm.params.{LightGBMModelParams, LightGBMPredictionParams,
RankerTrainParams, TrainParams}
import com.microsoft.ml.spark.logging.BasicLogging
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Ranker, RankerModel}
import org.apache.spark.ml.param._
Expand Down Expand Up @@ -51,12 +54,13 @@ class LightGBMRanker(override val uid: String)
def getTrainParams(numTasks: Int, categoricalIndexes: Array[Int], dataset: Dataset[_]): TrainParams = {
val modelStr = if (getModelString == null || getModelString.isEmpty) None else get(modelString)
RankerTrainParams(getParallelism, getTopK, getNumIterations, getLearningRate, getNumLeaves,
getObjective, getMaxBin, getBinSampleCount, getBaggingFraction, getPosBaggingFraction, getNegBaggingFraction,
getMaxBin, getBinSampleCount, getBaggingFraction, getPosBaggingFraction, getNegBaggingFraction,
getBaggingFreq, getBaggingSeed, getEarlyStoppingRound, getImprovementTolerance,
getFeatureFraction, getMaxDepth, getMinSumHessianInLeaf, numTasks, modelStr,
getVerbosity, categoricalIndexes, getBoostingType, getLambdaL1, getLambdaL2, getMaxPosition, getLabelGain,
getIsProvideTrainingMetric, getMetric, getEvalAt, getMinGainToSplit, getMaxDeltaStep,
getMaxBinByFeature, getMinDataInLeaf, getSlotNames, getDelegate, getDartParams(), getExecutionParams())
getMaxBinByFeature, getMinDataInLeaf, getSlotNames, getDelegate, getDartParams(), getExecutionParams(),
getObjectiveParams())
}

def getModel(trainParams: TrainParams, lightGBMBooster: LightGBMBooster): LightGBMRankerModel = {
Expand Down Expand Up @@ -157,7 +161,7 @@ class LightGBMRankerModel(override val uid: String)

object LightGBMRankerModel extends ComplexParamsReadable[LightGBMRankerModel] {
def loadNativeModelFromFile(filename: String): LightGBMRankerModel = {
val uid = Identifiable.randomUID("LightGBMRanker")
val uid = Identifiable.randomUID("LightGBMRankerModel")
val session = SparkSession.builder().getOrCreate()
val textRdd = session.read.text(filename)
val text = textRdd.collect().map { row => row.getString(0) }.mkString("\n")
Expand All @@ -166,7 +170,7 @@ object LightGBMRankerModel extends ComplexParamsReadable[LightGBMRankerModel] {
}

def loadNativeModelFromString(model: String): LightGBMRankerModel = {
val uid = Identifiable.randomUID("LightGBMRanker")
val uid = Identifiable.randomUID("LightGBMRankerModel")
val lightGBMBooster = new LightGBMBooster(model)
new LightGBMRankerModel(uid).setLightGBMBooster(lightGBMBooster)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

package com.microsoft.ml.spark.lightgbm

import com.microsoft.ml.spark.lightgbm.booster.LightGBMBooster
import com.microsoft.ml.spark.lightgbm.params.{LightGBMModelParams, LightGBMPredictionParams,
RegressorTrainParams, TrainParams}
import com.microsoft.ml.spark.logging.BasicLogging
import org.apache.spark.ml.{BaseRegressor, ComplexParamsReadable, ComplexParamsWritable}
import org.apache.spark.ml.param._
Expand Down Expand Up @@ -58,12 +61,12 @@ class LightGBMRegressor(override val uid: String)
def getTrainParams(numTasks: Int, categoricalIndexes: Array[Int], dataset: Dataset[_]): TrainParams = {
val modelStr = if (getModelString == null || getModelString.isEmpty) None else get(modelString)
RegressorTrainParams(getParallelism, getTopK, getNumIterations, getLearningRate, getNumLeaves,
getObjective, getAlpha, getTweedieVariancePower, getMaxBin, getBinSampleCount,
getBaggingFraction, getPosBaggingFraction, getNegBaggingFraction, getBaggingFreq, getBaggingSeed,
getEarlyStoppingRound, getImprovementTolerance, getFeatureFraction, getMaxDepth, getMinSumHessianInLeaf,
numTasks, modelStr, getVerbosity, categoricalIndexes, getBoostFromAverage, getBoostingType, getLambdaL1,
getLambdaL2, getIsProvideTrainingMetric, getMetric, getMinGainToSplit, getMaxDeltaStep,
getMaxBinByFeature, getMinDataInLeaf, getSlotNames, getDelegate, getDartParams(), getExecutionParams())
getAlpha, getTweedieVariancePower, getMaxBin, getBinSampleCount, getBaggingFraction, getPosBaggingFraction,
getNegBaggingFraction, getBaggingFreq, getBaggingSeed, getEarlyStoppingRound, getImprovementTolerance,
getFeatureFraction, getMaxDepth, getMinSumHessianInLeaf, numTasks, modelStr, getVerbosity, categoricalIndexes,
getBoostFromAverage, getBoostingType, getLambdaL1, getLambdaL2, getIsProvideTrainingMetric, getMetric,
getMinGainToSplit, getMaxDeltaStep, getMaxBinByFeature, getMinDataInLeaf, getSlotNames, getDelegate,
getDartParams(), getExecutionParams(), getObjectiveParams())
}

def getModel(trainParams: TrainParams, lightGBMBooster: LightGBMBooster): LightGBMRegressionModel = {
Expand Down Expand Up @@ -134,7 +137,7 @@ class LightGBMRegressionModel(override val uid: String)

object LightGBMRegressionModel extends ComplexParamsReadable[LightGBMRegressionModel] {
def loadNativeModelFromFile(filename: String): LightGBMRegressionModel = {
val uid = Identifiable.randomUID("LightGBMRegressor")
val uid = Identifiable.randomUID("LightGBMRegressionModel")
val session = SparkSession.builder().getOrCreate()
val textRdd = session.read.text(filename)
val text = textRdd.collect().map { row => row.getString(0) }.mkString("\n")
Expand All @@ -143,7 +146,7 @@ object LightGBMRegressionModel extends ComplexParamsReadable[LightGBMRegressionM
}

def loadNativeModelFromString(model: String): LightGBMRegressionModel = {
val uid = Identifiable.randomUID("LightGBMRegressor")
val uid = Identifiable.randomUID("LightGBMRegressionModel")
val lightGBMBooster = new LightGBMBooster(model)
new LightGBMRegressionModel(uid).setLightGBMBooster(lightGBMBooster)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import com.microsoft.ml.lightgbm._
import com.microsoft.ml.spark.core.env.NativeLoader
import com.microsoft.ml.spark.core.utils.ClusterUtil
import com.microsoft.ml.spark.featurize.{Featurize, FeaturizeUtilities}
import com.microsoft.ml.spark.lightgbm.dataset.LightGBMDataset
import com.microsoft.ml.spark.lightgbm.params.TrainParams
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.ml.PipelineModel
import org.apache.spark.ml.attribute._
Expand Down Expand Up @@ -245,7 +247,7 @@ object LightGBMUtils {
LightGBMUtils.validate(lightgbmlib.LGBM_DatasetCreateFromMats(featuresArray.get_chunks_count().toInt,
featuresArray.data_as_void(), data64bitType,
numRowsForChunks, numCols,
isRowMajor, datasetParams, referenceDataset.map(_.dataset).orNull, datasetOutPtr),
isRowMajor, datasetParams, referenceDataset.map(_.datasetPtr).orNull, datasetOutPtr),
"Dataset create")
} finally {
featuresArray.release()
Expand Down Expand Up @@ -275,7 +277,7 @@ object LightGBMUtils {
LightGBMUtils.validate(lightgbmlib.LGBM_DatasetCreateFromCSRSpark(
sparseRows.asInstanceOf[Array[Object]],
sparseRows.length,
numCols, datasetParams, referenceDataset.map(_.dataset).orNull,
numCols, datasetParams, referenceDataset.map(_.datasetPtr).orNull,
datasetOutPtr),
"Dataset create")
val dataset = new LightGBMDataset(lightgbmlib.voidpp_value(datasetOutPtr))
Expand Down
Loading

0 comments on commit 392f4a5

Please sign in to comment.