Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add method to get lightgbm native model string directly #1515

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lightgbm/src/main/python/synapse/ml/lightgbm/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ def saveNativeModel(self, filename, overwrite=True):
"""
self._java_obj.saveNativeModel(filename, overwrite)

def getNativeModel(self):
"""
Get the native model serialized representation as a string.
"""
return self._java_obj.getNativeModel()

def getFeatureImportances(self, importance_type="split"):
"""
Get the feature importances as a list. The importance_type can be "split" or "gain".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,6 @@ class LightGBMClassificationModel(override val uid: String)
udf(predict _).apply(col(getFeaturesCol))
}
}

def saveNativeModel(filename: String, overwrite: Boolean): Unit = {
val session = SparkSession.builder().getOrCreate()
getModel.saveNativeModel(session, filename, overwrite)
}
}

object LightGBMClassificationModel extends ComplexParamsReadable[LightGBMClassificationModel] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package com.microsoft.azure.synapse.ml.lightgbm
import com.microsoft.azure.synapse.ml.lightgbm.params.LightGBMModelParams
import org.apache.spark.internal.Logging
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.sql.SparkSession

/** Contains common LightGBM model methods across all LightGBM learner types.
*/
Expand Down Expand Up @@ -91,6 +92,22 @@ trait LightGBMModelMethods extends LightGBMModelParams with Logging {
getLightGBMBooster.numClasses
}

/** Saves the native model serialized representation to file.
* @param session The spark session
* @param filename The name of the file to save the model to
* @param overwrite Whether to overwrite if the file already exists
*/
def saveNativeModel(filename: String, overwrite: Boolean): Unit = {
val session = SparkSession.builder().getOrCreate()
getModel.saveNativeModel(session, filename, overwrite)
}

/** Gets the native model serialized representation as a string.
*/
def getNativeModel(): String = {
getModel.getNativeModel()
}

/**
* Protected method to predict leaf index.
* @param features The local instance or row to compute the leaf index for.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,6 @@ class LightGBMRankerModel(override val uid: String)
override def copy(extra: ParamMap): LightGBMRankerModel = defaultCopy(extra)

override def numFeatures: Int = getModel.numFeatures

def saveNativeModel(filename: String, overwrite: Boolean): Unit = {
val session = SparkSession.builder().getOrCreate()
getModel.saveNativeModel(session, filename, overwrite)
}
}

object LightGBMRankerModel extends ComplexParamsReadable[LightGBMRankerModel] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,6 @@ class LightGBMRegressionModel(override val uid: String)
}

override def copy(extra: ParamMap): LightGBMRegressionModel = defaultCopy(extra)

def saveNativeModel(filename: String, overwrite: Boolean): Unit = {
val session = SparkSession.builder().getOrCreate()
getModel.saveNativeModel(session, filename, overwrite)
}
}

object LightGBMRegressionModel extends ComplexParamsReadable[LightGBMRegressionModel] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,12 @@ class LightGBMBooster(val trainDataset: Option[LightGBMDataset] = None, val para
dataset.coalesce(1).write.mode(mode).text(filename)
}

/** Gets the native model serialized representation as a string.
*/
def getNativeModel(): String = {
modelStr.get
}

/** Dumps the native model pointer to file.
* @param session The spark session
* @param filename The name of the file to save the model to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -745,9 +745,13 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM
val modelPath = targetDir.toString + "/" + outputFileName
FileUtils.deleteDirectory(new File(modelPath))
fitModel.saveNativeModel(modelPath, overwrite = true)
val retrievedModelStr = fitModel.getNativeModel()
assert(Files.exists(Paths.get(modelPath)), true)

val oldModelString = fitModel.getModel.modelStr.get
// Assert model string is equal when retrieved from booster and getNativeModel API
assert(retrievedModelStr == oldModelString)

// Verify model string contains some feature
colsToVerify.foreach(col => oldModelString.contains(col))

Expand Down