-
Notifications
You must be signed in to change notification settings - Fork 834
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Refactoring Param system and adding new ones (#1444)
* Refactoring Param system and adding new ones * fixes and added more explicit params * remove extraneous files * test fixes * fix test * test fix * fix test * build fix * test fixes * doc fix * fix merge conflicts; * notebook test fixes * fix style error * style fix * added tests
- Loading branch information
Showing
33 changed files
with
1,415 additions
and
822 deletions.
There are no files selected for viewing
174 changes: 174 additions & 0 deletions
174
core/src/main/scala/com/microsoft/azure/synapse/ml/core/utils/ParamsStringBuilder.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
// Copyright (C) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. See LICENSE in project root for information. | ||
|
||
package com.microsoft.azure.synapse.ml.core.utils | ||
|
||
import org.apache.spark.ml.param._ | ||
|
||
/** Helper class for converting individual typed parameters to a single long string for passing to native libraries. | ||
* | ||
* Example: | ||
* new ParamsStringBuilder(prefix = "--", delimiter = "=") | ||
* .append("--first_param=a") | ||
* .appendParamValueIfNotThere("first_param", Option("a2")) | ||
* .appendParamValueIfNotThere("second_param", Option("b")) | ||
* .appendParamValueIfNotThere("third_param", None) | ||
* .result | ||
* | ||
* result == "--first_param=a --second_param=b" | ||
* | ||
* This utility mimics a traditional StringBuilder, where you append parts and ask for the final result() at the end. | ||
* Each parameter appended is tracked and can be compared against the current string, to both avoid duplicates | ||
* and provide an override mechanism. The first param added is considered the primary one to not be replaced. | ||
* | ||
* Use 'append' to add an unchecked string directly to end of current string. | ||
* | ||
* Use ParamsSet to create encapsulated subsets of params that can be incorporated into the parent ParamsStringBuilder. | ||
* | ||
* There is also integration with the SparkML Params system. Construct with a parent Params object and use the methods | ||
* with Param arguments. | ||
* | ||
* @param parent Optional parent Params instance to validate each Param against. | ||
* @param prefix Optional prefix to put before parameter names (e.g. "--"). Defaults to none. | ||
* @param delimiter Optional delimiter to put between names and values (e.g. "="). Defaults to "=". | ||
*/ | ||
class ParamsStringBuilder(parent: Option[Params], prefix: String, delimiter: String) { | ||
|
||
// A StringBuilder is a relatively inefficient way to implement this (e.g. HashTable would be better), | ||
// but it is simple to interpret/maintain and not on a critical perf path. | ||
private val sb: StringBuilder = new StringBuilder() | ||
|
||
def this(prefix: String = "", delimiter: String = "=") = { | ||
this(None, prefix, delimiter) | ||
} | ||
|
||
def this(parent: Params, prefix: String, delimiter: String) = { | ||
this(Option(parent), prefix, delimiter) | ||
} | ||
|
||
/** Add a parameter name-value pair to the end of the current string. | ||
* @param optionShort Short name of the parameter (only used to check against existing params). | ||
* @param optionLong Long name of the parameter. Will be used if it is not already set. | ||
* @param param The Param object with the value. Note that if this is not set, nothing will be appended to string. | ||
*/ | ||
def appendParamValueIfNotThere[T](optionShort: String, optionLong: String, param: Param[T]): ParamsStringBuilder = { | ||
if (getParamValue(param).isDefined && | ||
// boost allows " " or "=" as separators | ||
s"-$optionShort[ =]".r.findAllIn(sb.result).isEmpty && | ||
s"$prefix$optionLong[ =]".r.findAllIn(sb.result).isEmpty) | ||
{ | ||
param match { | ||
case _: IntArrayParam => | ||
appendParamListIfNotThere(optionLong, getParamValue(param).get.asInstanceOf[Array[Int]]) | ||
case _: DoubleArrayParam => | ||
appendParamListIfNotThere(optionLong, getParamValue(param).get.asInstanceOf[Array[Double]]) | ||
case _: StringArrayParam => | ||
appendParamListIfNotThere(optionLong, getParamValue(param).get.asInstanceOf[Array[String]]) | ||
//for (q <- getParamValue(param).get) TODO this is the old code for Array[String], but seems wrong | ||
// append(s"$prefix$optionLong$delimiter$q") | ||
case _ => append(s"$prefix$optionLong$delimiter${getParamValue(param).get}") | ||
} | ||
} | ||
this | ||
} | ||
|
||
/** Add a parameter name-value pair to the end of the current string. | ||
* @param optionLong Long name of the parameter. Will be used if it is not already set. | ||
* @param param The Option object with the value. Note that if this is None, nothing will be appended to string. | ||
*/ | ||
def appendParamValueIfNotThere[T](optionLong: String, param: Option[T]): ParamsStringBuilder = { | ||
if (param.isDefined && | ||
// boost allow " " or "=" | ||
s"$prefix$optionLong[ =]".r.findAllIn(sb.result).isEmpty) | ||
{ | ||
append(s"$prefix$optionLong$delimiter${param.get}") | ||
} | ||
sb.to | ||
this | ||
} | ||
|
||
/** Add a parameter name to the end of the current string. (i.e. a param that does not have a value) | ||
* @param optionLong Long name of the parameter. Will be used if it is not already set. | ||
*/ | ||
def appendParamFlagIfNotThere(name: String): ParamsStringBuilder = { | ||
if (s"$prefix$name".r.findAllIn(sb.result).isEmpty) { | ||
append(s"$prefix$name") | ||
} | ||
this | ||
} | ||
|
||
/** Add a parameter name-list pair to the end of the current string. Values will be comma-delimited. | ||
* @param optionLong Long name of the parameter. Will be used if it is not already set. | ||
* @param values The Array of values. Note that if the array is empty, nothing will be appended to string. | ||
*/ | ||
def appendParamListIfNotThere[T](name: String, values: Array[T]): ParamsStringBuilder = { | ||
if (!values.isEmpty && s"$prefix$name".r.findAllIn(sb.result).isEmpty) { | ||
appendParamList(name, values) | ||
} | ||
this | ||
} | ||
|
||
def appendParamList[T](name: String, values: Array[T]): ParamsStringBuilder = { | ||
append(s"$prefix$name$delimiter${values.mkString(",")}") | ||
} | ||
|
||
/** Add a parameter group to the end of the current string. | ||
* @param paramGroup The ParamGroup to add. | ||
*/ | ||
def appendParamGroup(paramGroup: ParamGroup): ParamsStringBuilder = { | ||
paramGroup.appendParams(this) | ||
} | ||
|
||
/** Add a parameter group to the end of the current string conditionally. | ||
* @param paramGroup The ParamGroup to add. | ||
* @param condition Whether to add the group or not. | ||
*/ | ||
def appendParamGroup(paramGroup: ParamGroup, condition: Boolean): ParamsStringBuilder = { | ||
if (condition) paramGroup.appendParams(this) else this | ||
} | ||
|
||
/** Direct append a string with no checking of existing params. | ||
* @param str The string to add. | ||
*/ | ||
def append(str: String): ParamsStringBuilder = | ||
{ | ||
if (!str.isEmpty) { | ||
if (!sb.isEmpty) sb.append(" ") | ||
sb.append(str) | ||
} | ||
this | ||
} | ||
|
||
def result(): String = | ||
{ | ||
sb.result | ||
} | ||
|
||
private def getParent(): Params = { | ||
if (parent.isEmpty) | ||
{ | ||
throw new IllegalArgumentException("ParamsStringBuilder requires a parent for this operation") | ||
} | ||
parent.get | ||
} | ||
|
||
private def getParamValue[T](param: Param[T]): Option[T] = { | ||
getParent.get(param) | ||
} | ||
} | ||
|
||
/** Derive from this to create an encapsulated subset of params that can be integrated | ||
* into a parent ParamsStringBuilder with appendParamsSet, Useful for encapsulating Params into | ||
* smaller subsets. | ||
*/ | ||
trait ParamGroup extends Serializable { | ||
override def toString: String = { | ||
new ParamsStringBuilder().appendParamGroup(this).result | ||
} | ||
|
||
def appendParams(sb: ParamsStringBuilder): ParamsStringBuilder | ||
} | ||
|
||
|
||
|
||
|
221 changes: 221 additions & 0 deletions
221
.../src/test/scala/com/microsoft/azure/synapse/ml/core/utils/VerifyParamsStringBuilder.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
// Copyright (C) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. See LICENSE in project root for information. | ||
|
||
package com.microsoft.azure.synapse.ml.core.utils | ||
|
||
import com.microsoft.azure.synapse.ml.core.test.base.TestBase | ||
import org.apache.spark.ml.param._ | ||
|
||
class VerifyParamsStringBuilder extends TestBase { | ||
// Test LightGBM style arguments, e.g. "objective=classifier" | ||
test("Verify ParamsStringBuilder handles basic string parameters") { | ||
val psb = new ParamsStringBuilder(prefix = "", delimiter = "=") | ||
.append("pass_through_param=custom") | ||
.append("pass_through_param_with_space custom2") | ||
.append("pass_through_list=1,2,3") | ||
.append("pass_through_flag") | ||
.append("-short_param=ok") | ||
.append("-short_param_with_space ok2") | ||
.appendParamFlagIfNotThere("some_param_flag") | ||
.appendParamFlagIfNotThere("pass_through_flag") // should not add | ||
.appendParamListIfNotThere("some_param_list", Array[Int](3, 4, 5)) | ||
.appendParamListIfNotThere("some_empty_list", new Array(0)) // should not add | ||
.appendParamListIfNotThere("pass_through_list", Array[Int](3, 4, 5)) // should not add | ||
.appendParamValueIfNotThere("null_param", None) // should not add | ||
.appendParamValueIfNotThere("some_param", Option(10)) | ||
.appendParamValueIfNotThere("pass_through_param", Option("bad_val")) // should not add | ||
|
||
val expectedStr: String = "pass_through_param=custom" + | ||
" pass_through_param_with_space custom2" + | ||
" pass_through_list=1,2,3" + | ||
" pass_through_flag" + | ||
" -short_param=ok" + | ||
" -short_param_with_space ok2" + | ||
" some_param_flag" + | ||
" some_param_list=3,4,5" + | ||
" some_param=10" | ||
|
||
assert(psb.result == expectedStr) | ||
} | ||
|
||
// Test Vowpal Wabbit style arguments, e.g. "--hash_seed 4 -l .1 --holdout_off" | ||
test("Verify ParamsStringBuilder handles different prefix and delimiter") { | ||
val psb = new ParamsStringBuilder(prefix = "--", delimiter = " ") | ||
.append("--pass_through_param custom") | ||
.append("--pass_through_list 1,2,3") | ||
.append("--pass_through_flag") | ||
.append("-short_param ok") | ||
.appendParamFlagIfNotThere("some_param_flag") | ||
.appendParamFlagIfNotThere("pass_through_flag") // should not add | ||
.appendParamListIfNotThere("some_param_list", Array[Int](3, 4, 5)) | ||
.appendParamListIfNotThere("some_empty_list", new Array(0)) // should not add | ||
.appendParamListIfNotThere("pass_through_list", Array[Int](3, 4, 5)) // should not add | ||
.appendParamValueIfNotThere("null_param", None) // should not add | ||
.appendParamValueIfNotThere("some_param", Option(10)) | ||
.appendParamValueIfNotThere("pass_through_param", Option("bad_val")) // should not add | ||
|
||
val expectedStr: String = "--pass_through_param custom" + | ||
" --pass_through_list 1,2,3" + | ||
" --pass_through_flag" + | ||
" -short_param ok" + | ||
" --some_param_flag" + | ||
" --some_param_list 3,4,5" + | ||
" --some_param 10" | ||
|
||
assert(psb.result == expectedStr) | ||
} | ||
|
||
test("Verify ParamsStringBuilder handles sparkml Params") { | ||
val testParamsContainer = new TestParams() | ||
|
||
// Populate container with parameters | ||
testParamsContainer.setString("some_string") | ||
testParamsContainer.setInt(3) | ||
testParamsContainer.setFloat(3.0f) | ||
testParamsContainer.setDouble(3.0) | ||
testParamsContainer.setStringArray(Array("one", "two", "three")) | ||
testParamsContainer.setIntArray(Array(1,2,3)) | ||
testParamsContainer.setDoubleArray(Array(1.0, 2.0, 3.0)) | ||
|
||
val psb = new ParamsStringBuilder(testParamsContainer, prefix = "--", delimiter = " ") | ||
.appendParamValueIfNotThere("string_value", "string_value", testParamsContainer.testString) | ||
.appendParamValueIfNotThere("int_value", "int_value", testParamsContainer.testInt) | ||
.appendParamValueIfNotThere("float_value", "float_value", testParamsContainer.testFloat) | ||
.appendParamValueIfNotThere("double_value", "double_value", testParamsContainer.testDouble) | ||
.appendParamValueIfNotThere("string_array_value", "string_array_value", testParamsContainer.testStringArray) | ||
.appendParamValueIfNotThere("int_array_value", "int_array_value", testParamsContainer.testIntArray) | ||
.appendParamValueIfNotThere("double_array_value", "double_array_value", testParamsContainer.testDoubleArray) | ||
|
||
val expectedStr: String = "--string_value some_string" + | ||
" --int_value 3" + | ||
" --float_value 3.0" + | ||
" --double_value 3.0" + | ||
" --string_array_value one,two,three" + | ||
" --int_array_value 1,2,3" + | ||
" --double_array_value 1.0,2.0,3.0" | ||
|
||
assert(psb.result == expectedStr) | ||
} | ||
|
||
test("Verify ParamsStringBuilder ignores sparkml Params unset parameters") { | ||
val testParamsContainer = new TestParams() | ||
|
||
val psb = new ParamsStringBuilder(testParamsContainer, prefix = "--", delimiter = " ") | ||
.appendParamValueIfNotThere("string_value", "string_value", testParamsContainer.testString) | ||
.appendParamValueIfNotThere("int_value", "int_value", testParamsContainer.testInt) | ||
.appendParamValueIfNotThere("float_value", "float_value", testParamsContainer.testFloat) | ||
.appendParamValueIfNotThere("double_value", "double_value", testParamsContainer.testDouble) | ||
.appendParamValueIfNotThere("string_array_value", "string_array_value", testParamsContainer.testStringArray) | ||
.appendParamValueIfNotThere("int_array_value", "int_array_value", testParamsContainer.testIntArray) | ||
.appendParamValueIfNotThere("double_array_value", "double_array_value", testParamsContainer.testDoubleArray) | ||
|
||
val expectedStr: String = "" // Since no parameters were set, they should not have been appended to string | ||
|
||
assert(psb.result == expectedStr) | ||
} | ||
|
||
test("Verify ParamsStringBuilder throws when sparkml gets invalid Param") { | ||
val testParamsContainer = new TestParams() | ||
|
||
val unknownParam = new Param[String]("some parent", "unknown_param", "doc_string") | ||
assertThrows[IllegalArgumentException] { | ||
val psb = new ParamsStringBuilder(testParamsContainer, prefix = "--", delimiter = " ") | ||
.appendParamValueIfNotThere("string_value", "string_value", unknownParam) | ||
} | ||
|
||
assertThrows[IllegalArgumentException] { | ||
val psb = new ParamsStringBuilder(prefix = "--", delimiter = " ") | ||
.appendParamValueIfNotThere("string_value", "string_value", testParamsContainer.testString) | ||
} | ||
} | ||
|
||
test("Verify ParamsStringBuilder handles sparkml Params custom overrrides") { | ||
val testParamsContainer = new TestParams() | ||
|
||
// Populate container with parameters that will be overridden | ||
testParamsContainer.setString("some_string") | ||
testParamsContainer.setInt(3) | ||
testParamsContainer.setFloat(3.0f) | ||
testParamsContainer.setDouble(3.0) | ||
testParamsContainer.setStringArray(Array("one", "two", "three")) | ||
testParamsContainer.setIntArray(Array(1,2,3)) | ||
testParamsContainer.setDoubleArray(Array(1.0, 2.0, 3.0)) | ||
|
||
val psb = new ParamsStringBuilder(testParamsContainer, prefix = "--", delimiter = " ") | ||
.append("--string_value some_override") | ||
.append("--int_value 99") | ||
.append("--float_value 99.0") | ||
.append("--double_value 99.0") | ||
.append("--string_array_value three,four,five") | ||
.append("--int_array_value 97,98,99") | ||
.append("--double_array_value 97.0,98.0,99.0") | ||
.appendParamValueIfNotThere("string_value", "string_value", testParamsContainer.testString) | ||
.appendParamValueIfNotThere("int_value", "int_value", testParamsContainer.testInt) | ||
.appendParamValueIfNotThere("float_value", "float_value", testParamsContainer.testFloat) | ||
.appendParamValueIfNotThere("double_value", "double_value", testParamsContainer.testDouble) | ||
.appendParamValueIfNotThere("int_array_value", "int_array_value", testParamsContainer.testIntArray) | ||
.appendParamValueIfNotThere("double_array_value", "double_array_value", testParamsContainer.testDoubleArray) | ||
.appendParamValueIfNotThere("string_array_value", "string_array_value", testParamsContainer.testStringArray) | ||
|
||
// We expect the set Params to be ignored since an override exists, so we should only see the original overrides | ||
val expectedStr: String = "--string_value some_override" + | ||
" --int_value 99" + | ||
" --float_value 99.0" + | ||
" --double_value 99.0" + | ||
" --string_array_value three,four,five" + | ||
" --int_array_value 97,98,99" + | ||
" --double_array_value 97.0,98.0,99.0" | ||
|
||
assert(psb.result == expectedStr) | ||
} | ||
|
||
test("Verify ParamsStringBuilder handles ParamGroup") { | ||
val testParamGroup = new TestParamGroup(1, None) | ||
|
||
val psb = new ParamsStringBuilder(prefix = "", delimiter = "=") | ||
.appendParamGroup(testParamGroup) | ||
|
||
// We expect the optional Int to be ignored since it wasn't set | ||
val expectedStr: String = "some_int=1" | ||
|
||
assert(psb.result == expectedStr) | ||
} | ||
} | ||
|
||
/** Test class for testing sparkml Param handling | ||
* */ | ||
class TestParams extends Params { | ||
override def copy(extra: ParamMap): Params = this // not needed | ||
override val uid: String = "some id" | ||
|
||
val testString = new Param[String](this, "testString", "Test String param") | ||
def setString(value: String): this.type = set(testString, value) | ||
|
||
val testInt = new IntParam(this, "testInt", "Test Int param") | ||
def setInt(value: Int): this.type = set(testInt, value) | ||
|
||
val testFloat = new FloatParam(this, "testFloat", "Test Float param") | ||
def setFloat(value: Float): this.type = set(testFloat, value) | ||
|
||
val testDouble = new DoubleParam(this, "testDouble", "Test Double param") | ||
def setDouble(value: Double): this.type = set(testDouble, value) | ||
|
||
val testStringArray = new StringArrayParam(this, "testStringArray", "Test StringArray param") | ||
def setStringArray(value: Array[String]): this.type = set(testStringArray, value) | ||
|
||
val testIntArray = new IntArrayParam(this, "testIntArray", "Test IntArray param") | ||
def setIntArray(value: Array[Int]): this.type = set(testIntArray, value) | ||
|
||
val testDoubleArray = new DoubleArrayParam(this, "testDoubleArray", "Test DoubleArray param") | ||
def setDoubleArray(value: Array[Double]): this.type = set(testDoubleArray, value) | ||
} | ||
|
||
/** Test class for testing ParamGroup handling | ||
* */ | ||
case class TestParamGroup (someInt: Int, | ||
someOptionInt: Option[Int]) extends ParamGroup { | ||
def appendParams(sb: ParamsStringBuilder): ParamsStringBuilder = { | ||
sb.appendParamValueIfNotThere("some_int", Option(someInt)) | ||
.appendParamValueIfNotThere("some_option_int", someOptionInt) | ||
} | ||
} |
Oops, something went wrong.