Skip to content

Commit

Permalink
feat: Measure transformers for Data Balance Analysis (#1218)
Browse files Browse the repository at this point in the history
* Initial implementation of AssociationGapTransformer

* Add normalized PMI

* Add other association metrics

* Add names of measures

* Change counts to probabilities in transformer, Add dp unit test

* Add other association measure unit tests, Fix sPmi and krc bugs

* Unit tests cleanup

* Refactor code

* [AggregateMeasureTransformer] Initial implementation

* [AggregateMeasureTransformer] Checkpoint

* Changes by jasowang

* [AggregateMeasureTransformer] Move agg calcs to AggregateMetrics

* [AssociationGapTransformer] Quick fix

* [AggregateMeasureTransformer] Add unit tests

* Add DistributionMeasures, update DataImbalance tests

* Changes by jasowang

* Update all DataImbalance transformers based on jasowang's comments

* Remove use of null

* Make fixes to build Python package

* Move common code to DataImbalanceParams and DataImbalanceTestBase, address PR to jasowang comments, pSensitive -> pFeature

* [DistributionMeasures] Introduce uniform distribution and treat reference values as columns

* AssociationGaps -> ParityMeasures

* Add TransformerFuzzing to all unit test suites, refactor chi-squared p-value, address memoryz PR comments, merge from master (new namespace)

* DataImbalance -> DataBalance (to go with official name of Data Balance Analysis)

* Add documentation

* {Refactor unit tests, introduce imbalance namespace} based on PR comments

* Change measures output from MapType to nested struct

* Update scaladocs based on MapType -> nested struct change

* [ParityMeasures] Fix bug for calculating p(PositiveAndSensitive)

* [ParityMeasures] Update unit tests

* Cosmetic changes

* adding the experimental tag

* [DataBalanceAnalysis] Naming changes

* [DataBalanceAnalysis] Use === in unit tests

Co-authored-by: Kashyap Patel <[email protected]>
Co-authored-by: Jason Wang <[email protected]>
  • Loading branch information
3 people authored Oct 22, 2021
1 parent 73c6a65 commit c5e1742
Show file tree
Hide file tree
Showing 10 changed files with 1,185 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.exploratory

import com.microsoft.azure.synapse.ml.codegen.Wrappable
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions
import com.microsoft.azure.synapse.ml.logging.BasicLogging
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer}
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

/** This transformer computes a set of aggregated balance measures that represents how balanced
* the given dataframe is along the given sensitive features.
*
* The output is a dataframe that contains one column:
* - A struct containing measure names and their values showing higher notions of inequality.
* The following measures are computed:
* - Atkinson Index - https://en.wikipedia.org/wiki/Atkinson_index
* - Theil Index (L and T) - https://en.wikipedia.org/wiki/Theil_index
*
* The output dataframe contains one row.
*
* @param uid The unique ID.
*/
@org.apache.spark.annotation.Experimental
class AggregateBalanceMeasure(override val uid: String)
extends Transformer
with DataBalanceParams
with ComplexParamsWritable
with Wrappable
with BasicLogging {

logClass()

def this() = this(Identifiable.randomUID("AggregateBalanceMeasure"))

val epsilon = new DoubleParam(
this,
"epsilon",
"Epsilon value for Atkinson Index. Inverse of alpha (1 - alpha)."
)

def getEpsilon: Double = $(epsilon)

def setEpsilon(value: Double): this.type = set(epsilon, value)

val errorTolerance = new DoubleParam(
this,
"errorTolerance",
"Error tolerance value for Atkinson Index."
)

def getErrorTolerance: Double = $(errorTolerance)

def setErrorTolerance(value: Double): this.type = set(errorTolerance, value)

setDefault(
outputCol -> "AggregateBalanceMeasure",
epsilon -> 1d,
errorTolerance -> 1e-12
)

override def transform(dataset: Dataset[_]): DataFrame = {
logTransform[DataFrame]({
validateSchema(dataset.schema)

val df = dataset.cache
val numRows = df.count.toDouble

val featureCountCol = DatasetExtensions.findUnusedColumnName("featureCount", df.schema)
val rowCountCol = DatasetExtensions.findUnusedColumnName("rowCount", df.schema)
val featureProbCol = DatasetExtensions.findUnusedColumnName("featureProb", df.schema)

val featureStats = df
.groupBy(getSensitiveCols map col: _*)
.agg(count("*").cast(DoubleType).alias(featureCountCol))
.withColumn(rowCountCol, lit(numRows))
.withColumn(featureProbCol, col(featureCountCol) / col(rowCountCol)) // P(sensitive)

//noinspection ScalaStyle
if (getVerbose)
featureStats.cache.show(numRows = 20, truncate = false)

df.unpersist
calculateAggregateMeasures(featureStats, featureProbCol)
})
}

private def calculateAggregateMeasures(featureStats: DataFrame, featureProbCol: String): DataFrame = {
val Row(numFeatures: Double, meanFeatures: Double) =
featureStats.agg(count("*").cast(DoubleType), mean(featureProbCol).cast(DoubleType)).head

val metricsCols = AggregateMetrics(
featureProbCol, numFeatures, meanFeatures, getEpsilon, getErrorTolerance).toColumnMap.values.toSeq
val aggDf = featureStats.agg(metricsCols.head, metricsCols.tail: _*)

if (getVerbose)
aggDf.cache.show(truncate = false)

val measureTuples = AggregateMetrics.METRICS.map(col)
aggDf.withColumn(getOutputCol, struct(measureTuples: _*)).select(getOutputCol)
}

override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

override def transformSchema(schema: StructType): StructType = {
validateSchema(schema)

StructType(
StructField(getOutputCol,
StructType(AggregateMetrics.METRICS.map(StructField(_, DoubleType, nullable = true))), nullable = false) ::
Nil
)
}
}

object AggregateBalanceMeasure extends ComplexParamsReadable[AggregateBalanceMeasure]

//noinspection SpellCheckingInspection
private[exploratory] object AggregateMetrics {
val ATKINSONINDEX = "atkinson_index"
val THEILLINDEX = "theil_l_index"
val THEILTINDEX = "theil_t_index"

val METRICS = Seq(ATKINSONINDEX, THEILLINDEX, THEILTINDEX)
}

//noinspection SpellCheckingInspection
private[exploratory] case class AggregateMetrics(featureProbCol: String,
numFeatures: Double,
meanFeatures: Double,
epsilon: Double,
errorTolerance: Double) {

import AggregateMetrics._

private val normFeatureProbCol = col(featureProbCol) / meanFeatures

def toColumnMap: Map[String, Column] = Map(
ATKINSONINDEX -> atkinsonIndex.alias(ATKINSONINDEX),
THEILLINDEX -> theilLIndex.alias(THEILLINDEX),
THEILTINDEX -> theilTIndex.alias(THEILTINDEX)
)

def atkinsonIndex: Column = {
val alpha = 1d - epsilon
val productExpression = exp(sum(log(normFeatureProbCol)))
val powerMeanExpression = sum(pow(normFeatureProbCol, alpha)) / numFeatures
when(
abs(lit(alpha)) < errorTolerance,
lit(1d) - pow(productExpression, 1d / numFeatures)
).otherwise(
lit(1d) - pow(powerMeanExpression, 1d / alpha)
)
}

def theilLIndex: Column = {
val negativeSumLog = sum(log(normFeatureProbCol) * -1d)
negativeSumLog / numFeatures
}

def theilTIndex: Column = {
val sumLog = sum(normFeatureProbCol * log(normFeatureProbCol))
sumLog / numFeatures
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.exploratory

import org.apache.spark.ml.param.shared.HasOutputCol
import org.apache.spark.ml.param.{BooleanParam, Params, StringArrayParam}
import org.apache.spark.sql.types._

trait DataBalanceParams extends Params with HasOutputCol {
val sensitiveCols = new StringArrayParam(
this,
"sensitiveCols",
"Sensitive columns to use."
)

def getSensitiveCols: Array[String] = $(sensitiveCols)

def setSensitiveCols(values: Array[String]): this.type = set(sensitiveCols, values)

val verbose = new BooleanParam(
this,
"verbose",
"Whether to show intermediate measures and calculations, such as Positive Rate."
)

def getVerbose: Boolean = $(verbose)

def setVerbose(value: Boolean): this.type = set(verbose, value)

def setOutputCol(value: String): this.type = set(outputCol, value)

setDefault(
verbose -> false
)

def validateSchema(schema: StructType): Unit = {
getSensitiveCols.foreach {
c =>
schema(c).dataType match {
case ByteType | ShortType | IntegerType | LongType | StringType =>
case _ => throw new Exception(s"The sensitive column named $c does not contain integral or string values.")
}
}
}
}
Loading

0 comments on commit c5e1742

Please sign in to comment.