diff --git a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMBase.scala b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMBase.scala index 0ad6ac64ded..e0a549579f4 100644 --- a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMBase.scala +++ b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMBase.scala @@ -255,6 +255,11 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine ExecutionParams(getChunkSize, getMatrixType, execNumThreads, getUseSingleDatasetMode) } + /** + * Constructs the ColumnParams. + * + * @return ColumnParams object containing the parameters related to LightGBM columns. + */ protected def getColumnParams: ColumnParams = { ColumnParams(getLabelCol, getFeaturesCol, get(weightCol), get(initScoreCol), getOptGroupCol) } @@ -268,13 +273,25 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine ObjectiveParams(getObjective, if (isDefined(fobj)) Some(getFObj) else None) } + /** + * Constructs the SeedParams. + * + * @return SeedParams object containing the parameters related to LightGBM seeds and determinism. + */ + protected def getSeedParams: SeedParams = { + SeedParams(get(seed), get(deterministic), get(baggingSeed), get(featureFractionSeed), + get(extraSeed), get(dropSeed), get(dataRandomSeed), get(objectiveSeed), getBoostingType, getObjective) + } + def getDatasetParams(categoricalIndexes: Array[Int], numThreads: Int): String = { + val seedParam = get(dataRandomSeed).orElse(get(seed)) val datasetParams = s"max_bin=$getMaxBin is_pre_partition=True " + s"bin_construct_sample_cnt=$getBinSampleCount " + s"min_data_in_leaf=$getMinDataInLeaf " + s"num_threads=$numThreads " + (if (categoricalIndexes.isEmpty) "" - else s"categorical_feature=${categoricalIndexes.mkString(",")}") + else s"categorical_feature=${categoricalIndexes.mkString(",")} ") + + seedParam.map(dataRandomSeedOpt => s"data_random_seed=$dataRandomSeedOpt ").getOrElse("") datasetParams } @@ -424,7 +441,7 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine } } // Concatenate with commas, eg: host1:port1,host2:port2, ... etc - val allConnections = hostAndPorts.map(_._2).mkString(",") + val allConnections = hostAndPorts.map(_._2).sorted.mkString(",") log.info(s"driver writing back to all connections: $allConnections") // Send data back to all tasks and helper tasks on executors sendDataToExecutors(hostAndPorts, allConnections) diff --git a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMClassifier.scala b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMClassifier.scala index 1bbc397849f..638cacede00 100644 --- a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMClassifier.scala +++ b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMClassifier.scala @@ -54,7 +54,7 @@ class LightGBMClassifier(override val uid: String) getIsUnbalance, getVerbosity, categoricalIndexes, actualNumClasses, getBoostFromAverage, getBoostingType, get(lambdaL1), get(lambdaL2), get(isProvideTrainingMetric), get(metric), get(minGainToSplit), get(maxDeltaStep), getMaxBinByFeature, get(minDataInLeaf), getSlotNames, - getDelegate, getDartParams, getExecutionParams(numTasksPerExec), getObjectiveParams) + getDelegate, getDartParams, getExecutionParams(numTasksPerExec), getObjectiveParams, getSeedParams) } def getModel(trainParams: TrainParams, lightGBMBooster: LightGBMBooster): LightGBMClassificationModel = { diff --git a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMRanker.scala b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMRanker.scala index 708e281f536..d37a5aab84b 100644 --- a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMRanker.scala +++ b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMRanker.scala @@ -60,7 +60,7 @@ class LightGBMRanker(override val uid: String) getVerbosity, categoricalIndexes, getBoostingType, get(lambdaL1), get(lambdaL2), getMaxPosition, getLabelGain, get(isProvideTrainingMetric), get(metric), getEvalAt, get(minGainToSplit), get(maxDeltaStep), getMaxBinByFeature, get(minDataInLeaf), getSlotNames, getDelegate, getDartParams, - getExecutionParams(numTasksPerExec), getObjectiveParams) + getExecutionParams(numTasksPerExec), getObjectiveParams, getSeedParams) } def getModel(trainParams: TrainParams, lightGBMBooster: LightGBMBooster): LightGBMRankerModel = { diff --git a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMRegressor.scala b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMRegressor.scala index 6b070ccfe90..e087a68e1c7 100644 --- a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMRegressor.scala +++ b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMRegressor.scala @@ -70,7 +70,7 @@ class LightGBMRegressor(override val uid: String) getBoostFromAverage, getBoostingType, get(lambdaL1), get(lambdaL2), get(isProvideTrainingMetric), get(metric), get(minGainToSplit), get(maxDeltaStep), getMaxBinByFeature, get(minDataInLeaf), getSlotNames, getDelegate, - getDartParams, getExecutionParams(numTasksPerExec), getObjectiveParams) + getDartParams, getExecutionParams(numTasksPerExec), getObjectiveParams, getSeedParams) } def getModel(trainParams: TrainParams, lightGBMBooster: LightGBMBooster): LightGBMRegressionModel = { diff --git a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMUtils.scala b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMUtils.scala index f26c02533cd..f5cb3b0f138 100644 --- a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMUtils.scala +++ b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMUtils.scala @@ -116,6 +116,17 @@ object LightGBMUtils { idAsInt } + /** Returns the partition ID for the spark Dataset. + * + * Used to make operations deterministic on same dataset. + * + * @return Returns the partition id. + */ + def getPartitionId: Int = { + val ctx = TaskContext.get + ctx.partitionId + } + /** Returns true if spark is run in local mode. * @return True if spark is run in local mode. */ diff --git a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/dataset/DatasetAggregator.scala b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/dataset/DatasetAggregator.scala index b5abca6a8fa..56ba3a56345 100644 --- a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/dataset/DatasetAggregator.scala +++ b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/dataset/DatasetAggregator.scala @@ -15,6 +15,7 @@ import org.apache.spark.sql.types.StructType import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable.ListBuffer +import scala.collection.concurrent.TrieMap private[lightgbm] object ChunkedArrayUtils { def copyChunkedArray[T: Numeric](chunkedArray: ChunkedArray[T], @@ -193,6 +194,8 @@ private[lightgbm] abstract class BaseAggregatedColumns(val chunkSize: Int) exten */ protected val rowCount = new AtomicLong(0L) protected val initScoreCount = new AtomicLong(0L) + protected val pIdToRowCountOffset = new TrieMap[Long, Long]() + protected val pIdToInitScoreCountOffset = new TrieMap[Long, Long]() protected var numCols = 0 @@ -216,7 +219,10 @@ private[lightgbm] abstract class BaseAggregatedColumns(val chunkSize: Int) exten def incrementCount(chunkedCols: BaseChunkedColumns): Unit = { rowCount.addAndGet(chunkedCols.rowCount) + pIdToRowCountOffset.update(LightGBMUtils.getPartitionId, chunkedCols.rowCount) initScoreCount.addAndGet(chunkedCols.numInitScores) + pIdToInitScoreCountOffset.update( + LightGBMUtils.getPartitionId, chunkedCols.numInitScores) } def addRows(chunkedCols: BaseChunkedColumns): Unit = { @@ -232,6 +238,18 @@ private[lightgbm] abstract class BaseAggregatedColumns(val chunkSize: Int) exten initScores = chunkedCols.initScores.map(_ => new DoubleSwigArray(isc)) initializeFeatures(chunkedCols, rc) groups = new Array[Any](rc.toInt) + updateConcurrentMapOffsets(pIdToRowCountOffset) + updateConcurrentMapOffsets(pIdToInitScoreCountOffset) + } + + protected def updateConcurrentMapOffsets(concurrentIdToOffset: TrieMap[Long, Long], + initialValue: Long = 0L): Unit = { + val sortedKeys = concurrentIdToOffset.keys.toSeq.sorted + sortedKeys.foldRight(initialValue: Long)((key, offset) => { + val partitionRowCount = concurrentIdToOffset(key) + concurrentIdToOffset.update(key, offset) + partitionRowCount + offset + }) } } @@ -254,12 +272,6 @@ private[lightgbm] trait DisjointAggregatedColumns extends BaseAggregatedColumns } private[lightgbm] trait SyncAggregatedColumns extends BaseAggregatedColumns { - /** - * Variables for current thread to use in order to update common arrays in parallel - */ - protected val threadRowStartIndex = new AtomicLong(0L) - protected val threadInitScoreStartIndex = new AtomicLong(0L) - /** Adds the rows to the internal data structure. */ override def addRows(chunkedCols: BaseChunkedColumns): Unit = { @@ -289,10 +301,9 @@ private[lightgbm] trait SyncAggregatedColumns extends BaseAggregatedColumns { var threadInitScoreStartIndex = 0L val featureIndexes = this.synchronized { - val labelsSize = chunkedCols.labels.getAddCount - threadRowStartIndex = this.threadRowStartIndex.getAndAdd(labelsSize.toInt) - val initScoreSize = chunkedCols.initScores.map(_.getAddCount) - initScoreSize.foreach(size => threadInitScoreStartIndex = this.threadInitScoreStartIndex.getAndAdd(size)) + val partitionId = LightGBMUtils.getPartitionId + threadRowStartIndex = pIdToRowCountOffset.get(partitionId).get + threadInitScoreStartIndex = chunkedCols.initScores.map(_ => pIdToInitScoreCountOffset(partitionId)).getOrElse(0) updateThreadLocalIndices(chunkedCols, threadRowStartIndex) } ChunkedArrayUtils.copyChunkedArray(chunkedCols.labels, labels, threadRowStartIndex, chunkSize) @@ -393,6 +404,8 @@ private[lightgbm] abstract class BaseSparseAggregatedColumns(chunkSize: Int) */ protected var indexesCount = new AtomicLong(0L) protected var indptrCount = new AtomicLong(0L) + protected val pIdToIndexesCountOffset = new TrieMap[Long, Long]() + protected val pIdToIndptrCountOffset = new TrieMap[Long, Long]() def getNumColsFromChunkedArray(chunkedCols: BaseChunkedColumns): Int = { chunkedCols.asInstanceOf[SparseChunkedColumns].numCols @@ -402,7 +415,9 @@ private[lightgbm] abstract class BaseSparseAggregatedColumns(chunkSize: Int) super.incrementCount(chunkedCols) val sparseChunkedCols = chunkedCols.asInstanceOf[SparseChunkedColumns] indexesCount.addAndGet(sparseChunkedCols.getNumIndexes) + pIdToIndexesCountOffset.update(LightGBMUtils.getPartitionId, sparseChunkedCols.getNumIndexes) indptrCount.addAndGet(sparseChunkedCols.getNumIndexPointers) + pIdToIndptrCountOffset.update(LightGBMUtils.getPartitionId, sparseChunkedCols.getNumIndexPointers) } protected def initializeFeatures(chunkedCols: BaseChunkedColumns, rowCount: Long): Unit = { @@ -412,6 +427,8 @@ private[lightgbm] abstract class BaseSparseAggregatedColumns(chunkSize: Int) values = new DoubleSwigArray(indexesCount) indexPointers = new IntSwigArray(indptrCount) indexPointers.setItem(0, 0) + updateConcurrentMapOffsets(pIdToIndexesCountOffset) + updateConcurrentMapOffsets(pIdToIndptrCountOffset, 1L) } def getIndexes: IntSwigArray = indexes @@ -489,12 +506,6 @@ private[lightgbm] final class SparseAggregatedColumns(chunkSize: Int) */ private[lightgbm] final class SparseSyncAggregatedColumns(chunkSize: Int) extends BaseSparseAggregatedColumns(chunkSize) with SyncAggregatedColumns { - /** - * Variables for current thread to use in order to update common arrays in parallel - */ - protected val threadIndexesStartIndex = new AtomicLong(0L) - protected val threadIndptrStartIndex = new AtomicLong(1L) - override protected def initializeRows(chunkedCols: BaseChunkedColumns): Unit = { // Add extra 0 for start of indptr in parallel case this.indptrCount.addAndGet(1L) @@ -502,12 +513,9 @@ private[lightgbm] final class SparseSyncAggregatedColumns(chunkSize: Int) } protected def updateThreadLocalIndices(chunkedCols: BaseChunkedColumns, threadRowStartIndex: Long): List[Long] = { - val sparseChunkedCols = chunkedCols.asInstanceOf[SparseChunkedColumns] - val indexesSize = sparseChunkedCols.indexes.getAddCount - val threadIndexesStartIndex = this.threadIndexesStartIndex.getAndAdd(indexesSize) - - val indPtrSize = sparseChunkedCols.indexPointers.getAddCount - val threadIndPtrStartIndex = this.threadIndptrStartIndex.getAndAdd(indPtrSize) + val partitionId = LightGBMUtils.getPartitionId + val threadIndexesStartIndex = pIdToIndexesCountOffset.get(partitionId).get + val threadIndPtrStartIndex = pIdToIndptrCountOffset.get(partitionId).get List(threadIndexesStartIndex, threadIndPtrStartIndex) } diff --git a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/dataset/DatasetUtils.scala b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/dataset/DatasetUtils.scala index a6664fb88da..f481f2c9844 100644 --- a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/dataset/DatasetUtils.scala +++ b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/dataset/DatasetUtils.scala @@ -4,7 +4,6 @@ package com.microsoft.azure.synapse.ml.lightgbm.dataset import com.microsoft.azure.synapse.ml.lightgbm.ColumnParams -import com.microsoft.azure.synapse.ml.lightgbm.swig.DoubleChunkedArray import com.microsoft.ml.lightgbm.{doubleChunkedArray, floatChunkedArray} import org.apache.spark.ml.linalg.SQLDataTypes.VectorType import org.apache.spark.ml.linalg.{DenseVector, SparseVector} diff --git a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/params/LightGBMParams.scala b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/params/LightGBMParams.scala index dc092aa0c92..a48981cccbd 100644 --- a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/params/LightGBMParams.scala +++ b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/params/LightGBMParams.scala @@ -317,12 +317,71 @@ trait LightGBMObjectiveParams extends Wrappable { def setFObj(value: FObjTrait): this.type = set(fobj, value) } +/** Defines common parameters related to seed and determinism + */ +trait LightGBMSeedParams extends Wrappable { + val seed = new IntParam(this, "seed", "Main seed, used to generate other seeds") + + def getSeed: Int = $(seed) + def setSeed(value: Int): this.type = set(seed, value) + + val deterministic = new BooleanParam(this, "deterministic", "Used only with cpu " + + "devide type. Setting this to true should ensure stable results when using the same data and the " + + "same parameters. Note: setting this to true may slow down training. To avoid potential instability " + + "due to numerical issues, please set force_col_wise=true or force_row_wise=true when setting " + + "deterministic=true") + setDefault(deterministic->false) + + def getDeterministic: Boolean = $(deterministic) + def setDeterministic(value: Boolean): this.type = set(deterministic, value) + + val baggingSeed = new IntParam(this, "baggingSeed", "Bagging seed") + setDefault(baggingSeed->3) + + def getBaggingSeed: Int = $(baggingSeed) + def setBaggingSeed(value: Int): this.type = set(baggingSeed, value) + + val featureFractionSeed = new IntParam(this, "featureFractionSeed", "Feature fraction seed") + setDefault(featureFractionSeed->2) + + def getFeatureFractionSeed: Int = $(featureFractionSeed) + def setFeatureFractionSeed(value: Int): this.type = set(featureFractionSeed, value) + + val extraSeed = new IntParam(this, "extraSeed", "Random seed for selecting threshold " + + "when extra_trees is true") + setDefault(extraSeed->6) + + def getExtraSeed: Int = $(extraSeed) + def setExtraSeed(value: Int): this.type = set(extraSeed, value) + + val dropSeed = new IntParam(this, "dropSeed", "Random seed to choose dropping models. " + + "Only used in dart.") + setDefault(dropSeed->4) + + def getDropSeed: Int = $(dropSeed) + def setDropSeed(value: Int): this.type = set(dropSeed, value) + + val dataRandomSeed = new IntParam(this, "dataRandomSeed", "Random seed for sampling " + + "data to construct histogram bins.") + setDefault(dataRandomSeed->1) + + def getDataRandomSeed: Int = $(dataRandomSeed) + def setDataRandomSeed(value: Int): this.type = set(dataRandomSeed, value) + + val objectiveSeed = new IntParam(this, "objectiveSeed", "Random seed for objectives, " + + "if random process is needed. Currently used only for rank_xendcg objective.") + setDefault(objectiveSeed->5) + + def getObjectiveSeed: Int = $(objectiveSeed) + def setObjectiveSeed(value: Int): this.type = set(objectiveSeed, value) +} + /** Defines common parameters across all LightGBM learners. */ trait LightGBMParams extends Wrappable with DefaultParamsWritable with HasWeightCol with HasValidationIndicatorCol with HasInitScoreCol with LightGBMExecutionParams with LightGBMSlotParams with LightGBMFractionParams with LightGBMBinParams with LightGBMLearnerParams - with LightGBMDartParams with LightGBMPredictionParams with LightGBMObjectiveParams { + with LightGBMDartParams with LightGBMPredictionParams with LightGBMObjectiveParams with LightGBMSeedParams { val numIterations = new IntParam(this, "numIterations", "Number of iterations, LightGBM constructs num_class * num_iterations trees") setDefault(numIterations->100) @@ -348,12 +407,6 @@ trait LightGBMParams extends Wrappable with DefaultParamsWritable with HasWeight def getBaggingFreq: Int = $(baggingFreq) def setBaggingFreq(value: Int): this.type = set(baggingFreq, value) - val baggingSeed = new IntParam(this, "baggingSeed", "Bagging seed") - setDefault(baggingSeed->3) - - def getBaggingSeed: Int = $(baggingSeed) - def setBaggingSeed(value: Int): this.type = set(baggingSeed, value) - val maxDepth = new IntParam(this, "maxDepth", "Max depth") setDefault(maxDepth-> -1) diff --git a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/params/TrainParams.scala b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/params/TrainParams.scala index 81b5a0a3430..04a9f41d548 100644 --- a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/params/TrainParams.scala +++ b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/params/TrainParams.scala @@ -5,6 +5,22 @@ package com.microsoft.azure.synapse.ml.lightgbm.params import com.microsoft.azure.synapse.ml.lightgbm.{LightGBMConstants, LightGBMDelegate} +/** Helper utilities for converting params to a string, to be passed to LightGBM. */ +object ParamUtils { + def paramToString[T](paramName: String, paramValueOpt: Option[T]): String = { + paramValueOpt match { + case Some(paramValue) => s"$paramName=$paramValue" + case None => "" + } + } + + def paramsToString(paramNamesToValues: Array[(String, Option[_])]): String = { + paramNamesToValues.map { + case (paramName: String, paramValue: Option[_]) => paramToString(paramName, paramValue) + }.mkString(" ") + } +} + /** Defines the common Booster parameters passed to the LightGBM learners. */ abstract class TrainParams extends Serializable { @@ -43,25 +59,13 @@ abstract class TrainParams extends Serializable { def dartModeParams: DartModeParams def executionParams: ExecutionParams def objectiveParams: ObjectiveParams - - def paramToString[T](paramName: String, paramValueOpt: Option[T]): String = { - paramValueOpt match { - case Some(paramValue) => s"$paramName=$paramValue" - case None => "" - } - } - - def paramsToString(paramNamesToValues: Array[(String, Option[_])]): String = { - paramNamesToValues.map { - case (paramName: String, paramValue: Option[_]) => paramToString(paramName, paramValue) - }.mkString(" ") - } + def seedParams: SeedParams override def toString: String = { // Since passing `isProvideTrainingMetric` to LightGBM as a config parameter won't work, // let's fetch and print training metrics in `TrainUtils.scala` through JNI. s"is_pre_partition=True boosting_type=$boostingType tree_learner=$parallelism " + - paramsToString(Array(("top_k", topK), ("num_leaves", numLeaves), ("max_bin", maxBin), + ParamUtils.paramsToString(Array(("top_k", topK), ("num_leaves", numLeaves), ("max_bin", maxBin), ("bagging_fraction", baggingFraction), ("pos_bagging_fraction", posBaggingFraction), ("neg_bagging_fraction", negBaggingFraction), ("bagging_freq", baggingFreq), ("bagging_seed", baggingSeed), ("feature_fraction", featureFraction), ("max_depth", maxDepth), @@ -75,7 +79,8 @@ abstract class TrainParams extends Serializable { (if (categoricalFeatures.isEmpty) "" else s"categorical_feature=${categoricalFeatures.mkString(",")} ") + (if (maxBinByFeature.isEmpty) "" else s"max_bin_by_feature=${maxBinByFeature.mkString(",")} ") + (if (boostingType == "dart") s"${dartModeParams.toString()} " else "") + - executionParams.toString() + executionParams.toString() + + seedParams.toString() } } @@ -118,7 +123,8 @@ case class ClassifierTrainParams(parallelism: String, delegate: Option[LightGBMDelegate], dartModeParams: DartModeParams, executionParams: ExecutionParams, - objectiveParams: ObjectiveParams) + objectiveParams: ObjectiveParams, + seedParams: SeedParams) extends TrainParams { override def toString: String = { val extraStr = @@ -167,7 +173,8 @@ case class RegressorTrainParams(parallelism: String, delegate: Option[LightGBMDelegate], dartModeParams: DartModeParams, executionParams: ExecutionParams, - objectiveParams: ObjectiveParams) + objectiveParams: ObjectiveParams, + seedParams: SeedParams) extends TrainParams { override def toString: String = { s"alpha=$alpha tweedie_variance_power=$tweedieVariancePower boost_from_average=${boostFromAverage.toString} " + @@ -214,7 +221,8 @@ case class RankerTrainParams(parallelism: String, delegate: Option[LightGBMDelegate], dartModeParams: DartModeParams, executionParams: ExecutionParams, - objectiveParams: ObjectiveParams) + objectiveParams: ObjectiveParams, + seedParams: SeedParams) extends TrainParams { override def toString: String = { val labelGainStr = @@ -266,3 +274,34 @@ case class ObjectiveParams(objective: String, fobj: Option[FObjTrait]) extends S } } } + +/** Defines parameters related to seed and determinism for lightgbm. + * + * @param seed Main seed, used to generate other seeds. + * + * @param deterministic Setting this to true should ensure stable results when using the + * same data and the same parameters. + * @param baggingSeed Bagging seed. + * @param featureFractionSeed Feature fraction seed. + * @param extraSeed Random seed for selecting threshold when extra_trees is true. + * @param dropSeed Random seed to choose dropping models. Only used in dart. + * @param dataRandomSeed Random seed for sampling data to construct histogram bins. + * @param objectiveSeed Random seed for objectives, if random process is needed. + * Currently used only for rank_xendcg objective. + * @param boostingType Boosting type, used to determine if drop seed should be set. + * @param objective Objective, used to determine if objective seed should be set. + */ +case class SeedParams(seed: Option[Int], deterministic: Option[Boolean], + baggingSeed: Option[Int], featureFractionSeed: Option[Int], + extraSeed: Option[Int], dropSeed: Option[Int], + dataRandomSeed: Option[Int], objectiveSeed: Option[Int], + boostingType: String, objective: String) extends Serializable { + override def toString: String = { + ParamUtils.paramsToString(Array(("seed", seed), ("deterministic", deterministic), + ("bagging_seed", baggingSeed), ("feature_fraction_seed", featureFractionSeed), + ("extra_seed", extraSeed), ("data_random_seed", dataRandomSeed))) + + (if (boostingType == "dart" && dropSeed.isDefined) s"drop_seed=${dropSeed.toString()} " else "") + + (if (objective == "rank_xendcg" && objectiveSeed.isDefined) + s"objective_seed=${objectiveSeed.toString()} " else "") + } +} diff --git a/lightgbm/src/test/resources/log4j.properties b/lightgbm/src/test/resources/log4j.properties new file mode 100644 index 00000000000..d777fb49fb6 --- /dev/null +++ b/lightgbm/src/test/resources/log4j.properties @@ -0,0 +1,8 @@ +log4j.appender.stdout=org.apache.log4j.ConsoleAppender +log4j.appender.stdout.Target=System.out +log4j.appender.stdout.layout=org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=%d{HH:mm:ss} %-5p %c{1}:%L - %m%n + +log4j.rootLogger=WARN, stdout +log4j.logger.org.apache.spark=WARN, stdout +log4j.logger.com.microsoft=INFO, stdout diff --git a/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split1/VerifyLightGBMClassifier.scala b/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split1/VerifyLightGBMClassifier.scala index 92860bc47c7..9684e8118ff 100644 --- a/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split1/VerifyLightGBMClassifier.scala +++ b/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split1/VerifyLightGBMClassifier.scala @@ -283,6 +283,10 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM assert(binaryEvaluator.evaluate(sdf1) < binaryEvaluator.evaluate(sdf2)) } + def assertBinaryEquality(sdf1: DataFrame, sdf2: DataFrame): Unit = { + assert(Math.abs(binaryEvaluator.evaluate(sdf1) - binaryEvaluator.evaluate(sdf2)) < 1e-10) + } + def assertMulticlassImprovement(sdf1: DataFrame, sdf2: DataFrame): Unit = { assert(multiclassEvaluator.evaluate(sdf1) < multiclassEvaluator.evaluate(sdf2)) } @@ -351,6 +355,14 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM assertBinaryImprovement(scoredDF1, scoredDF2) } + test("Verify LightGBM Classifier will give reproducible results when setting seed") { + val scoredDF1 = baseModel.setSeed(1).setDeterministic(true).fit(pimaDF).transform(pimaDF) + (1 to 10).foreach { i => + val scoredDF2 = baseModel.setSeed(1).setDeterministic(true).fit(pimaDF).transform(pimaDF) + assertBinaryEquality(scoredDF1, scoredDF2); + } + } + test("Verify LightGBM Classifier with dart mode parameters") { // Assert the dart parameters work without failing and setting them to tuned values improves performance val Array(train, test) = pimaDF.randomSplit(Array(0.8, 0.2), seed) @@ -688,7 +700,6 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM val fitModel = model.fit(df) val tdf = fitModel.transform(df) - assertProbabilities(tdf, model) assertImportanceLengths(fitModel, df) @@ -723,7 +734,6 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM val fitModel = model.fit(df) val tdf = fitModel.transform(df) - assertProbabilities(tdf, model) assertImportanceLengths(fitModel, df) @@ -782,8 +792,6 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM assert(resultsFromFile === resultsOriginal) } } - override def reader: MLReadable[_] = LightGBMClassifier - override def modelReader: MLReadable[_] = LightGBMClassificationModel }