diff --git a/lightgbm/src/main/python/synapse/ml/lightgbm/mixin.py b/lightgbm/src/main/python/synapse/ml/lightgbm/mixin.py index 84e9ff8463..242b23a29a 100644 --- a/lightgbm/src/main/python/synapse/ml/lightgbm/mixin.py +++ b/lightgbm/src/main/python/synapse/ml/lightgbm/mixin.py @@ -13,6 +13,12 @@ def saveNativeModel(self, filename, overwrite=True): """ self._java_obj.saveNativeModel(filename, overwrite) + def getNativeModel(self): + """ + Get the native model serialized representation as a string. + """ + return self._java_obj.getNativeModel() + def getFeatureImportances(self, importance_type="split"): """ Get the feature importances as a list. The importance_type can be "split" or "gain". diff --git a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMClassifier.scala b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMClassifier.scala index 11548c1fcd..74a14ac35f 100644 --- a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMClassifier.scala +++ b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMClassifier.scala @@ -183,11 +183,6 @@ class LightGBMClassificationModel(override val uid: String) udf(predict _).apply(col(getFeaturesCol)) } } - - def saveNativeModel(filename: String, overwrite: Boolean): Unit = { - val session = SparkSession.builder().getOrCreate() - getModel.saveNativeModel(session, filename, overwrite) - } } object LightGBMClassificationModel extends ComplexParamsReadable[LightGBMClassificationModel] { diff --git a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMModelMethods.scala b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMModelMethods.scala index 47e9979695..65ba78a89f 100644 --- a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMModelMethods.scala +++ b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMModelMethods.scala @@ -6,6 +6,7 @@ package com.microsoft.azure.synapse.ml.lightgbm import com.microsoft.azure.synapse.ml.lightgbm.params.LightGBMModelParams import org.apache.spark.internal.Logging import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.sql.SparkSession /** Contains common LightGBM model methods across all LightGBM learner types. */ @@ -91,6 +92,22 @@ trait LightGBMModelMethods extends LightGBMModelParams with Logging { getLightGBMBooster.numClasses } + /** Saves the native model serialized representation to file. + * @param session The spark session + * @param filename The name of the file to save the model to + * @param overwrite Whether to overwrite if the file already exists + */ + def saveNativeModel(filename: String, overwrite: Boolean): Unit = { + val session = SparkSession.builder().getOrCreate() + getModel.saveNativeModel(session, filename, overwrite) + } + + /** Gets the native model serialized representation as a string. + */ + def getNativeModel(): String = { + getModel.getNativeModel() + } + /** * Protected method to predict leaf index. * @param features The local instance or row to compute the leaf index for. diff --git a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMRanker.scala b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMRanker.scala index 9f8a45cd78..76df4d4feb 100644 --- a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMRanker.scala +++ b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMRanker.scala @@ -163,11 +163,6 @@ class LightGBMRankerModel(override val uid: String) override def copy(extra: ParamMap): LightGBMRankerModel = defaultCopy(extra) override def numFeatures: Int = getModel.numFeatures - - def saveNativeModel(filename: String, overwrite: Boolean): Unit = { - val session = SparkSession.builder().getOrCreate() - getModel.saveNativeModel(session, filename, overwrite) - } } object LightGBMRankerModel extends ComplexParamsReadable[LightGBMRankerModel] { diff --git a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMRegressor.scala b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMRegressor.scala index 5141bde0a6..4f82fa79d1 100644 --- a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMRegressor.scala +++ b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMRegressor.scala @@ -132,11 +132,6 @@ class LightGBMRegressionModel(override val uid: String) } override def copy(extra: ParamMap): LightGBMRegressionModel = defaultCopy(extra) - - def saveNativeModel(filename: String, overwrite: Boolean): Unit = { - val session = SparkSession.builder().getOrCreate() - getModel.saveNativeModel(session, filename, overwrite) - } } object LightGBMRegressionModel extends ComplexParamsReadable[LightGBMRegressionModel] { diff --git a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/booster/LightGBMBooster.scala b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/booster/LightGBMBooster.scala index fe6981bf53..4fb81fd999 100644 --- a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/booster/LightGBMBooster.scala +++ b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/booster/LightGBMBooster.scala @@ -465,6 +465,12 @@ class LightGBMBooster(val trainDataset: Option[LightGBMDataset] = None, val para dataset.coalesce(1).write.mode(mode).text(filename) } + /** Gets the native model serialized representation as a string. + */ + def getNativeModel(): String = { + modelStr.get + } + /** Dumps the native model pointer to file. * @param session The spark session * @param filename The name of the file to save the model to diff --git a/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split1/VerifyLightGBMClassifier.scala b/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split1/VerifyLightGBMClassifier.scala index 23ba771a74..272d43ac45 100644 --- a/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split1/VerifyLightGBMClassifier.scala +++ b/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split1/VerifyLightGBMClassifier.scala @@ -745,9 +745,13 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM val modelPath = targetDir.toString + "/" + outputFileName FileUtils.deleteDirectory(new File(modelPath)) fitModel.saveNativeModel(modelPath, overwrite = true) + val retrievedModelStr = fitModel.getNativeModel() assert(Files.exists(Paths.get(modelPath)), true) val oldModelString = fitModel.getModel.modelStr.get + // Assert model string is equal when retrieved from booster and getNativeModel API + assert(retrievedModelStr == oldModelString) + // Verify model string contains some feature colsToVerify.foreach(col => oldModelString.contains(col))