Skip to content

Commit

Permalink
feat: [DistributionBalanceMeasure] Add implementation + unit tests fo…
Browse files Browse the repository at this point in the history
…r custom reference distribution (#1885)

* [DistributionBalanceMeasure] Add implementation + unit tests for custom reference distribution

* Handle edge cases for Entropy + Chi^2 calculations (addresses TODO), use isDefined instead of isEmpty

* Use empty map instead of null to specify uniform distribution (due to pyspark error with ArrayMapParam)

* Add setter that accepts Array[Map[String, Double]]; update unit tests to use it (fixes testGettersAndSetters failure)

---------

Co-authored-by: Patel, Kashyap M <[email protected]>
  • Loading branch information
ms-kashyap and kashmoneygt authored Mar 24, 2023
1 parent 412620a commit 979c629
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@ import breeze.stats.distributions.ChiSquared
import com.microsoft.azure.synapse.ml.codegen.Wrappable
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
import com.microsoft.azure.synapse.ml.param.ArrayMapParam
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer}
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

import java.util
import scala.collection.JavaConverters._
import scala.language.postfixOps

/** This transformer computes data balance measures based on a reference distribution.
Expand Down Expand Up @@ -56,6 +59,27 @@ class DistributionBalanceMeasure(override val uid: String)

def setFeatureNameCol(value: String): this.type = set(featureNameCol, value)

val referenceDistribution = new ArrayMapParam(
this,
"referenceDistribution",
"An ordered list of reference distributions that correspond to each of the sensitive columns."
)

val emptyReferenceDistribution: Array[Map[String, Double]] = Array.empty

def getReferenceDistribution: Array[Map[String, Double]] =
if (isDefined(referenceDistribution))
$(referenceDistribution).map(_.mapValues(_.asInstanceOf[Double]).map(identity))
else emptyReferenceDistribution

def setReferenceDistribution(value: Array[Map[String, Double]]): this.type =
set(referenceDistribution, value.map(_.mapValues(_.asInstanceOf[Any])))

def setReferenceDistribution(value: util.ArrayList[util.HashMap[String, Double]]): this.type = {
val arrayMap = value.asScala.toArray.map(_.asScala.toMap.mapValues(_.asInstanceOf[Any]))
set(referenceDistribution, arrayMap)
}

setDefault(
featureNameCol -> "FeatureName",
outputCol -> "DistributionBalanceMeasure"
Expand All @@ -68,6 +92,15 @@ class DistributionBalanceMeasure(override val uid: String)
}
}

private val customDistribution: Map[String, Double] => String => Double = {
dist: Map[String, Double] => {
// NOTE: If the custom distribution doesn't have the col value, return a default probability of 0
// This assumes that the reference distribution does not contain the col value at all
s: String =>
dist.getOrElse(s, 0d)
}
}

override def transform(dataset: Dataset[_]): DataFrame = {
logTransform[DataFrame]({
validateSchema(dataset.schema)
Expand All @@ -89,30 +122,30 @@ class DistributionBalanceMeasure(override val uid: String)
if (getVerbose)
featureStats.cache.show(numRows = 20, truncate = false) //scalastyle:ignore magic.number

// TODO (for v2): Introduce a referenceDistribution function param for user to override the uniform distribution
val referenceDistribution = uniformDistribution

df.unpersist
calculateDistributionMeasures(featureStats, featureProbCol, featureCountCol, numRows, referenceDistribution)
calculateDistributionMeasures(featureStats, featureProbCol, featureCountCol, numRows)
})
}

private def calculateDistributionMeasures(featureStats: DataFrame,
obsFeatureProbCol: String,
obsFeatureCountCol: String,
numRows: Double,
referenceDistribution: Int => String => Double): DataFrame = {
val distributionMeasures = getSensitiveCols.map {
sensitiveCol =>
numRows: Double): DataFrame = {
val distributionMeasures = getSensitiveCols.zipWithIndex.map {
case (sensitiveCol, i) =>
val observed = featureStats
.groupBy(sensitiveCol)
.agg(sum(obsFeatureProbCol).alias(obsFeatureProbCol), sum(obsFeatureCountCol).alias(obsFeatureCountCol))

val numFeatures = observed.count.toInt
val refDistFunc = udf(referenceDistribution(numFeatures))
val refFeatureProbCol = DatasetExtensions.findUnusedColumnName("refFeatureProb", featureStats.schema)
val refFeatureCountCol = DatasetExtensions.findUnusedColumnName("refFeatureCount", featureStats.schema)

val refDist: String => Double =
if (!isDefined(referenceDistribution) || getReferenceDistribution(i).isEmpty) uniformDistribution(numFeatures)
else customDistribution(getReferenceDistribution(i))
val refDistFunc = udf(refDist)

val observedWithRef = observed
.withColumn(refFeatureProbCol, refDistFunc(col(sensitiveCol)))
.withColumn(refFeatureCountCol, refDistFunc(col(sensitiveCol)) * lit(numRows))
Expand Down Expand Up @@ -146,6 +179,15 @@ class DistributionBalanceMeasure(override val uid: String)
Nil
)
}

override def validateSchema(schema: StructType): Unit = {
super.validateSchema(schema)

if (isDefined(referenceDistribution) && getReferenceDistribution.length != getSensitiveCols.length) {
throw new Exception("The reference distribution must have the same length and order as the sensitive columns: "
+ getSensitiveCols.mkString(", "))
}
}
}

object DistributionBalanceMeasure extends ComplexParamsReadable[DistributionBalanceMeasure]
Expand Down Expand Up @@ -212,23 +254,32 @@ private[exploratory] case class DistributionMetrics(numFeatures: Int,
}

// Calculates Pearson's chi-squared statistic
def chiSquaredTestStatistic: Column =
sum(pow(col(obsFeatureCountCol) - col(refFeatureCountCol), 2) / col(refFeatureCountCol))
def chiSquaredTestStatistic: Column = sum(
// If expected is zero and observed is not zero, the test assumes observed is impossible so Chi^2 value becomes +inf
when(col(refFeatureCountCol) === 0 && col(obsFeatureCountCol) =!= 0, lit(Double.PositiveInfinity))
.otherwise(pow(col(obsFeatureCountCol) - col(refFeatureCountCol), 2) / col(refFeatureCountCol)))

// Calculates left-tailed p-value from degrees of freedom and chi-squared test statistic
def chiSquaredPValue: Column = {
val degOfFreedom = numFeatures - 1
val scoreCol = chiSquaredTestStatistic
val chiSqPValueUdf = udf({
score: Double =>
1d - ChiSquared(degOfFreedom).cdf(score)
})
val chiSqPValueUdf = udf(
(score: Double) => score match {
// limit of CDF as x approaches +inf is 1 (https://en.wikipedia.org/wiki/Cumulative_distribution_function)
case Double.PositiveInfinity => 1d
case _ => 1 - ChiSquared(degOfFreedom).cdf(score)
}
)
chiSqPValueUdf(scoreCol)
}

private def entropy(distA: Column, distB: Option[Column] = None): Column = {
if (distB.isDefined) {
sum(distA * log(distA / distB.get))
// Using same cases as scipy (https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.rel_entr.html)
val entropies = when(distA === 0d && distB.get >= 0d, lit(0d))
.when(distA > 0d && distB.get > 0d, distA * log(distA / distB.get))
.otherwise(lit(Double.PositiveInfinity))
sum(entropies)
} else {
sum(distA * log(distA)) * -1d
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,11 @@ case class AggregateMetricsCalculator(featureProbabilities: Array[Double], epsil
}
}

case class DistributionMetricsCalculator(obsFeatureProbabilities: Array[Double],
case class DistributionMetricsCalculator(refFeatureProbabilities: Array[Double],
refFeatureCounts: Array[Double],
obsFeatureProbabilities: Array[Double],
obsFeatureCounts: Array[Double],
numRows: Double) {
val numFeatures: Double = obsFeatureProbabilities.length
val refFeatureProbabilities: Array[Double] = Array.fill(numFeatures.toInt)(1d / numFeatures)

numFeatures: Double) {
val absDiffObsRef: Array[Double] = (obsFeatureProbabilities, refFeatureProbabilities).zipped.map((a, b) => abs(a - b))

val klDivergence: Double = entropy(obsFeatureProbabilities, Some(refFeatureProbabilities))
Expand All @@ -126,16 +125,22 @@ case class DistributionMetricsCalculator(obsFeatureProbabilities: Array[Double],
val infNormDistance: Double = absDiffObsRef.max
val totalVariationDistance: Double = 0.5d * absDiffObsRef.sum
val wassersteinDistance: Double = absDiffObsRef.sum / absDiffObsRef.length
val chiSquaredTestStatistic: Double = {
val refFeatureCount = numRows / numFeatures
obsFeatureCounts.map(o => pow(o - refFeatureCount, 2) / refFeatureCount).sum
val chiSquaredTestStatistic: Double = (obsFeatureCounts, refFeatureCounts).zipped.map((a, b) => pow(a - b, 2) / b).sum
val chiSquaredPValue: Double = chiSquaredTestStatistic match {
// limit of CDF as x approaches +inf is 1 (https://en.wikipedia.org/wiki/Cumulative_distribution_function)
case Double.PositiveInfinity => 1
case _ => 1 - ChiSquared(numFeatures - 1).cdf(chiSquaredTestStatistic)
}
val chiSquaredPValue: Double = 1 - ChiSquared(numFeatures - 1).cdf(chiSquaredTestStatistic)

def entropy(distA: Array[Double], distB: Option[Array[Double]] = None): Double = {
if (distB.isDefined) {
val logQuotient = (distA, distB.get).zipped.map((a, b) => log(a / b))
(distA, logQuotient).zipped.map(_ * _).sum
(distA, distB.get).zipped.map((a, b) =>
// Using cases from scipy.special.rel_entr, which scipy.stats.entropy directly calls
// https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.rel_entr.html
if (a == 0.0 && b >= 0.0) 0.0
else if (a > 0.0 && b > 0) a * log(a / b)
else Double.PositiveInfinity
).sum
} else {
-1d * distA.map(x => x * log(x)).sum
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ class DistributionBalanceMeasureSuite extends DataBalanceTestBase with Transform

private object ExpectedFeature1 {
// Values were computed using:
// val CALC =
// DistributionMetricsCalculator(expectedFeature1.map(_._1), expectedFeature1.map(_._2), sensitiveFeaturesDf.count)
// val (numRows, numFeatures) = (sensitiveFeaturesDf.count.toDouble, expectedFeature1.length)
// val (obsProbs, obsCounts) = expectedFeature1.unzip
// val (refProbs, refCounts) = Array.fill(numFeatures.toInt)(numFeatures).map(n => (1d / n, numRows / n)).unzip
// val CALC = DistributionMetricsCalculator(refProbs, refCounts, obsProbs, obsCounts, numFeatures)
val KLDIVERGENCE = 0.03775534151008829
val JSDISTANCE = 0.09785224086736323
val INFNORMDISTANCE = 0.1111111111111111
Expand Down Expand Up @@ -83,14 +85,16 @@ class DistributionBalanceMeasureSuite extends DataBalanceTestBase with Transform

private object ExpectedFeature2 {
// Values were computed using:
// val CALC =
// DistributionMetricsCalculator(expectedFeature2.map(_._1), expectedFeature2.map(_._2), sensitiveFeaturesDf.count)
// val (numRows, numFeatures) = (sensitiveFeaturesDf.count.toDouble, expectedFeature2.length)
// val (obsProbs, obsCounts) = expectedFeature2.unzip
// val (refProbs, refCounts) = Array.fill(numFeatures.toInt)(numFeatures).map(n => (1d / n, numRows / n)).unzip
// val CALC = DistributionMetricsCalculator(refProbs, refCounts, obsProbs, obsCounts, numFeatures)
val KLDIVERGENCE = 0.07551068302017659
val JSDISTANCE = 0.14172745151398888
val INFNORMDISTANCE = 0.1388888888888889
val TOTALVARIATIONDISTANCE = 0.16666666666666666
val WASSERSTEINDISTANCE = 0.08333333333333333
val CHISQUAREDTESTSTATISTIC = 1.222222222222222
val CHISQUAREDTESTSTATISTIC = 1.2222222222222223
val CHISQUAREDPVALUE = 0.7476795872877147
}

Expand All @@ -105,4 +109,139 @@ class DistributionBalanceMeasureSuite extends DataBalanceTestBase with Transform
assert(actual(CHISQUAREDTESTSTATISTIC) === expected.CHISQUAREDTESTSTATISTIC)
assert(actual(CHISQUAREDPVALUE) === expected.CHISQUAREDPVALUE)
}

// For each feature in sensitiveFeaturesDf (["Gender", "Ethnicity"]), need to specify its corresponding distribution
private def customDistribution: Array[Map[String, Double]] = Array(
// Index 0: Gender (all unique values included)
Map("Male" -> 0.25, "Female" -> 0.4, "Other" -> 0.35),
// Index 1: Ethnicity ('Other' value purposefully left out, which signals a probability of 0.0)
Map("Asian" -> 1/3d, "White" -> 1/3d, "Black" -> 1/3d)
)

test("DistributionBalanceMeasure can use a custom reference distribution for multiple cols") {
val df = distributionBalanceMeasure
.setReferenceDistribution(customDistribution)
.transform(sensitiveFeaturesDf)

df.show(truncate = false)
df.printSchema()
}

test("DistributionBalanceMeasure can use a custom distribution for one col and uniform for another") {
val customDist = customDistribution
// Keep custom distribution for Gender (index 0), and use uniform distribution for Ethnicity (index 1)
// Specifying empty map defaults to the uniform distribution
customDist.update(1, Map())

val df = distributionBalanceMeasure
.setReferenceDistribution(customDist)
.transform(sensitiveFeaturesDf)

df.show(truncate = false)
df.printSchema()
}

test("DistributionBalanceMeasure expects the custom distribution to be the same length as sensitive columns") {
val emptyDist: Array[Map[String, Double]] = Array.empty
assertThrows[Exception] {
distributionBalanceMeasure
.setReferenceDistribution(emptyDist)
.transform(sensitiveFeaturesDf)
}

val mismatchedLenDist = Array(Map("ColA" -> 0.25))
assertThrows[Exception] {
distributionBalanceMeasure
.setReferenceDistribution(mismatchedLenDist)
.transform(sensitiveFeaturesDf)
}
}

private def actualCustomDist: DataFrame =
new DistributionBalanceMeasure()
.setSensitiveCols(features)
.setVerbose(true)
.setReferenceDistribution(customDistribution)
.transform(sensitiveFeaturesDf)

private def actualCustomDistFeature1: Map[String, Double] =
METRICS zip actualCustomDist.filter(col("FeatureName") === feature1)
.select(array(col("DistributionBalanceMeasure.*")))
.as[Array[Double]]
.head toMap

private def expectedCustomDistFeature1 = getFeatureStats(sensitiveFeaturesDf.groupBy(feature1))
.select(feature1, featureProbCol, featureCountCol)
.as[(String, Double, Double)].collect()

private object ExpectedCustomDistFeature1 {
// Values were computed using:
// val (numRows, numFeatures) = (sensitiveFeaturesDf.count.toDouble, expectedCustomDistFeature1.length)
// val (featureValues, obsProbs, obsCounts) = expectedCustomDistFeature1.unzip3
// val refProbs = featureValues.map(customDistribution.get(0).getOrDefault(_, 0.0)) // idx 0 = Gender
// val refCounts = refProbs.map(_ * numRows)
// val CALC = DistributionMetricsCalculator(refProbs, refCounts, obsProbs, obsCounts, numFeatures)
val KLDIVERGENCE = 0.09399792940857671
val JSDISTANCE = 0.15001917759832653
val INFNORMDISTANCE = 0.19444444444444442
val TOTALVARIATIONDISTANCE = 0.19444444444444445
val WASSERSTEINDISTANCE = 0.12962962962962962
val CHISQUAREDTESTSTATISTIC = 1.880952380952381
val CHISQUAREDPVALUE = 0.3904418663854293
}

test(s"DistributionBalanceMeasure can use a custom reference distribution with all values ($feature1)") {
// The custom reference distribution for Gender is Map("Male" -> 0.25, "Female" -> 0.4, "Other" -> 0.35)
// This includes all unique values of Gender in the dataframe being transformed
val actual = actualCustomDistFeature1
val expected = ExpectedCustomDistFeature1
assert(actual(KLDIVERGENCE) === expected.KLDIVERGENCE)
assert(actual(JSDISTANCE) === expected.JSDISTANCE)
assert(actual(INFNORMDISTANCE) === expected.INFNORMDISTANCE)
assert(actual(TOTALVARIATIONDISTANCE) === expected.TOTALVARIATIONDISTANCE)
assert(actual(WASSERSTEINDISTANCE) === expected.WASSERSTEINDISTANCE)
assert(actual(CHISQUAREDTESTSTATISTIC) === expected.CHISQUAREDTESTSTATISTIC)
assert(actual(CHISQUAREDPVALUE) === expected.CHISQUAREDPVALUE)
}

private def actualCustomDistFeature2: Map[String, Double] =
METRICS zip actualCustomDist.filter(col("FeatureName") === feature2)
.select(array(col("DistributionBalanceMeasure.*")))
.as[Array[Double]]
.head toMap

private def expectedCustomDistFeature2 = getFeatureStats(sensitiveFeaturesDf.groupBy(feature2))
.select(feature2, featureProbCol, featureCountCol)
.as[(String, Double, Double)].collect()

private object ExpectedCustomDistFeature2 {
// Values were computed using:
// val (numRows, numFeatures) = (sensitiveFeaturesDf.count.toDouble, expectedCustomDistFeature2.length)
// val (featureValues, obsProbs, obsCounts) = expectedCustomDistFeature2.unzip3
// val refProbs = featureValues.map(customDistribution.get(1).getOrDefault(_, 0.0)) // idx 1 = Ethnicity
// val refCounts = refProbs.map(_ * numRows)
// val CALC = DistributionMetricsCalculator(refProbs, refCounts, obsProbs, obsCounts, numFeatures)
val KLDIVERGENCE = Double.PositiveInfinity
val JSDISTANCE = 0.2100032735609124
val INFNORMDISTANCE = 0.1111111111111111
val TOTALVARIATIONDISTANCE = 0.1111111111111111
val WASSERSTEINDISTANCE = 0.05555555555555555
val CHISQUAREDTESTSTATISTIC = Double.PositiveInfinity
val CHISQUAREDPVALUE = 1d
}

test(s"DistributionBalanceMeasure can a custom reference distribution with missing values ($feature2)") {
// The custom reference distribution for Ethnicity is Map("Asian" -> 0.33, "White" -> 0.33, "Black" -> 0.33)
// This does NOT include all unique values in the dataframe being transformed; 'Other' is left out
// which means that it should default to a reference probability of 0.00
val actual = actualCustomDistFeature2
val expected = ExpectedCustomDistFeature2
assert(actual(KLDIVERGENCE) === expected.KLDIVERGENCE)
assert(actual(JSDISTANCE) === expected.JSDISTANCE)
assert(actual(INFNORMDISTANCE) === expected.INFNORMDISTANCE)
assert(actual(TOTALVARIATIONDISTANCE) === expected.TOTALVARIATIONDISTANCE)
assert(actual(WASSERSTEINDISTANCE) === expected.WASSERSTEINDISTANCE)
assert(actual(CHISQUAREDTESTSTATISTIC) === expected.CHISQUAREDTESTSTATISTIC)
assert(actual(CHISQUAREDPVALUE) === expected.CHISQUAREDPVALUE)
}
}

0 comments on commit 979c629

Please sign in to comment.