Skip to content

Commit

Permalink
addressing comments
Browse files Browse the repository at this point in the history
  • Loading branch information
memoryz committed Oct 30, 2023
1 parent 6553653 commit 97d55dd
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ package com.microsoft.azure.synapse.ml.causal

import com.microsoft.azure.synapse.ml.causal.linalg.DVector
import com.microsoft.azure.synapse.ml.codegen.Wrappable
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.DataFrameParam
import org.apache.spark.SparkException
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.{ParamMap, Params}
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Estimator, Model}
Expand Down Expand Up @@ -43,7 +44,7 @@ abstract class BaseDiffInDiffEstimator(override val uid: String)

override def copy(extra: ParamMap): Estimator[DiffInDiffModel] = defaultCopy(extra)

private[causal] val interactionCol = "interaction"
private[causal] val findInteractionCol = DatasetExtensions.findUnusedColumnName("interaction") _

private[causal] def fitLinearModel(df: DataFrame,
featureCols: Array[String],
Expand All @@ -57,8 +58,10 @@ abstract class BaseDiffInDiffEstimator(override val uid: String)
.map(new LinearRegression().setWeightCol)
.getOrElse(new LinearRegression())

val featuresCol = DatasetExtensions.findUnusedColumnName("features", df)

regression
.setFeaturesCol("features")
.setFeaturesCol(featuresCol)
.setLabelCol(getOutcomeCol)
.setFitIntercept(fitIntercept)
.setLoss("squaredError")
Expand Down Expand Up @@ -162,3 +165,8 @@ class DiffInDiffModel(override val uid: String)
}

object DiffInDiffModel extends ComplexParamsReadable[DiffInDiffModel]

trait DiffInDiffEstimatorParams extends Params
with HasTreatmentCol
with HasOutcomeCol
with HasPostTreatmentCol
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package com.microsoft.azure.synapse.ml.causal

import com.microsoft.azure.synapse.ml.codegen.Wrappable
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable}
Expand All @@ -22,6 +23,7 @@ class DiffInDiffEstimator(override val uid: String)
def this() = this(Identifiable.randomUID("DiffInDiffEstimator"))

override def fit(dataset: Dataset[_]): DiffInDiffModel = logFit({
val interactionCol = findInteractionCol(dataset.columns.toSet)
val postTreatment = col(getPostTreatmentCol)
val treatment = col(getTreatmentCol)
val outcome = col(getOutcomeCol)
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class SyntheticControlEstimator(override val uid: String)

val indexedDf = df.join(unitIdx, df(getUnitCol) === unitIdx(getUnitCol), "left_outer")

val interactionCol = findInteractionCol(indexedDf.columns.toSet)

val didData = indexedDf.select(
col(getTimeCol),
col(UnitIdxCol),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class SyntheticDiffInDiffEstimator(override val uid: String)
val indexedDf = df.join(timeIdx, df(getTimeCol) === timeIdx(getTimeCol), "left_outer")
.join(unitIdx, df(getUnitCol) === unitIdx(getUnitCol), "left_outer")

val interactionCol = findInteractionCol(indexedDf.columns.toSet)

val didData = indexedDf.select(
col(UnitIdxCol),
col(TimeIdxCol),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@

package com.microsoft.azure.synapse.ml.causal

import org.apache.spark.ml.param.{IntParam, LongParam, Param, ParamValidators, Params}
import org.apache.spark.ml.param.{DoubleParam, IntParam, LongParam, Param, ParamValidators, Params}
import org.apache.spark.ml.param.shared.{HasMaxIter, HasStepSize, HasTol}

import scala.util.Random

trait SyntheticEstimatorParams extends Params
with HasUnitCol
with HasTimeCol
Expand Down Expand Up @@ -48,6 +46,15 @@ trait SyntheticEstimatorParams extends Params
/** @group expertGetParam */
def setLocalSolverThreshold(value: Long): this.type = set(localSolverThreshold, value)

final val epsilon = new DoubleParam(this, "epsilon",
"This value is added to the weights when we fit the final linear model for " +
"SyntheticControlEstimator and SyntheticDiffInDiffEstimator in order to avoid " +
"zero weights.", ParamValidators.gt(0d))

def getEpsilon: Double = $(epsilon)

def setEpsilon(value: Double): this.type = set(epsilon, value)

def setMaxIter(value: Int): this.type = set(maxIter, value)

def setStepSize(value: Double): this.type = set(stepSize, value)
Expand All @@ -59,6 +66,7 @@ trait SyntheticEstimatorParams extends Params
tol -> 1E-3,
maxIter -> 100,
handleMissingOutcome -> "zero",
localSolverThreshold -> 1000 * 1000
localSolverThreshold -> 1000 * 1000,
epsilon -> 1E-10
)
}

0 comments on commit 97d55dd

Please sign in to comment.