diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/BaseDiffInDiffEstimator.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/BaseDiffInDiffEstimator.scala index fcb10389e5..8f87f05cda 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/BaseDiffInDiffEstimator.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/BaseDiffInDiffEstimator.scala @@ -5,11 +5,12 @@ package com.microsoft.azure.synapse.ml.causal import com.microsoft.azure.synapse.ml.causal.linalg.DVector import com.microsoft.azure.synapse.ml.codegen.Wrappable +import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging} import com.microsoft.azure.synapse.ml.param.DataFrameParam import org.apache.spark.SparkException import org.apache.spark.ml.feature.VectorAssembler -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.{ParamMap, Params} import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.ml.util.Identifiable import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Estimator, Model} @@ -43,7 +44,7 @@ abstract class BaseDiffInDiffEstimator(override val uid: String) override def copy(extra: ParamMap): Estimator[DiffInDiffModel] = defaultCopy(extra) - private[causal] val interactionCol = "interaction" + private[causal] val findInteractionCol = DatasetExtensions.findUnusedColumnName("interaction") _ private[causal] def fitLinearModel(df: DataFrame, featureCols: Array[String], @@ -57,8 +58,10 @@ abstract class BaseDiffInDiffEstimator(override val uid: String) .map(new LinearRegression().setWeightCol) .getOrElse(new LinearRegression()) + val featuresCol = DatasetExtensions.findUnusedColumnName("features", df) + regression - .setFeaturesCol("features") + .setFeaturesCol(featuresCol) .setLabelCol(getOutcomeCol) .setFitIntercept(fitIntercept) .setLoss("squaredError") @@ -162,3 +165,8 @@ class DiffInDiffModel(override val uid: String) } object DiffInDiffModel extends ComplexParamsReadable[DiffInDiffModel] + +trait DiffInDiffEstimatorParams extends Params + with HasTreatmentCol + with HasOutcomeCol + with HasPostTreatmentCol diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DiffInDiffEstimator.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DiffInDiffEstimator.scala index b1ea6652ab..6fe92fb5fd 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DiffInDiffEstimator.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DiffInDiffEstimator.scala @@ -4,6 +4,7 @@ package com.microsoft.azure.synapse.ml.causal import com.microsoft.azure.synapse.ml.codegen.Wrappable +import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging} import org.apache.spark.ml.util.Identifiable import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable} @@ -22,6 +23,7 @@ class DiffInDiffEstimator(override val uid: String) def this() = this(Identifiable.randomUID("DiffInDiffEstimator")) override def fit(dataset: Dataset[_]): DiffInDiffModel = logFit({ + val interactionCol = findInteractionCol(dataset.columns.toSet) val postTreatment = col(getPostTreatmentCol) val treatment = col(getTreatmentCol) val outcome = col(getOutcomeCol) diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DiffInDiffEstimatorParams.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DiffInDiffEstimatorParams.scala deleted file mode 100644 index a26f6128c1..0000000000 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DiffInDiffEstimatorParams.scala +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright (C) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See LICENSE in project root for information. - -package com.microsoft.azure.synapse.ml.causal - -import org.apache.spark.ml.param.Params - -trait DiffInDiffEstimatorParams extends Params - with HasTreatmentCol - with HasOutcomeCol - with HasPostTreatmentCol diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/SyntheticControlEstimator.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/SyntheticControlEstimator.scala index c28b463588..befd2b4b76 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/SyntheticControlEstimator.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/SyntheticControlEstimator.scala @@ -57,6 +57,8 @@ class SyntheticControlEstimator(override val uid: String) val indexedDf = df.join(unitIdx, df(getUnitCol) === unitIdx(getUnitCol), "left_outer") + val interactionCol = findInteractionCol(indexedDf.columns.toSet) + val didData = indexedDf.select( col(getTimeCol), col(UnitIdxCol), diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/SyntheticDiffInDiffEstimator.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/SyntheticDiffInDiffEstimator.scala index d099f31877..5f81e0e552 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/SyntheticDiffInDiffEstimator.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/SyntheticDiffInDiffEstimator.scala @@ -70,6 +70,8 @@ class SyntheticDiffInDiffEstimator(override val uid: String) val indexedDf = df.join(timeIdx, df(getTimeCol) === timeIdx(getTimeCol), "left_outer") .join(unitIdx, df(getUnitCol) === unitIdx(getUnitCol), "left_outer") + val interactionCol = findInteractionCol(indexedDf.columns.toSet) + val didData = indexedDf.select( col(UnitIdxCol), col(TimeIdxCol), diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/SyntheticEstimatorParams.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/SyntheticEstimatorParams.scala index 13a16df35c..68a1bbf316 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/SyntheticEstimatorParams.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/causal/SyntheticEstimatorParams.scala @@ -3,11 +3,9 @@ package com.microsoft.azure.synapse.ml.causal -import org.apache.spark.ml.param.{IntParam, LongParam, Param, ParamValidators, Params} +import org.apache.spark.ml.param.{DoubleParam, IntParam, LongParam, Param, ParamValidators, Params} import org.apache.spark.ml.param.shared.{HasMaxIter, HasStepSize, HasTol} -import scala.util.Random - trait SyntheticEstimatorParams extends Params with HasUnitCol with HasTimeCol @@ -48,6 +46,15 @@ trait SyntheticEstimatorParams extends Params /** @group expertGetParam */ def setLocalSolverThreshold(value: Long): this.type = set(localSolverThreshold, value) + final val epsilon = new DoubleParam(this, "epsilon", + "This value is added to the weights when we fit the final linear model for " + + "SyntheticControlEstimator and SyntheticDiffInDiffEstimator in order to avoid " + + "zero weights.", ParamValidators.gt(0d)) + + def getEpsilon: Double = $(epsilon) + + def setEpsilon(value: Double): this.type = set(epsilon, value) + def setMaxIter(value: Int): this.type = set(maxIter, value) def setStepSize(value: Double): this.type = set(stepSize, value) @@ -59,6 +66,7 @@ trait SyntheticEstimatorParams extends Params tol -> 1E-3, maxIter -> 100, handleMissingOutcome -> "zero", - localSolverThreshold -> 1000 * 1000 + localSolverThreshold -> 1000 * 1000, + epsilon -> 1E-10 ) }