Skip to content

Commit

Permalink
feat: Refactoring Param system and adding new ones (#1444)
Browse files Browse the repository at this point in the history
* Refactoring Param system and adding new ones

* fixes and added more explicit params

* remove extraneous files

* test fixes

* fix test

* test fix

* fix test

* build fix

* test fixes

* doc fix

* fix merge conflicts;

* notebook test fixes

* fix style error

* style fix

* added tests
  • Loading branch information
svotaw authored Mar 30, 2022
1 parent b734132 commit 339e516
Show file tree
Hide file tree
Showing 33 changed files with 1,415 additions and 822 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.core.utils

import org.apache.spark.ml.param._

/** Helper class for converting individual typed parameters to a single long string for passing to native libraries.
*
* Example:
* new ParamsStringBuilder(prefix = "--", delimiter = "=")
* .append("--first_param=a")
* .appendParamValueIfNotThere("first_param", Option("a2"))
* .appendParamValueIfNotThere("second_param", Option("b"))
* .appendParamValueIfNotThere("third_param", None)
* .result
*
* result == "--first_param=a --second_param=b"
*
* This utility mimics a traditional StringBuilder, where you append parts and ask for the final result() at the end.
* Each parameter appended is tracked and can be compared against the current string, to both avoid duplicates
* and provide an override mechanism. The first param added is considered the primary one to not be replaced.
*
* Use 'append' to add an unchecked string directly to end of current string.
*
* Use ParamsSet to create encapsulated subsets of params that can be incorporated into the parent ParamsStringBuilder.
*
* There is also integration with the SparkML Params system. Construct with a parent Params object and use the methods
* with Param arguments.
*
* @param parent Optional parent Params instance to validate each Param against.
* @param prefix Optional prefix to put before parameter names (e.g. "--"). Defaults to none.
* @param delimiter Optional delimiter to put between names and values (e.g. "="). Defaults to "=".
*/
class ParamsStringBuilder(parent: Option[Params], prefix: String, delimiter: String) {

// A StringBuilder is a relatively inefficient way to implement this (e.g. HashTable would be better),
// but it is simple to interpret/maintain and not on a critical perf path.
private val sb: StringBuilder = new StringBuilder()

def this(prefix: String = "", delimiter: String = "=") = {
this(None, prefix, delimiter)
}

def this(parent: Params, prefix: String, delimiter: String) = {
this(Option(parent), prefix, delimiter)
}

/** Add a parameter name-value pair to the end of the current string.
* @param optionShort Short name of the parameter (only used to check against existing params).
* @param optionLong Long name of the parameter. Will be used if it is not already set.
* @param param The Param object with the value. Note that if this is not set, nothing will be appended to string.
*/
def appendParamValueIfNotThere[T](optionShort: String, optionLong: String, param: Param[T]): ParamsStringBuilder = {
if (getParamValue(param).isDefined &&
// boost allows " " or "=" as separators
s"-$optionShort[ =]".r.findAllIn(sb.result).isEmpty &&
s"$prefix$optionLong[ =]".r.findAllIn(sb.result).isEmpty)
{
param match {
case _: IntArrayParam =>
appendParamListIfNotThere(optionLong, getParamValue(param).get.asInstanceOf[Array[Int]])
case _: DoubleArrayParam =>
appendParamListIfNotThere(optionLong, getParamValue(param).get.asInstanceOf[Array[Double]])
case _: StringArrayParam =>
appendParamListIfNotThere(optionLong, getParamValue(param).get.asInstanceOf[Array[String]])
//for (q <- getParamValue(param).get) TODO this is the old code for Array[String], but seems wrong
// append(s"$prefix$optionLong$delimiter$q")
case _ => append(s"$prefix$optionLong$delimiter${getParamValue(param).get}")
}
}
this
}

/** Add a parameter name-value pair to the end of the current string.
* @param optionLong Long name of the parameter. Will be used if it is not already set.
* @param param The Option object with the value. Note that if this is None, nothing will be appended to string.
*/
def appendParamValueIfNotThere[T](optionLong: String, param: Option[T]): ParamsStringBuilder = {
if (param.isDefined &&
// boost allow " " or "="
s"$prefix$optionLong[ =]".r.findAllIn(sb.result).isEmpty)
{
append(s"$prefix$optionLong$delimiter${param.get}")
}
sb.to
this
}

/** Add a parameter name to the end of the current string. (i.e. a param that does not have a value)
* @param optionLong Long name of the parameter. Will be used if it is not already set.
*/
def appendParamFlagIfNotThere(name: String): ParamsStringBuilder = {
if (s"$prefix$name".r.findAllIn(sb.result).isEmpty) {
append(s"$prefix$name")
}
this
}

/** Add a parameter name-list pair to the end of the current string. Values will be comma-delimited.
* @param optionLong Long name of the parameter. Will be used if it is not already set.
* @param values The Array of values. Note that if the array is empty, nothing will be appended to string.
*/
def appendParamListIfNotThere[T](name: String, values: Array[T]): ParamsStringBuilder = {
if (!values.isEmpty && s"$prefix$name".r.findAllIn(sb.result).isEmpty) {
appendParamList(name, values)
}
this
}

def appendParamList[T](name: String, values: Array[T]): ParamsStringBuilder = {
append(s"$prefix$name$delimiter${values.mkString(",")}")
}

/** Add a parameter group to the end of the current string.
* @param paramGroup The ParamGroup to add.
*/
def appendParamGroup(paramGroup: ParamGroup): ParamsStringBuilder = {
paramGroup.appendParams(this)
}

/** Add a parameter group to the end of the current string conditionally.
* @param paramGroup The ParamGroup to add.
* @param condition Whether to add the group or not.
*/
def appendParamGroup(paramGroup: ParamGroup, condition: Boolean): ParamsStringBuilder = {
if (condition) paramGroup.appendParams(this) else this
}

/** Direct append a string with no checking of existing params.
* @param str The string to add.
*/
def append(str: String): ParamsStringBuilder =
{
if (!str.isEmpty) {
if (!sb.isEmpty) sb.append(" ")
sb.append(str)
}
this
}

def result(): String =
{
sb.result
}

private def getParent(): Params = {
if (parent.isEmpty)
{
throw new IllegalArgumentException("ParamsStringBuilder requires a parent for this operation")
}
parent.get
}

private def getParamValue[T](param: Param[T]): Option[T] = {
getParent.get(param)
}
}

/** Derive from this to create an encapsulated subset of params that can be integrated
* into a parent ParamsStringBuilder with appendParamsSet, Useful for encapsulating Params into
* smaller subsets.
*/
trait ParamGroup extends Serializable {
override def toString: String = {
new ParamsStringBuilder().appendParamGroup(this).result
}

def appendParams(sb: ParamsStringBuilder): ParamsStringBuilder
}




Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.core.utils

import com.microsoft.azure.synapse.ml.core.test.base.TestBase
import org.apache.spark.ml.param._

class VerifyParamsStringBuilder extends TestBase {
// Test LightGBM style arguments, e.g. "objective=classifier"
test("Verify ParamsStringBuilder handles basic string parameters") {
val psb = new ParamsStringBuilder(prefix = "", delimiter = "=")
.append("pass_through_param=custom")
.append("pass_through_param_with_space custom2")
.append("pass_through_list=1,2,3")
.append("pass_through_flag")
.append("-short_param=ok")
.append("-short_param_with_space ok2")
.appendParamFlagIfNotThere("some_param_flag")
.appendParamFlagIfNotThere("pass_through_flag") // should not add
.appendParamListIfNotThere("some_param_list", Array[Int](3, 4, 5))
.appendParamListIfNotThere("some_empty_list", new Array(0)) // should not add
.appendParamListIfNotThere("pass_through_list", Array[Int](3, 4, 5)) // should not add
.appendParamValueIfNotThere("null_param", None) // should not add
.appendParamValueIfNotThere("some_param", Option(10))
.appendParamValueIfNotThere("pass_through_param", Option("bad_val")) // should not add

val expectedStr: String = "pass_through_param=custom" +
" pass_through_param_with_space custom2" +
" pass_through_list=1,2,3" +
" pass_through_flag" +
" -short_param=ok" +
" -short_param_with_space ok2" +
" some_param_flag" +
" some_param_list=3,4,5" +
" some_param=10"

assert(psb.result == expectedStr)
}

// Test Vowpal Wabbit style arguments, e.g. "--hash_seed 4 -l .1 --holdout_off"
test("Verify ParamsStringBuilder handles different prefix and delimiter") {
val psb = new ParamsStringBuilder(prefix = "--", delimiter = " ")
.append("--pass_through_param custom")
.append("--pass_through_list 1,2,3")
.append("--pass_through_flag")
.append("-short_param ok")
.appendParamFlagIfNotThere("some_param_flag")
.appendParamFlagIfNotThere("pass_through_flag") // should not add
.appendParamListIfNotThere("some_param_list", Array[Int](3, 4, 5))
.appendParamListIfNotThere("some_empty_list", new Array(0)) // should not add
.appendParamListIfNotThere("pass_through_list", Array[Int](3, 4, 5)) // should not add
.appendParamValueIfNotThere("null_param", None) // should not add
.appendParamValueIfNotThere("some_param", Option(10))
.appendParamValueIfNotThere("pass_through_param", Option("bad_val")) // should not add

val expectedStr: String = "--pass_through_param custom" +
" --pass_through_list 1,2,3" +
" --pass_through_flag" +
" -short_param ok" +
" --some_param_flag" +
" --some_param_list 3,4,5" +
" --some_param 10"

assert(psb.result == expectedStr)
}

test("Verify ParamsStringBuilder handles sparkml Params") {
val testParamsContainer = new TestParams()

// Populate container with parameters
testParamsContainer.setString("some_string")
testParamsContainer.setInt(3)
testParamsContainer.setFloat(3.0f)
testParamsContainer.setDouble(3.0)
testParamsContainer.setStringArray(Array("one", "two", "three"))
testParamsContainer.setIntArray(Array(1,2,3))
testParamsContainer.setDoubleArray(Array(1.0, 2.0, 3.0))

val psb = new ParamsStringBuilder(testParamsContainer, prefix = "--", delimiter = " ")
.appendParamValueIfNotThere("string_value", "string_value", testParamsContainer.testString)
.appendParamValueIfNotThere("int_value", "int_value", testParamsContainer.testInt)
.appendParamValueIfNotThere("float_value", "float_value", testParamsContainer.testFloat)
.appendParamValueIfNotThere("double_value", "double_value", testParamsContainer.testDouble)
.appendParamValueIfNotThere("string_array_value", "string_array_value", testParamsContainer.testStringArray)
.appendParamValueIfNotThere("int_array_value", "int_array_value", testParamsContainer.testIntArray)
.appendParamValueIfNotThere("double_array_value", "double_array_value", testParamsContainer.testDoubleArray)

val expectedStr: String = "--string_value some_string" +
" --int_value 3" +
" --float_value 3.0" +
" --double_value 3.0" +
" --string_array_value one,two,three" +
" --int_array_value 1,2,3" +
" --double_array_value 1.0,2.0,3.0"

assert(psb.result == expectedStr)
}

test("Verify ParamsStringBuilder ignores sparkml Params unset parameters") {
val testParamsContainer = new TestParams()

val psb = new ParamsStringBuilder(testParamsContainer, prefix = "--", delimiter = " ")
.appendParamValueIfNotThere("string_value", "string_value", testParamsContainer.testString)
.appendParamValueIfNotThere("int_value", "int_value", testParamsContainer.testInt)
.appendParamValueIfNotThere("float_value", "float_value", testParamsContainer.testFloat)
.appendParamValueIfNotThere("double_value", "double_value", testParamsContainer.testDouble)
.appendParamValueIfNotThere("string_array_value", "string_array_value", testParamsContainer.testStringArray)
.appendParamValueIfNotThere("int_array_value", "int_array_value", testParamsContainer.testIntArray)
.appendParamValueIfNotThere("double_array_value", "double_array_value", testParamsContainer.testDoubleArray)

val expectedStr: String = "" // Since no parameters were set, they should not have been appended to string

assert(psb.result == expectedStr)
}

test("Verify ParamsStringBuilder throws when sparkml gets invalid Param") {
val testParamsContainer = new TestParams()

val unknownParam = new Param[String]("some parent", "unknown_param", "doc_string")
assertThrows[IllegalArgumentException] {
val psb = new ParamsStringBuilder(testParamsContainer, prefix = "--", delimiter = " ")
.appendParamValueIfNotThere("string_value", "string_value", unknownParam)
}

assertThrows[IllegalArgumentException] {
val psb = new ParamsStringBuilder(prefix = "--", delimiter = " ")
.appendParamValueIfNotThere("string_value", "string_value", testParamsContainer.testString)
}
}

test("Verify ParamsStringBuilder handles sparkml Params custom overrrides") {
val testParamsContainer = new TestParams()

// Populate container with parameters that will be overridden
testParamsContainer.setString("some_string")
testParamsContainer.setInt(3)
testParamsContainer.setFloat(3.0f)
testParamsContainer.setDouble(3.0)
testParamsContainer.setStringArray(Array("one", "two", "three"))
testParamsContainer.setIntArray(Array(1,2,3))
testParamsContainer.setDoubleArray(Array(1.0, 2.0, 3.0))

val psb = new ParamsStringBuilder(testParamsContainer, prefix = "--", delimiter = " ")
.append("--string_value some_override")
.append("--int_value 99")
.append("--float_value 99.0")
.append("--double_value 99.0")
.append("--string_array_value three,four,five")
.append("--int_array_value 97,98,99")
.append("--double_array_value 97.0,98.0,99.0")
.appendParamValueIfNotThere("string_value", "string_value", testParamsContainer.testString)
.appendParamValueIfNotThere("int_value", "int_value", testParamsContainer.testInt)
.appendParamValueIfNotThere("float_value", "float_value", testParamsContainer.testFloat)
.appendParamValueIfNotThere("double_value", "double_value", testParamsContainer.testDouble)
.appendParamValueIfNotThere("int_array_value", "int_array_value", testParamsContainer.testIntArray)
.appendParamValueIfNotThere("double_array_value", "double_array_value", testParamsContainer.testDoubleArray)
.appendParamValueIfNotThere("string_array_value", "string_array_value", testParamsContainer.testStringArray)

// We expect the set Params to be ignored since an override exists, so we should only see the original overrides
val expectedStr: String = "--string_value some_override" +
" --int_value 99" +
" --float_value 99.0" +
" --double_value 99.0" +
" --string_array_value three,four,five" +
" --int_array_value 97,98,99" +
" --double_array_value 97.0,98.0,99.0"

assert(psb.result == expectedStr)
}

test("Verify ParamsStringBuilder handles ParamGroup") {
val testParamGroup = new TestParamGroup(1, None)

val psb = new ParamsStringBuilder(prefix = "", delimiter = "=")
.appendParamGroup(testParamGroup)

// We expect the optional Int to be ignored since it wasn't set
val expectedStr: String = "some_int=1"

assert(psb.result == expectedStr)
}
}

/** Test class for testing sparkml Param handling
* */
class TestParams extends Params {
override def copy(extra: ParamMap): Params = this // not needed
override val uid: String = "some id"

val testString = new Param[String](this, "testString", "Test String param")
def setString(value: String): this.type = set(testString, value)

val testInt = new IntParam(this, "testInt", "Test Int param")
def setInt(value: Int): this.type = set(testInt, value)

val testFloat = new FloatParam(this, "testFloat", "Test Float param")
def setFloat(value: Float): this.type = set(testFloat, value)

val testDouble = new DoubleParam(this, "testDouble", "Test Double param")
def setDouble(value: Double): this.type = set(testDouble, value)

val testStringArray = new StringArrayParam(this, "testStringArray", "Test StringArray param")
def setStringArray(value: Array[String]): this.type = set(testStringArray, value)

val testIntArray = new IntArrayParam(this, "testIntArray", "Test IntArray param")
def setIntArray(value: Array[Int]): this.type = set(testIntArray, value)

val testDoubleArray = new DoubleArrayParam(this, "testDoubleArray", "Test DoubleArray param")
def setDoubleArray(value: Array[Double]): this.type = set(testDoubleArray, value)
}

/** Test class for testing ParamGroup handling
* */
case class TestParamGroup (someInt: Int,
someOptionInt: Option[Int]) extends ParamGroup {
def appendParams(sb: ParamsStringBuilder): ParamsStringBuilder = {
sb.appendParamValueIfNotThere("some_int", Option(someInt))
.appendParamValueIfNotThere("some_option_int", someOptionInt)
}
}
Loading

0 comments on commit 339e516

Please sign in to comment.