From 6aecdf1c0c212950344f210f11aea2dfb8760009 Mon Sep 17 00:00:00 2001 From: jtinkus <35308202+jtinkus@users.noreply.github.com> Date: Tue, 1 Jun 2021 20:13:14 +0300 Subject: [PATCH] Add sparse vector support to KNN. (#1063) Co-authored-by: Jako Tinkus Co-authored-by: Mark Hamilton --- .../ml/spark/nn/ConditionalKNN.scala | 9 +++-- .../scala/com/microsoft/ml/spark/nn/KNN.scala | 10 +++--- .../injections/OptimizedCKNNFitting.scala | 11 +++---- .../microsoft/ml/spark/nn/BallTreeTest.scala | 15 +++++++++ .../com/microsoft/ml/spark/nn/KNNTest.scala | 33 ++++++++++++------- 5 files changed, 50 insertions(+), 28 deletions(-) diff --git a/src/main/scala/com/microsoft/ml/spark/nn/ConditionalKNN.scala b/src/main/scala/com/microsoft/ml/spark/nn/ConditionalKNN.scala index d79ff557a3..7b25be7bd6 100644 --- a/src/main/scala/com/microsoft/ml/spark/nn/ConditionalKNN.scala +++ b/src/main/scala/com/microsoft/ml/spark/nn/ConditionalKNN.scala @@ -8,11 +8,11 @@ import com.microsoft.ml.spark.core.contracts.HasLabelCol import com.microsoft.ml.spark.logging.BasicLogging import org.apache.spark.broadcast.Broadcast import org.apache.spark.injections.UDFUtils -import org.apache.spark.ml.linalg.DenseVector +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.{ConditionalBallTreeParam, Param, ParamMap} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Estimator, Model} -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ import org.apache.spark.sql.types.injections.OptimizedCKNNFitting import org.apache.spark.sql.{DataFrame, Dataset, Row} @@ -63,11 +63,10 @@ class ConditionalKNN(override val uid: String) extends Estimator[ConditionalKNNM private[ml] object KNNFuncHolder { def queryFunc[L, V](bbt: Broadcast[ConditionalBallTree[L, V]], k: Int) - (dv: DenseVector, conditioner: Seq[L]): Seq[Row] = { - bbt.value.findMaximumInnerProducts(new BDV(dv.values), conditioner.toSet, k) + (v: Vector, conditioner: Seq[L]): Seq[Row] = { + bbt.value.findMaximumInnerProducts(new BDV(v.toDense.values), conditioner.toSet, k) .map(bm => Row(bbt.value.values(bm.index), bm.distance, bbt.value.labels(bm.index))) } - } class ConditionalKNNModel(val uid: String) extends Model[ConditionalKNNModel] diff --git a/src/main/scala/com/microsoft/ml/spark/nn/KNN.scala b/src/main/scala/com/microsoft/ml/spark/nn/KNN.scala index 9fa3b109f9..a4c3973a79 100644 --- a/src/main/scala/com/microsoft/ml/spark/nn/KNN.scala +++ b/src/main/scala/com/microsoft/ml/spark/nn/KNN.scala @@ -4,16 +4,16 @@ package com.microsoft.ml.spark.nn import breeze.linalg.{DenseVector => BDV} -import com.microsoft.ml.spark.core.contracts.{HasFeaturesCol, HasOutputCol} import com.microsoft.ml.spark.codegen.Wrappable +import com.microsoft.ml.spark.core.contracts.{HasFeaturesCol, HasOutputCol} import com.microsoft.ml.spark.logging.BasicLogging import org.apache.spark.broadcast.Broadcast import org.apache.spark.injections.UDFUtils import org.apache.spark.ml._ -import org.apache.spark.ml.linalg.DenseVector +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.util._ -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ import org.apache.spark.sql.types.injections.OptimizedKNNFitting import org.apache.spark.sql.{DataFrame, Dataset, Row} @@ -101,9 +101,9 @@ class KNNModel(val uid: String) extends Model[KNNModel] if (broadcastedModelOption.isEmpty) { broadcastedModelOption = Some(dataset.sparkSession.sparkContext.broadcast(getBallTree)) } - val getNeighborUDF = UDFUtils.oldUdf({ dv: DenseVector => + val getNeighborUDF = UDFUtils.oldUdf({ v: Vector => val bt = broadcastedModelOption.get.value - bt.findMaximumInnerProducts(new BDV(dv.values), getK) + bt.findMaximumInnerProducts(new BDV(v.toDense.values), getK) .map(bm => Row(bt.values(bm.index), bm.distance)) }, ArrayType(new StructType() .add("value", dataset.schema(getValuesCol).dataType) diff --git a/src/main/scala/org/apache/spark/sql/types/injections/OptimizedCKNNFitting.scala b/src/main/scala/org/apache/spark/sql/types/injections/OptimizedCKNNFitting.scala index fdbf4e1773..42d167750d 100644 --- a/src/main/scala/org/apache/spark/sql/types/injections/OptimizedCKNNFitting.scala +++ b/src/main/scala/org/apache/spark/sql/types/injections/OptimizedCKNNFitting.scala @@ -3,12 +3,11 @@ package org.apache.spark.sql.types.injections -import com.microsoft.ml.spark.nn._ -import org.apache.spark.ml.linalg.DenseVector -import org.apache.spark.sql.Dataset import breeze.linalg.{DenseVector => BDV} import com.microsoft.ml.spark.logging.BasicLogging -import org.apache.spark.internal.Logging +import com.microsoft.ml.spark.nn._ +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.sql.Dataset import org.apache.spark.sql.types._ trait OptimizedCKNNFitting extends ConditionalKNNParams with BasicLogging { @@ -17,7 +16,7 @@ trait OptimizedCKNNFitting extends ConditionalKNNParams with BasicLogging { val kvlTriples = dataset.toDF().select(getFeaturesCol, getValuesCol, getLabelCol).collect() .map { row => - val bdv = new BDV(row.getAs[DenseVector](getFeaturesCol).values) + val bdv = new BDV(row.getAs[Vector](getFeaturesCol).toDense.values) val value = row.getAs[V](getValuesCol) val label = row.getAs[L](getLabelCol) (bdv, value, label) @@ -54,7 +53,7 @@ trait OptimizedKNNFitting extends KNNParams with BasicLogging { val kvlTuples = dataset.toDF().select(getFeaturesCol, getValuesCol).collect() .map { row => - val bdv = new BDV(row.getAs[DenseVector](getFeaturesCol).values) + val bdv = new BDV(row.getAs[Vector](getFeaturesCol).toDense.values) val value = row.getAs[V](getValuesCol) (bdv, value) } diff --git a/src/test/scala/com/microsoft/ml/spark/nn/BallTreeTest.scala b/src/test/scala/com/microsoft/ml/spark/nn/BallTreeTest.scala index 81f54846dd..e389d54aa2 100644 --- a/src/test/scala/com/microsoft/ml/spark/nn/BallTreeTest.scala +++ b/src/test/scala/com/microsoft/ml/spark/nn/BallTreeTest.scala @@ -7,6 +7,7 @@ import breeze.linalg.DenseVector import com.microsoft.ml.spark.core.test.base.TestBase import org.apache.spark.ml.linalg.{DenseVector => SDV} import org.apache.spark.sql.functions.lit + import scala.collection.immutable trait BallTreeTestBase extends TestBase { @@ -28,6 +29,7 @@ trait BallTreeTestBase extends TestBase { def twoClassStringLabels(data: IndexedSeq[_]): IndexedSeq[String] = twoClassLabels(data).map(_.toString) + def randomClassLabels(data: IndexedSeq[_], nClasses: Int): IndexedSeq[Int] = { val r = scala.util.Random data.map(_ => r.nextInt(nClasses)) @@ -54,6 +56,12 @@ trait BallTreeTestBase extends TestBase { )) .toDF("features", "values", "labels") + lazy val dfSparse = spark + .createDataFrame(uniformData.zip(uniformLabels).map(p => + (new SDV(p._1.data).toSparse, "foo", p._2) + )) + .toDF("features", "values", "labels") + lazy val stringDF = spark .createDataFrame(uniformData.zip(uniformLabels).map(p => (new SDV(p._1.data), "foo", "class1") @@ -67,6 +75,13 @@ trait BallTreeTestBase extends TestBase { .toDF("features", "values", "labels") .withColumn("conditioner", lit(Array(0, 1))) + lazy val testDFSparse = spark + .createDataFrame(uniformData.zip(uniformLabels).take(5).map(p => + (new SDV(p._1.data).toSparse, "foo", p._2) + )) + .toDF("features", "values", "labels") + .withColumn("conditioner", lit(Array(0, 1))) + lazy val testStringDF = spark .createDataFrame(uniformData.zip(uniformLabels).take(5).map(p => (new SDV(p._1.data), "foo", "class1") diff --git a/src/test/scala/com/microsoft/ml/spark/nn/KNNTest.scala b/src/test/scala/com/microsoft/ml/spark/nn/KNNTest.scala index aee4c0e6eb..735a37ee43 100644 --- a/src/test/scala/com/microsoft/ml/spark/nn/KNNTest.scala +++ b/src/test/scala/com/microsoft/ml/spark/nn/KNNTest.scala @@ -7,7 +7,6 @@ import com.microsoft.ml.spark.core.test.fuzzing.{EstimatorFuzzing, TestObject} import org.apache.spark.ml.util.MLReadable import org.apache.spark.sql.{DataFrame, Row} import org.scalactic.Equality -import org.scalatest.Assertion class KNNTest extends EstimatorFuzzing[KNN] with BallTreeTestBase { @@ -15,17 +14,19 @@ class KNNTest extends EstimatorFuzzing[KNN] with BallTreeTestBase { val results = new KNN().setOutputCol("matches") .fit(df).transform(testDF) .select("matches").collect() - val sparkResults = results.map(r => + val resultsSparse = new KNN().setOutputCol("matches") + .fit(dfSparse).transform(testDFSparse) + .select("matches").collect() + val sparkResults = List(results, resultsSparse).map(_.map(r => r.getSeq[Row](0).map(mr => mr.getDouble(1)) - ) + )) val tree = BallTree(uniformData, uniformData.indices) val nonSparkResults = uniformData.take(5).map( point => tree.findMaximumInnerProducts(point, 5) ) - - sparkResults.zip(nonSparkResults).foreach { case (sr, nsr) => + sparkResults.map(_.zip(nonSparkResults).foreach { case (sr, nsr) => assert(sr === nsr.map(_.distance)) - } + }) } override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = { @@ -36,7 +37,9 @@ class KNNTest extends EstimatorFuzzing[KNN] with BallTreeTestBase { } override def testObjects(): Seq[TestObject[KNN]] = - List(new TestObject(new KNN().setOutputCol("matches"), df, testDF)) + List( + new TestObject(new KNN().setOutputCol("matches"), df, testDF), + new TestObject(new KNN().setOutputCol("matches"), dfSparse, testDFSparse)) override def reader: MLReadable[_] = KNN @@ -49,18 +52,22 @@ class ConditionalKNNTest extends EstimatorFuzzing[ConditionalKNN] with BallTreeT val results = new ConditionalKNN().setOutputCol("matches") .fit(df).transform(testDF) .select("matches").collect() - val sparkResults = results.map(r => + val resultsSparse = new ConditionalKNN().setOutputCol("matches") + .fit(dfSparse).transform(testDFSparse) + .select("matches").collect() + + val sparkResults = List(results, resultsSparse).map(_.map(r => r.getSeq[Row](0).map(mr => (mr.getDouble(1), mr.getInt(2))) - ) + )) val tree = ConditionalBallTree(uniformData, uniformData.indices, uniformLabels) val nonSparkResults = uniformData.take(5).map( point => tree.findMaximumInnerProducts(point, Set(0, 1), 5) ) - sparkResults.zip(nonSparkResults).foreach { case (sr, nsr) => + sparkResults.map(_.zip(nonSparkResults).foreach { case (sr, nsr) => assert(sr.map(_._1) === nsr.map(_.distance)) assert(sr.forall(p => Set(1, 0)(p._2))) - } + }) } override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = { @@ -71,7 +78,9 @@ class ConditionalKNNTest extends EstimatorFuzzing[ConditionalKNN] with BallTreeT } override def testObjects(): Seq[TestObject[ConditionalKNN]] = - List(new TestObject(new ConditionalKNN().setOutputCol("matches"), df, testDF)) + List( + new TestObject(new ConditionalKNN().setOutputCol("matches"), df, testDF), + new TestObject(new ConditionalKNN().setOutputCol("matches"), dfSparse, testDFSparse)) override def reader: MLReadable[_] = ConditionalKNN