Skip to content

Commit

Permalink
feat: Add method to get lightgbm native model string directly (#1515)
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft authored May 22, 2022
1 parent 8e7151b commit 5f61e6f
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 15 deletions.
6 changes: 6 additions & 0 deletions lightgbm/src/main/python/synapse/ml/lightgbm/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ def saveNativeModel(self, filename, overwrite=True):
"""
self._java_obj.saveNativeModel(filename, overwrite)

def getNativeModel(self):
"""
Get the native model serialized representation as a string.
"""
return self._java_obj.getNativeModel()

def getFeatureImportances(self, importance_type="split"):
"""
Get the feature importances as a list. The importance_type can be "split" or "gain".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,6 @@ class LightGBMClassificationModel(override val uid: String)
udf(predict _).apply(col(getFeaturesCol))
}
}

def saveNativeModel(filename: String, overwrite: Boolean): Unit = {
val session = SparkSession.builder().getOrCreate()
getModel.saveNativeModel(session, filename, overwrite)
}
}

object LightGBMClassificationModel extends ComplexParamsReadable[LightGBMClassificationModel] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package com.microsoft.azure.synapse.ml.lightgbm
import com.microsoft.azure.synapse.ml.lightgbm.params.LightGBMModelParams
import org.apache.spark.internal.Logging
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.sql.SparkSession

/** Contains common LightGBM model methods across all LightGBM learner types.
*/
Expand Down Expand Up @@ -91,6 +92,22 @@ trait LightGBMModelMethods extends LightGBMModelParams with Logging {
getLightGBMBooster.numClasses
}

/** Saves the native model serialized representation to file.
* @param session The spark session
* @param filename The name of the file to save the model to
* @param overwrite Whether to overwrite if the file already exists
*/
def saveNativeModel(filename: String, overwrite: Boolean): Unit = {
val session = SparkSession.builder().getOrCreate()
getModel.saveNativeModel(session, filename, overwrite)
}

/** Gets the native model serialized representation as a string.
*/
def getNativeModel(): String = {
getModel.getNativeModel()
}

/**
* Protected method to predict leaf index.
* @param features The local instance or row to compute the leaf index for.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,6 @@ class LightGBMRankerModel(override val uid: String)
override def copy(extra: ParamMap): LightGBMRankerModel = defaultCopy(extra)

override def numFeatures: Int = getModel.numFeatures

def saveNativeModel(filename: String, overwrite: Boolean): Unit = {
val session = SparkSession.builder().getOrCreate()
getModel.saveNativeModel(session, filename, overwrite)
}
}

object LightGBMRankerModel extends ComplexParamsReadable[LightGBMRankerModel] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,6 @@ class LightGBMRegressionModel(override val uid: String)
}

override def copy(extra: ParamMap): LightGBMRegressionModel = defaultCopy(extra)

def saveNativeModel(filename: String, overwrite: Boolean): Unit = {
val session = SparkSession.builder().getOrCreate()
getModel.saveNativeModel(session, filename, overwrite)
}
}

object LightGBMRegressionModel extends ComplexParamsReadable[LightGBMRegressionModel] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,12 @@ class LightGBMBooster(val trainDataset: Option[LightGBMDataset] = None, val para
dataset.coalesce(1).write.mode(mode).text(filename)
}

/** Gets the native model serialized representation as a string.
*/
def getNativeModel(): String = {
modelStr.get
}

/** Dumps the native model pointer to file.
* @param session The spark session
* @param filename The name of the file to save the model to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -745,9 +745,13 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM
val modelPath = targetDir.toString + "/" + outputFileName
FileUtils.deleteDirectory(new File(modelPath))
fitModel.saveNativeModel(modelPath, overwrite = true)
val retrievedModelStr = fitModel.getNativeModel()
assert(Files.exists(Paths.get(modelPath)), true)

val oldModelString = fitModel.getModel.modelStr.get
// Assert model string is equal when retrieved from booster and getNativeModel API
assert(retrievedModelStr == oldModelString)

// Verify model string contains some feature
colsToVerify.foreach(col => oldModelString.contains(col))

Expand Down

0 comments on commit 5f61e6f

Please sign in to comment.