Skip to content

Commit

Permalink
Add sparse vector support to KNN. (#1063)
Browse files Browse the repository at this point in the history
Co-authored-by: Jako Tinkus <[email protected]>
Co-authored-by: Mark Hamilton <[email protected]>
  • Loading branch information
3 people authored Jun 1, 2021
1 parent ab15ca4 commit 6aecdf1
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 28 deletions.
9 changes: 4 additions & 5 deletions src/main/scala/com/microsoft/ml/spark/nn/ConditionalKNN.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ import com.microsoft.ml.spark.core.contracts.HasLabelCol
import com.microsoft.ml.spark.logging.BasicLogging
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.injections.UDFUtils
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.{ConditionalBallTreeParam, Param, ParamMap}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Estimator, Model}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.injections.OptimizedCKNNFitting
import org.apache.spark.sql.{DataFrame, Dataset, Row}
Expand Down Expand Up @@ -63,11 +63,10 @@ class ConditionalKNN(override val uid: String) extends Estimator[ConditionalKNNM

private[ml] object KNNFuncHolder {
def queryFunc[L, V](bbt: Broadcast[ConditionalBallTree[L, V]], k: Int)
(dv: DenseVector, conditioner: Seq[L]): Seq[Row] = {
bbt.value.findMaximumInnerProducts(new BDV(dv.values), conditioner.toSet, k)
(v: Vector, conditioner: Seq[L]): Seq[Row] = {
bbt.value.findMaximumInnerProducts(new BDV(v.toDense.values), conditioner.toSet, k)
.map(bm => Row(bbt.value.values(bm.index), bm.distance, bbt.value.labels(bm.index)))
}

}

class ConditionalKNNModel(val uid: String) extends Model[ConditionalKNNModel]
Expand Down
10 changes: 5 additions & 5 deletions src/main/scala/com/microsoft/ml/spark/nn/KNN.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
package com.microsoft.ml.spark.nn

import breeze.linalg.{DenseVector => BDV}
import com.microsoft.ml.spark.core.contracts.{HasFeaturesCol, HasOutputCol}
import com.microsoft.ml.spark.codegen.Wrappable
import com.microsoft.ml.spark.core.contracts.{HasFeaturesCol, HasOutputCol}
import com.microsoft.ml.spark.logging.BasicLogging
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.injections.UDFUtils
import org.apache.spark.ml._
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.injections.OptimizedKNNFitting
import org.apache.spark.sql.{DataFrame, Dataset, Row}
Expand Down Expand Up @@ -101,9 +101,9 @@ class KNNModel(val uid: String) extends Model[KNNModel]
if (broadcastedModelOption.isEmpty) {
broadcastedModelOption = Some(dataset.sparkSession.sparkContext.broadcast(getBallTree))
}
val getNeighborUDF = UDFUtils.oldUdf({ dv: DenseVector =>
val getNeighborUDF = UDFUtils.oldUdf({ v: Vector =>
val bt = broadcastedModelOption.get.value
bt.findMaximumInnerProducts(new BDV(dv.values), getK)
bt.findMaximumInnerProducts(new BDV(v.toDense.values), getK)
.map(bm => Row(bt.values(bm.index), bm.distance))
}, ArrayType(new StructType()
.add("value", dataset.schema(getValuesCol).dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@

package org.apache.spark.sql.types.injections

import com.microsoft.ml.spark.nn._
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.sql.Dataset
import breeze.linalg.{DenseVector => BDV}
import com.microsoft.ml.spark.logging.BasicLogging
import org.apache.spark.internal.Logging
import com.microsoft.ml.spark.nn._
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types._

trait OptimizedCKNNFitting extends ConditionalKNNParams with BasicLogging {
Expand All @@ -17,7 +16,7 @@ trait OptimizedCKNNFitting extends ConditionalKNNParams with BasicLogging {

val kvlTriples = dataset.toDF().select(getFeaturesCol, getValuesCol, getLabelCol).collect()
.map { row =>
val bdv = new BDV(row.getAs[DenseVector](getFeaturesCol).values)
val bdv = new BDV(row.getAs[Vector](getFeaturesCol).toDense.values)
val value = row.getAs[V](getValuesCol)
val label = row.getAs[L](getLabelCol)
(bdv, value, label)
Expand Down Expand Up @@ -54,7 +53,7 @@ trait OptimizedKNNFitting extends KNNParams with BasicLogging {

val kvlTuples = dataset.toDF().select(getFeaturesCol, getValuesCol).collect()
.map { row =>
val bdv = new BDV(row.getAs[DenseVector](getFeaturesCol).values)
val bdv = new BDV(row.getAs[Vector](getFeaturesCol).toDense.values)
val value = row.getAs[V](getValuesCol)
(bdv, value)
}
Expand Down
15 changes: 15 additions & 0 deletions src/test/scala/com/microsoft/ml/spark/nn/BallTreeTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import breeze.linalg.DenseVector
import com.microsoft.ml.spark.core.test.base.TestBase
import org.apache.spark.ml.linalg.{DenseVector => SDV}
import org.apache.spark.sql.functions.lit

import scala.collection.immutable

trait BallTreeTestBase extends TestBase {
Expand All @@ -28,6 +29,7 @@ trait BallTreeTestBase extends TestBase {

def twoClassStringLabels(data: IndexedSeq[_]): IndexedSeq[String] =
twoClassLabels(data).map(_.toString)

def randomClassLabels(data: IndexedSeq[_], nClasses: Int): IndexedSeq[Int] = {
val r = scala.util.Random
data.map(_ => r.nextInt(nClasses))
Expand All @@ -54,6 +56,12 @@ trait BallTreeTestBase extends TestBase {
))
.toDF("features", "values", "labels")

lazy val dfSparse = spark
.createDataFrame(uniformData.zip(uniformLabels).map(p =>
(new SDV(p._1.data).toSparse, "foo", p._2)
))
.toDF("features", "values", "labels")

lazy val stringDF = spark
.createDataFrame(uniformData.zip(uniformLabels).map(p =>
(new SDV(p._1.data), "foo", "class1")
Expand All @@ -67,6 +75,13 @@ trait BallTreeTestBase extends TestBase {
.toDF("features", "values", "labels")
.withColumn("conditioner", lit(Array(0, 1)))

lazy val testDFSparse = spark
.createDataFrame(uniformData.zip(uniformLabels).take(5).map(p =>
(new SDV(p._1.data).toSparse, "foo", p._2)
))
.toDF("features", "values", "labels")
.withColumn("conditioner", lit(Array(0, 1)))

lazy val testStringDF = spark
.createDataFrame(uniformData.zip(uniformLabels).take(5).map(p =>
(new SDV(p._1.data), "foo", "class1")
Expand Down
33 changes: 21 additions & 12 deletions src/test/scala/com/microsoft/ml/spark/nn/KNNTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,26 @@ import com.microsoft.ml.spark.core.test.fuzzing.{EstimatorFuzzing, TestObject}
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.{DataFrame, Row}
import org.scalactic.Equality
import org.scalatest.Assertion

class KNNTest extends EstimatorFuzzing[KNN] with BallTreeTestBase {

test("matches non spark result") {
val results = new KNN().setOutputCol("matches")
.fit(df).transform(testDF)
.select("matches").collect()
val sparkResults = results.map(r =>
val resultsSparse = new KNN().setOutputCol("matches")
.fit(dfSparse).transform(testDFSparse)
.select("matches").collect()
val sparkResults = List(results, resultsSparse).map(_.map(r =>
r.getSeq[Row](0).map(mr => mr.getDouble(1))
)
))
val tree = BallTree(uniformData, uniformData.indices)
val nonSparkResults = uniformData.take(5).map(
point => tree.findMaximumInnerProducts(point, 5)
)

sparkResults.zip(nonSparkResults).foreach { case (sr, nsr) =>
sparkResults.map(_.zip(nonSparkResults).foreach { case (sr, nsr) =>
assert(sr === nsr.map(_.distance))
}
})
}

override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = {
Expand All @@ -36,7 +37,9 @@ class KNNTest extends EstimatorFuzzing[KNN] with BallTreeTestBase {
}

override def testObjects(): Seq[TestObject[KNN]] =
List(new TestObject(new KNN().setOutputCol("matches"), df, testDF))
List(
new TestObject(new KNN().setOutputCol("matches"), df, testDF),
new TestObject(new KNN().setOutputCol("matches"), dfSparse, testDFSparse))

override def reader: MLReadable[_] = KNN

Expand All @@ -49,18 +52,22 @@ class ConditionalKNNTest extends EstimatorFuzzing[ConditionalKNN] with BallTreeT
val results = new ConditionalKNN().setOutputCol("matches")
.fit(df).transform(testDF)
.select("matches").collect()
val sparkResults = results.map(r =>
val resultsSparse = new ConditionalKNN().setOutputCol("matches")
.fit(dfSparse).transform(testDFSparse)
.select("matches").collect()

val sparkResults = List(results, resultsSparse).map(_.map(r =>
r.getSeq[Row](0).map(mr => (mr.getDouble(1), mr.getInt(2)))
)
))
val tree = ConditionalBallTree(uniformData, uniformData.indices, uniformLabels)
val nonSparkResults = uniformData.take(5).map(
point => tree.findMaximumInnerProducts(point, Set(0, 1), 5)
)

sparkResults.zip(nonSparkResults).foreach { case (sr, nsr) =>
sparkResults.map(_.zip(nonSparkResults).foreach { case (sr, nsr) =>
assert(sr.map(_._1) === nsr.map(_.distance))
assert(sr.forall(p => Set(1, 0)(p._2)))
}
})
}

override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = {
Expand All @@ -71,7 +78,9 @@ class ConditionalKNNTest extends EstimatorFuzzing[ConditionalKNN] with BallTreeT
}

override def testObjects(): Seq[TestObject[ConditionalKNN]] =
List(new TestObject(new ConditionalKNN().setOutputCol("matches"), df, testDF))
List(
new TestObject(new ConditionalKNN().setOutputCol("matches"), df, testDF),
new TestObject(new ConditionalKNN().setOutputCol("matches"), dfSparse, testDFSparse))

override def reader: MLReadable[_] = ConditionalKNN

Expand Down

0 comments on commit 6aecdf1

Please sign in to comment.