Skip to content

Commit

Permalink
fix: Fix failing geospatial tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mhamilton723 committed Dec 10, 2021
1 parent f2bf567 commit de7da4f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ trait HasAddressInput extends HasServiceParams with HasSubscriptionKey with HasU
val post = new HttpPost(new URI(getUrl + queryParams))
post.setHeader("Content-Type", "application/json")
post.setHeader("User-Agent", s"synapseml/${BuildInfo.version}${HeaderValues.PlatformInfo}")
val addressesCol = getValueOpt(row, address)
val encodedAddresses = addressesCol.get.map(x => URLEncoder.encode(x, "UTF-8")).toList
val encodedAddresses = getValue(row, address).map(x => URLEncoder.encode(x, "UTF-8")).toList
val payloadItems = encodedAddresses.map(x => s"""{ "query": "?query=$x&limit=1" }""").mkString(",")
val payload = s"""{ "batchItems": [ $payloadItems ] }"""
post.setEntity(new StringEntity(payload))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import com.microsoft.azure.synapse.ml.stages.{FixedMiniBatchTransformer, Flatten
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
import org.scalactic.Equality

trait AzureMapsKey {
lazy val azureMapsKey: String = sys.env.getOrElse("AZURE_MAPS_KEY", Secrets.AzureMapsKey)
Expand Down Expand Up @@ -41,24 +42,33 @@ class AzureMapSearchSuite extends TransformerFuzzing[AddressGeocoder] with Azure
.setAddressCol("address")
.setOutputCol("output")

def extractFields(df: DataFrame): DataFrame = {
df.select(
col("address"),
col("output.response.results").getItem(0).getField("position")
.getField("lat").as("latitude"),
col("output.response.results").getItem(0).getField("position")
.getField("lon").as("longitude"))
}

test("Basic Batch Geocode Usage") {
val batchedDF = batchGeocodeAddresses.transform(new FixedMiniBatchTransformer().setBatchSize(5).transform(df))
val flattenedResults = new FlattenBatch().transform(batchedDF)
.select(
col("address"),
col("output.response.results").getItem(0).getField("position")
.getField("lat").as("latitude"),
col("output.response.results").getItem(0).getField("position")
.getField("lon").as("longitude"))
val flattenedResults = extractFields(new FlattenBatch().transform(batchedDF))
.collect()

assert(flattenedResults != null)
assert(flattenedResults.length == 15)
assert(flattenedResults.toSeq(0).get(1) == 47.64016)
assert(flattenedResults.toSeq.head.get(1) == 47.64016)
}

override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = {
super.assertDFEq(extractFields(df1), extractFields(df2))(eq)
}

override def testObjects(): Seq[TestObject[AddressGeocoder]] =
Seq(new TestObject[AddressGeocoder](batchGeocodeAddresses, df))
Seq(new TestObject[AddressGeocoder](
batchGeocodeAddresses,
new FixedMiniBatchTransformer().setBatchSize(5).transform(df)))

override def reader: MLReadable[_] = AddressGeocoder
}

0 comments on commit de7da4f

Please sign in to comment.