-
Notifications
You must be signed in to change notification settings - Fork 834
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Measure transformers for Data Balance Analysis (#1218)
* Initial implementation of AssociationGapTransformer * Add normalized PMI * Add other association metrics * Add names of measures * Change counts to probabilities in transformer, Add dp unit test * Add other association measure unit tests, Fix sPmi and krc bugs * Unit tests cleanup * Refactor code * [AggregateMeasureTransformer] Initial implementation * [AggregateMeasureTransformer] Checkpoint * Changes by jasowang * [AggregateMeasureTransformer] Move agg calcs to AggregateMetrics * [AssociationGapTransformer] Quick fix * [AggregateMeasureTransformer] Add unit tests * Add DistributionMeasures, update DataImbalance tests * Changes by jasowang * Update all DataImbalance transformers based on jasowang's comments * Remove use of null * Make fixes to build Python package * Move common code to DataImbalanceParams and DataImbalanceTestBase, address PR to jasowang comments, pSensitive -> pFeature * [DistributionMeasures] Introduce uniform distribution and treat reference values as columns * AssociationGaps -> ParityMeasures * Add TransformerFuzzing to all unit test suites, refactor chi-squared p-value, address memoryz PR comments, merge from master (new namespace) * DataImbalance -> DataBalance (to go with official name of Data Balance Analysis) * Add documentation * {Refactor unit tests, introduce imbalance namespace} based on PR comments * Change measures output from MapType to nested struct * Update scaladocs based on MapType -> nested struct change * [ParityMeasures] Fix bug for calculating p(PositiveAndSensitive) * [ParityMeasures] Update unit tests * Cosmetic changes * adding the experimental tag * [DataBalanceAnalysis] Naming changes * [DataBalanceAnalysis] Use === in unit tests Co-authored-by: Kashyap Patel <[email protected]> Co-authored-by: Jason Wang <[email protected]>
- Loading branch information
1 parent
73c6a65
commit c5e1742
Showing
10 changed files
with
1,185 additions
and
1 deletion.
There are no files selected for viewing
170 changes: 170 additions & 0 deletions
170
core/src/main/scala/com/microsoft/azure/synapse/ml/exploratory/AggregateBalanceMeasure.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
// Copyright (C) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. See LICENSE in project root for information. | ||
|
||
package com.microsoft.azure.synapse.ml.exploratory | ||
|
||
import com.microsoft.azure.synapse.ml.codegen.Wrappable | ||
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions | ||
import com.microsoft.azure.synapse.ml.logging.BasicLogging | ||
import org.apache.spark.ml.param._ | ||
import org.apache.spark.ml.util.Identifiable | ||
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer} | ||
import org.apache.spark.sql._ | ||
import org.apache.spark.sql.functions._ | ||
import org.apache.spark.sql.types._ | ||
|
||
/** This transformer computes a set of aggregated balance measures that represents how balanced | ||
* the given dataframe is along the given sensitive features. | ||
* | ||
* The output is a dataframe that contains one column: | ||
* - A struct containing measure names and their values showing higher notions of inequality. | ||
* The following measures are computed: | ||
* - Atkinson Index - https://en.wikipedia.org/wiki/Atkinson_index | ||
* - Theil Index (L and T) - https://en.wikipedia.org/wiki/Theil_index | ||
* | ||
* The output dataframe contains one row. | ||
* | ||
* @param uid The unique ID. | ||
*/ | ||
@org.apache.spark.annotation.Experimental | ||
class AggregateBalanceMeasure(override val uid: String) | ||
extends Transformer | ||
with DataBalanceParams | ||
with ComplexParamsWritable | ||
with Wrappable | ||
with BasicLogging { | ||
|
||
logClass() | ||
|
||
def this() = this(Identifiable.randomUID("AggregateBalanceMeasure")) | ||
|
||
val epsilon = new DoubleParam( | ||
this, | ||
"epsilon", | ||
"Epsilon value for Atkinson Index. Inverse of alpha (1 - alpha)." | ||
) | ||
|
||
def getEpsilon: Double = $(epsilon) | ||
|
||
def setEpsilon(value: Double): this.type = set(epsilon, value) | ||
|
||
val errorTolerance = new DoubleParam( | ||
this, | ||
"errorTolerance", | ||
"Error tolerance value for Atkinson Index." | ||
) | ||
|
||
def getErrorTolerance: Double = $(errorTolerance) | ||
|
||
def setErrorTolerance(value: Double): this.type = set(errorTolerance, value) | ||
|
||
setDefault( | ||
outputCol -> "AggregateBalanceMeasure", | ||
epsilon -> 1d, | ||
errorTolerance -> 1e-12 | ||
) | ||
|
||
override def transform(dataset: Dataset[_]): DataFrame = { | ||
logTransform[DataFrame]({ | ||
validateSchema(dataset.schema) | ||
|
||
val df = dataset.cache | ||
val numRows = df.count.toDouble | ||
|
||
val featureCountCol = DatasetExtensions.findUnusedColumnName("featureCount", df.schema) | ||
val rowCountCol = DatasetExtensions.findUnusedColumnName("rowCount", df.schema) | ||
val featureProbCol = DatasetExtensions.findUnusedColumnName("featureProb", df.schema) | ||
|
||
val featureStats = df | ||
.groupBy(getSensitiveCols map col: _*) | ||
.agg(count("*").cast(DoubleType).alias(featureCountCol)) | ||
.withColumn(rowCountCol, lit(numRows)) | ||
.withColumn(featureProbCol, col(featureCountCol) / col(rowCountCol)) // P(sensitive) | ||
|
||
//noinspection ScalaStyle | ||
if (getVerbose) | ||
featureStats.cache.show(numRows = 20, truncate = false) | ||
|
||
df.unpersist | ||
calculateAggregateMeasures(featureStats, featureProbCol) | ||
}) | ||
} | ||
|
||
private def calculateAggregateMeasures(featureStats: DataFrame, featureProbCol: String): DataFrame = { | ||
val Row(numFeatures: Double, meanFeatures: Double) = | ||
featureStats.agg(count("*").cast(DoubleType), mean(featureProbCol).cast(DoubleType)).head | ||
|
||
val metricsCols = AggregateMetrics( | ||
featureProbCol, numFeatures, meanFeatures, getEpsilon, getErrorTolerance).toColumnMap.values.toSeq | ||
val aggDf = featureStats.agg(metricsCols.head, metricsCols.tail: _*) | ||
|
||
if (getVerbose) | ||
aggDf.cache.show(truncate = false) | ||
|
||
val measureTuples = AggregateMetrics.METRICS.map(col) | ||
aggDf.withColumn(getOutputCol, struct(measureTuples: _*)).select(getOutputCol) | ||
} | ||
|
||
override def copy(extra: ParamMap): Transformer = defaultCopy(extra) | ||
|
||
override def transformSchema(schema: StructType): StructType = { | ||
validateSchema(schema) | ||
|
||
StructType( | ||
StructField(getOutputCol, | ||
StructType(AggregateMetrics.METRICS.map(StructField(_, DoubleType, nullable = true))), nullable = false) :: | ||
Nil | ||
) | ||
} | ||
} | ||
|
||
object AggregateBalanceMeasure extends ComplexParamsReadable[AggregateBalanceMeasure] | ||
|
||
//noinspection SpellCheckingInspection | ||
private[exploratory] object AggregateMetrics { | ||
val ATKINSONINDEX = "atkinson_index" | ||
val THEILLINDEX = "theil_l_index" | ||
val THEILTINDEX = "theil_t_index" | ||
|
||
val METRICS = Seq(ATKINSONINDEX, THEILLINDEX, THEILTINDEX) | ||
} | ||
|
||
//noinspection SpellCheckingInspection | ||
private[exploratory] case class AggregateMetrics(featureProbCol: String, | ||
numFeatures: Double, | ||
meanFeatures: Double, | ||
epsilon: Double, | ||
errorTolerance: Double) { | ||
|
||
import AggregateMetrics._ | ||
|
||
private val normFeatureProbCol = col(featureProbCol) / meanFeatures | ||
|
||
def toColumnMap: Map[String, Column] = Map( | ||
ATKINSONINDEX -> atkinsonIndex.alias(ATKINSONINDEX), | ||
THEILLINDEX -> theilLIndex.alias(THEILLINDEX), | ||
THEILTINDEX -> theilTIndex.alias(THEILTINDEX) | ||
) | ||
|
||
def atkinsonIndex: Column = { | ||
val alpha = 1d - epsilon | ||
val productExpression = exp(sum(log(normFeatureProbCol))) | ||
val powerMeanExpression = sum(pow(normFeatureProbCol, alpha)) / numFeatures | ||
when( | ||
abs(lit(alpha)) < errorTolerance, | ||
lit(1d) - pow(productExpression, 1d / numFeatures) | ||
).otherwise( | ||
lit(1d) - pow(powerMeanExpression, 1d / alpha) | ||
) | ||
} | ||
|
||
def theilLIndex: Column = { | ||
val negativeSumLog = sum(log(normFeatureProbCol) * -1d) | ||
negativeSumLog / numFeatures | ||
} | ||
|
||
def theilTIndex: Column = { | ||
val sumLog = sum(normFeatureProbCol * log(normFeatureProbCol)) | ||
sumLog / numFeatures | ||
} | ||
} |
46 changes: 46 additions & 0 deletions
46
core/src/main/scala/com/microsoft/azure/synapse/ml/exploratory/DataBalanceParams.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
// Copyright (C) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. See LICENSE in project root for information. | ||
|
||
package com.microsoft.azure.synapse.ml.exploratory | ||
|
||
import org.apache.spark.ml.param.shared.HasOutputCol | ||
import org.apache.spark.ml.param.{BooleanParam, Params, StringArrayParam} | ||
import org.apache.spark.sql.types._ | ||
|
||
trait DataBalanceParams extends Params with HasOutputCol { | ||
val sensitiveCols = new StringArrayParam( | ||
this, | ||
"sensitiveCols", | ||
"Sensitive columns to use." | ||
) | ||
|
||
def getSensitiveCols: Array[String] = $(sensitiveCols) | ||
|
||
def setSensitiveCols(values: Array[String]): this.type = set(sensitiveCols, values) | ||
|
||
val verbose = new BooleanParam( | ||
this, | ||
"verbose", | ||
"Whether to show intermediate measures and calculations, such as Positive Rate." | ||
) | ||
|
||
def getVerbose: Boolean = $(verbose) | ||
|
||
def setVerbose(value: Boolean): this.type = set(verbose, value) | ||
|
||
def setOutputCol(value: String): this.type = set(outputCol, value) | ||
|
||
setDefault( | ||
verbose -> false | ||
) | ||
|
||
def validateSchema(schema: StructType): Unit = { | ||
getSensitiveCols.foreach { | ||
c => | ||
schema(c).dataType match { | ||
case ByteType | ShortType | IntegerType | LongType | StringType => | ||
case _ => throw new Exception(s"The sensitive column named $c does not contain integral or string values.") | ||
} | ||
} | ||
} | ||
} |
Oops, something went wrong.