From 7cd5c400cfee6c3d9595313557ce8f9fbdcc26fd Mon Sep 17 00:00:00 2001 From: Mark Hamilton Date: Mon, 21 Jun 2021 18:38:12 -0400 Subject: [PATCH] merge in changes --- build.sbt | 11 ++--------- .../ml/spark/cognitive/RESTHelpers.scala | 4 ++-- .../ml/spark/cognitive/SpeechAPI.scala | 4 ++-- .../ml/spark/cognitive/split1/FaceAPI.scala | 4 ++-- .../cognitive/split2/SearchWriterSuite.scala | 3 ++- .../core/test/base/SparkSessionFactory.scala | 19 ++++++++++++------- .../ml/spark/core/test/base/TestBase.scala | 3 ++- .../ml/spark/core/test/fuzzing/Fuzzing.scala | 2 +- .../spark/core/test/fuzzing/FuzzingTest.scala | 5 +++-- .../explainers/split1/SamplerSuite.scala | 4 ++-- .../ml/spark/featurize/VerifyFeaturize.scala | 2 +- .../flaky/PartitionConsolidatorSuite.scala | 4 ++-- .../ml/spark/image/ImageTestUtils.scala | 2 +- .../spark/io/split2/ContinuousHTTPSuite.scala | 2 -- .../io/split2/DistributedHTTPSuite.scala | 2 +- .../ml/spark/nbtest/DatabricksUtilities.scala | 4 ++-- .../spark/train/VerifyTrainClassifier.scala | 2 +- .../ml/spark/downloader/ModelDownloader.scala | 8 ++++---- .../explainers/ImageExplainersSuite.scala | 10 +++++----- .../split2}/ImageLIMEExplainerSuite.scala | 4 ++-- .../split3}/ImageSHAPExplainerSuite.scala | 4 ++-- .../spark/lightgbm/PartitionProcessor.scala | 0 .../spark/lightgbm/dataset/DatasetUtils.scala | 7 ++++--- 23 files changed, 55 insertions(+), 55 deletions(-) rename {core/src/main => deep-learning/src/test}/scala/com/microsoft/ml/spark/explainers/ImageExplainersSuite.scala (86%) rename {core/src/test/scala/com/microsoft/ml/spark/explainers/split3 => deep-learning/src/test/scala/com/microsoft/ml/spark/explainers/split2}/ImageLIMEExplainerSuite.scala (98%) rename {core/src/test/scala/com/microsoft/ml/spark/explainers/split2 => deep-learning/src/test/scala/com/microsoft/ml/spark/explainers/split3}/ImageSHAPExplainerSuite.scala (97%) rename {core => lightgbm}/src/main/scala/com/microsoft/ml/spark/lightgbm/PartitionProcessor.scala (100%) rename {core => lightgbm}/src/main/scala/com/microsoft/ml/spark/lightgbm/dataset/DatasetUtils.scala (98%) diff --git a/build.sbt b/build.sbt index a38129297ff..62eaa77d97d 100644 --- a/build.sbt +++ b/build.sbt @@ -1,24 +1,17 @@ -import java.io.{File, PrintWriter} +import java.io.File import java.net.URL - import org.apache.commons.io.FileUtils import sbt.ExclusionRule - import scala.xml.{Node => XmlNode, NodeSeq => XmlNodeSeq, _} import scala.xml.transform.{RewriteRule, RuleTransformer} import BuildUtils._ -import CodegenPlugin.autoImport.pythonizedVersion -import sbt.Project.projectToRef import xerial.sbt.Sonatype._ val condaEnvName = "mmlspark" -name := "mmlspark" -organization := "com.microsoft.ml.spark" -scalaVersion := "2.12.10" val sparkVersion = "3.1.2" +name := "mmlspark" ThisBuild / organization := "com.microsoft.ml.spark" ThisBuild / scalaVersion := "2.12.10" -val sparkVersion = "3.0.1" val scalaMajorVersion = 2.12 diff --git a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/RESTHelpers.scala b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/RESTHelpers.scala index 3e4dc4e4a14..01de211a8e0 100644 --- a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/RESTHelpers.scala +++ b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/RESTHelpers.scala @@ -59,11 +59,11 @@ object RESTHelpers { response } else { val requestBodyOpt = Try(request match { - case er: HttpEntityEnclosingRequestBase => IOUtils.toString(er.getEntity.getContent) + case er: HttpEntityEnclosingRequestBase => IOUtils.toString(er.getEntity.getContent, "UTF-8") case _ => "" }).get - val responseBodyOpt = Try(IOUtils.toString(response.getEntity.getContent)).getOrElse("") + val responseBodyOpt = Try(IOUtils.toString(response.getEntity.getContent, "UTF-8")).getOrElse("") throw new RuntimeException( s"Failed: " + diff --git a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/SpeechAPI.scala b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/SpeechAPI.scala index 361c63507cf..b240da1a95f 100644 --- a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/SpeechAPI.scala +++ b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/SpeechAPI.scala @@ -32,7 +32,7 @@ object SpeechAPI { using(Client.execute(request)) { response => if (!response.getStatusLine.getStatusCode.toString.startsWith("2")) { val bodyOpt = request match { - case er: HttpEntityEnclosingRequestBase => IOUtils.toString(er.getEntity.getContent) + case er: HttpEntityEnclosingRequestBase => IOUtils.toString(er.getEntity.getContent, "UTF-8") case _ => "" } throw new RuntimeException( @@ -40,7 +40,7 @@ object SpeechAPI { s"requestUrl: ${request.getURI}" + s"requestBody: $bodyOpt") } - IOUtils.toString(response.getEntity.getContent) + IOUtils.toString(response.getEntity.getContent, "UTF-8") .parseJson.asJsObject().fields("Signature").compactPrint }.get }) diff --git a/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/FaceAPI.scala b/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/FaceAPI.scala index 3cc8c4eefcf..3b1744c63f4 100644 --- a/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/FaceAPI.scala +++ b/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/FaceAPI.scala @@ -35,7 +35,7 @@ object FaceUtils extends CognitiveKey { using(Client.execute(request)) { response => if (!response.getStatusLine.getStatusCode.toString.startsWith("2")) { val bodyOpt = request match { - case er: HttpEntityEnclosingRequestBase => IOUtils.toString(er.getEntity.getContent) + case er: HttpEntityEnclosingRequestBase => IOUtils.toString(er.getEntity.getContent, "UTF-8") case _ => "" } throw new RuntimeException( @@ -43,7 +43,7 @@ object FaceUtils extends CognitiveKey { s"requestUrl: ${request.getURI}" + s"requestBody: $bodyOpt") } - IOUtils.toString(response.getEntity.getContent) + IOUtils.toString(response.getEntity.getContent, "UTF-8") }.get }) } diff --git a/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split2/SearchWriterSuite.scala b/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split2/SearchWriterSuite.scala index 0f543420bd2..9b8d91af8ae 100644 --- a/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split2/SearchWriterSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split2/SearchWriterSuite.scala @@ -14,7 +14,8 @@ import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing} import org.apache.http.client.methods.HttpDelete import org.apache.spark.ml.util.MLReadable import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.functions.{lit, udf, col, split} +import org.apache.spark.sql.functions.{col, lit, split, udf} + import scala.collection.mutable import scala.concurrent.blocking diff --git a/core/src/test/scala/com/microsoft/ml/spark/core/test/base/SparkSessionFactory.scala b/core/src/test/scala/com/microsoft/ml/spark/core/test/base/SparkSessionFactory.scala index 4b2de249739..031d1b333e4 100644 --- a/core/src/test/scala/com/microsoft/ml/spark/core/test/base/SparkSessionFactory.scala +++ b/core/src/test/scala/com/microsoft/ml/spark/core/test/base/SparkSessionFactory.scala @@ -31,19 +31,24 @@ object SparkSessionFactory { if (File.separator != "\\") path else path.replaceFirst("[A-Z]:", "").replace("\\", "/") } + def currentDir(): String = System.getProperty("user.dir") def getSession(name: String, logLevel: String = "WARN", numRetries: Int = 1, numCores: Option[Int] = None): SparkSession = { val cores = numCores.map(_.toString).getOrElse("*") val conf = new SparkConf() - .setAppName(name) - .setMaster(if (numRetries == 1){s"local[$cores]"}else{s"local[$cores, $numRetries]"}) - .set("spark.logConf", "true") - .set("spark.sql.shuffle.partitions", "20") - .set("spark.driver.maxResultSize", "6g") - .set("spark.sql.warehouse.dir", SparkSessionFactory.LocalWarehousePath) - .set("spark.sql.crossJoin.enabled", "true") + .setAppName(name) + .setMaster(if (numRetries == 1) { + s"local[$cores]" + } else { + s"local[$cores, $numRetries]" + }) + .set("spark.logConf", "true") + .set("spark.sql.shuffle.partitions", "20") + .set("spark.driver.maxResultSize", "6g") + .set("spark.sql.warehouse.dir", SparkSessionFactory.LocalWarehousePath) + .set("spark.sql.crossJoin.enabled", "true") val sess = SparkSession.builder() .config(conf) .getOrCreate() diff --git a/core/src/test/scala/com/microsoft/ml/spark/core/test/base/TestBase.scala b/core/src/test/scala/com/microsoft/ml/spark/core/test/base/TestBase.scala index 097c120581b..84a2bfcd08f 100644 --- a/core/src/test/scala/com/microsoft/ml/spark/core/test/base/TestBase.scala +++ b/core/src/test/scala/com/microsoft/ml/spark/core/test/base/TestBase.scala @@ -3,6 +3,8 @@ package com.microsoft.ml.spark.core.test.base +import java.nio.file.Files + import breeze.linalg.norm.Impl import breeze.linalg.{norm, DenseVector => BDV} import breeze.math.Field @@ -17,7 +19,6 @@ import org.scalatest._ import org.scalatest.concurrent.TimeLimits import org.scalatest.time.{Seconds, Span} -import java.nio.file.Files import scala.concurrent._ import scala.reflect.ClassTag diff --git a/core/src/test/scala/com/microsoft/ml/spark/core/test/fuzzing/Fuzzing.scala b/core/src/test/scala/com/microsoft/ml/spark/core/test/fuzzing/Fuzzing.scala index 9ee92739c58..7c6540c8861 100644 --- a/core/src/test/scala/com/microsoft/ml/spark/core/test/fuzzing/Fuzzing.scala +++ b/core/src/test/scala/com/microsoft/ml/spark/core/test/fuzzing/Fuzzing.scala @@ -9,13 +9,13 @@ import java.nio.file.Files import com.microsoft.ml.spark.codegen.CodegenConfig import com.microsoft.ml.spark.core.env.FileUtilities -import com.microsoft.ml.spark.core.test.base.TestBase import org.apache.commons.io.FileUtils import org.apache.spark.ml._ import org.apache.spark.ml.param.{DataFrameEquality, ExternalPythonWrappableParam, ParamPair} import org.apache.spark.ml.util.{MLReadable, MLWritable} import org.apache.spark.sql.DataFrame import com.microsoft.ml.spark.codegen.GenerationUtils._ +import com.microsoft.ml.spark.core.test.base.TestBase /** * Class for holding test information, call by name to avoid uneccesary computations in test generations diff --git a/core/src/test/scala/com/microsoft/ml/spark/core/test/fuzzing/FuzzingTest.scala b/core/src/test/scala/com/microsoft/ml/spark/core/test/fuzzing/FuzzingTest.scala index c213e7a7447..4e5a12fb09d 100644 --- a/core/src/test/scala/com/microsoft/ml/spark/core/test/fuzzing/FuzzingTest.scala +++ b/core/src/test/scala/com/microsoft/ml/spark/core/test/fuzzing/FuzzingTest.scala @@ -4,13 +4,14 @@ package com.microsoft.ml.spark.core.test.fuzzing import com.microsoft.ml.spark.core.contracts.{HasFeaturesCol, HasInputCol, HasLabelCol, HasOutputCol} -import com.microsoft.ml.spark.core.test.base.TestBase import com.microsoft.ml.spark.core.utils.JarLoadingUtils import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.util.{MLReadable, MLWritable} - import java.lang.reflect.ParameterizedType + +import com.microsoft.ml.spark.core.test.base.TestBase + import scala.language.existentials /** Tests to validate fuzzing of modules. */ diff --git a/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/SamplerSuite.scala b/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/SamplerSuite.scala index 4606bedcf70..0c4ea711ed0 100644 --- a/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/SamplerSuite.scala +++ b/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/SamplerSuite.scala @@ -6,7 +6,6 @@ package com.microsoft.ml.spark.explainers.split1 import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV} import breeze.stats.distributions.RandBasis import breeze.stats.{mean, stddev} -import com.microsoft.ml.spark.core.test.base.TestBase import com.microsoft.ml.spark.explainers.BreezeUtils._ import com.microsoft.ml.spark.explainers._ import com.microsoft.ml.spark.io.image.ImageUtils @@ -17,8 +16,9 @@ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.types._ import org.scalactic.{Equality, TolerantNumerics} import org.scalatest.Matchers._ - import java.nio.file.{Files, Paths} + +import com.microsoft.ml.spark.core.test.base.TestBase import javax.imageio.ImageIO class SamplerSuite extends TestBase { diff --git a/core/src/test/scala/com/microsoft/ml/spark/featurize/VerifyFeaturize.scala b/core/src/test/scala/com/microsoft/ml/spark/featurize/VerifyFeaturize.scala index 7ff30e5c723..72168f2badc 100644 --- a/core/src/test/scala/com/microsoft/ml/spark/featurize/VerifyFeaturize.scala +++ b/core/src/test/scala/com/microsoft/ml/spark/featurize/VerifyFeaturize.scala @@ -13,7 +13,7 @@ import com.microsoft.ml.spark.core.test.fuzzing.{EstimatorFuzzing, TestObject} import org.apache.commons.io.FileUtils import org.apache.spark.ml.PipelineModel import org.apache.spark.ml.feature.StringIndexer -import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vectors, Vector} +import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.util.MLReadable import org.apache.spark.sql._ diff --git a/core/src/test/scala/com/microsoft/ml/spark/flaky/PartitionConsolidatorSuite.scala b/core/src/test/scala/com/microsoft/ml/spark/flaky/PartitionConsolidatorSuite.scala index 197f85a6fb5..9c014e715a8 100644 --- a/core/src/test/scala/com/microsoft/ml/spark/flaky/PartitionConsolidatorSuite.scala +++ b/core/src/test/scala/com/microsoft/ml/spark/flaky/PartitionConsolidatorSuite.scala @@ -3,13 +3,13 @@ package com.microsoft.ml.spark.flaky -import com.microsoft.ml.spark.core.test.base.{SparkSessionFactory, TestBase, TimeLimitedFlaky} +import com.microsoft.ml.spark.core.test.base.{TestBase, TimeLimitedFlaky} import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing} import com.microsoft.ml.spark.stages.PartitionConsolidator import org.apache.spark.ml.util.MLReadable import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.types.{DoubleType, StructType} -import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.scalatest.Assertion class PartitionConsolidatorSuite extends TransformerFuzzing[PartitionConsolidator] with TimeLimitedFlaky { diff --git a/core/src/test/scala/com/microsoft/ml/spark/image/ImageTestUtils.scala b/core/src/test/scala/com/microsoft/ml/spark/image/ImageTestUtils.scala index 84e516c5498..63dbea62576 100644 --- a/core/src/test/scala/com/microsoft/ml/spark/image/ImageTestUtils.scala +++ b/core/src/test/scala/com/microsoft/ml/spark/image/ImageTestUtils.scala @@ -8,9 +8,9 @@ import java.net.URL import com.microsoft.ml.spark.build.BuildInfo import com.microsoft.ml.spark.core.env.FileUtilities +import com.microsoft.ml.spark.core.test.base.TestBase import org.apache.spark.ml.linalg.DenseVector import org.apache.spark.sql.{DataFrame, SparkSession} -import com.microsoft.ml.spark.core.test.base.TestBase import com.microsoft.ml.spark.io.IOImplicits.dfrToDfre import org.apache.commons.io.FileUtils import org.apache.spark.sql.functions.col diff --git a/core/src/test/scala/com/microsoft/ml/spark/io/split2/ContinuousHTTPSuite.scala b/core/src/test/scala/com/microsoft/ml/spark/io/split2/ContinuousHTTPSuite.scala index 5507196ee7b..40cf3936191 100644 --- a/core/src/test/scala/com/microsoft/ml/spark/io/split2/ContinuousHTTPSuite.scala +++ b/core/src/test/scala/com/microsoft/ml/spark/io/split2/ContinuousHTTPSuite.scala @@ -5,7 +5,6 @@ package com.microsoft.ml.spark.io.split2 import java.io.File import java.util.UUID -import java.util.concurrent.TimeUnit import com.microsoft.ml.spark.core.test.base.{Flaky, TestBase} import com.microsoft.ml.spark.io.IOImplicits._ @@ -15,7 +14,6 @@ import org.apache.spark.sql.streaming.{DataStreamReader, StreamingQuery, Trigger import org.apache.spark.sql.types.BinaryType import scala.concurrent.Await -import scala.concurrent.duration.Duration // scalastyle:off magic.number class ContinuousHTTPSuite extends TestBase with Flaky with HTTPTestUtils { diff --git a/core/src/test/scala/com/microsoft/ml/spark/io/split2/DistributedHTTPSuite.scala b/core/src/test/scala/com/microsoft/ml/spark/io/split2/DistributedHTTPSuite.scala index 5dd5b437408..d5d106315b8 100644 --- a/core/src/test/scala/com/microsoft/ml/spark/io/split2/DistributedHTTPSuite.scala +++ b/core/src/test/scala/com/microsoft/ml/spark/io/split2/DistributedHTTPSuite.scala @@ -354,7 +354,7 @@ class DistributedHTTPSuite extends TestBase with Flaky with HTTPTestUtils { processes.foreach { p => p.waitFor - val error = IOUtils.toString(p.getErrorStream) + val error = IOUtils.toString(p.getErrorStream, "UTF-8") assert(error === "") } } diff --git a/core/src/test/scala/com/microsoft/ml/spark/nbtest/DatabricksUtilities.scala b/core/src/test/scala/com/microsoft/ml/spark/nbtest/DatabricksUtilities.scala index b72081cde1e..bda6857db7c 100644 --- a/core/src/test/scala/com/microsoft/ml/spark/nbtest/DatabricksUtilities.scala +++ b/core/src/test/scala/com/microsoft/ml/spark/nbtest/DatabricksUtilities.scala @@ -86,7 +86,7 @@ object DatabricksUtilities extends HasHttpClient { if (response.getStatusLine.getStatusCode != 200) { throw new RuntimeException(s"Failed: response: $response") } - IOUtils.toString(response.getEntity.getContent).parseJson + IOUtils.toString(response.getEntity.getContent, "UTF-8").parseJson }.get }) } @@ -102,7 +102,7 @@ object DatabricksUtilities extends HasHttpClient { val entity = IOUtils.toString(response.getEntity.getContent, "UTF-8") throw new RuntimeException(s"Failed:\n entity:$entity \n response: $response") } - IOUtils.toString(response.getEntity.getContent).parseJson + IOUtils.toString(response.getEntity.getContent, "UTF-8").parseJson }.get }) } diff --git a/core/src/test/scala/com/microsoft/ml/spark/train/VerifyTrainClassifier.scala b/core/src/test/scala/com/microsoft/ml/spark/train/VerifyTrainClassifier.scala index 959486a9093..387eb04e375 100644 --- a/core/src/test/scala/com/microsoft/ml/spark/train/VerifyTrainClassifier.scala +++ b/core/src/test/scala/com/microsoft/ml/spark/train/VerifyTrainClassifier.scala @@ -6,7 +6,6 @@ package com.microsoft.ml.spark.train import java.io.File import com.microsoft.ml.spark.core.schema.SchemaConstants -import com.microsoft.ml.spark.core.test.base.TestBase import com.microsoft.ml.spark.core.test.benchmarks.Benchmarks import com.microsoft.ml.spark.core.test.fuzzing.{EstimatorFuzzing, TestObject} import com.microsoft.ml.spark.featurize.ValueIndexer @@ -18,6 +17,7 @@ import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, Multiclas import org.apache.spark.sql.functions._ import org.apache.spark.sql.{DataFrame, Row} import com.microsoft.ml.spark.codegen.GenerationUtils +import com.microsoft.ml.spark.core.test.base.TestBase object ClassifierTestUtils { diff --git a/deep-learning/src/main/scala/com/microsoft/ml/spark/downloader/ModelDownloader.scala b/deep-learning/src/main/scala/com/microsoft/ml/spark/downloader/ModelDownloader.scala index 54f890242b4..8c2a46c55e6 100644 --- a/deep-learning/src/main/scala/com/microsoft/ml/spark/downloader/ModelDownloader.scala +++ b/deep-learning/src/main/scala/com/microsoft/ml/spark/downloader/ModelDownloader.scala @@ -63,7 +63,7 @@ private[spark] class HDFSRepo[S <: Schema](val uri: URI, val hconf: HadoopConf) .filter(status => status.isFile & status.getPath.toString.endsWith(".meta")) .map(status => - IOUtils.toString(fs.open(status.getPath).getWrappedStream)) + IOUtils.toString(fs.open(status.getPath).getWrappedStream, "UTF-8")) schemaStrings.map(s => s.parseJson.convertTo[S]).toList } @@ -94,7 +94,7 @@ private[spark] class HDFSRepo[S <: Schema](val uri: URI, val hconf: HadoopConf) val newSchema = schema.updateURI(location) val schemaPath = new Path(location.getPath + ".meta") val osSchema = fs.create(schemaPath) - val schemaIs = IOUtils.toInputStream(newSchema.toJson.prettyPrint) + val schemaIs = IOUtils.toInputStream(newSchema.toJson.prettyPrint, "UTF-8") try { HUtils.copyBytes(schemaIs, osSchema, hconf) } finally { @@ -130,9 +130,9 @@ private[spark] class DefaultModelRepo(val baseURL: URL) extends Repository[Model val url = join(baseURL, "MANIFEST") val manifestStream = toStream(url) try { - val modelStreams = IOUtils.readLines(manifestStream).asScala.map(fn => toStream(join(baseURL, fn))) + val modelStreams = IOUtils.readLines(manifestStream, "UTF-8").asScala.map(fn => toStream(join(baseURL, fn))) try { - modelStreams.map(s => IOUtils.toString(s).parseJson.convertTo[ModelSchema]) + modelStreams.map(s => IOUtils.toString(s, "UTF-8").parseJson.convertTo[ModelSchema]) } finally { modelStreams.foreach(_.close()) } diff --git a/core/src/main/scala/com/microsoft/ml/spark/explainers/ImageExplainersSuite.scala b/deep-learning/src/test/scala/com/microsoft/ml/spark/explainers/ImageExplainersSuite.scala similarity index 86% rename from core/src/main/scala/com/microsoft/ml/spark/explainers/ImageExplainersSuite.scala rename to deep-learning/src/test/scala/com/microsoft/ml/spark/explainers/ImageExplainersSuite.scala index ae7103f2bd1..3623c88051f 100644 --- a/core/src/main/scala/com/microsoft/ml/spark/explainers/ImageExplainersSuite.scala +++ b/deep-learning/src/test/scala/com/microsoft/ml/spark/explainers/ImageExplainersSuite.scala @@ -3,16 +3,16 @@ package com.microsoft.ml.spark.explainers +import java.io.File +import java.net.URL + import com.microsoft.ml.spark.core.test.base.TestBase -import com.microsoft.ml.spark.image.{ImageFeaturizer, NetworkUtils} +import com.microsoft.ml.spark.image.{ImageFeaturizer, TrainedCNTKModelUtils} import com.microsoft.ml.spark.io.IOImplicits._ import org.apache.commons.io.FileUtils import org.apache.spark.sql.DataFrame -import java.io.File -import java.net.URL - -abstract class ImageExplainersSuite extends TestBase with NetworkUtils { +abstract class ImageExplainersSuite extends TestBase with TrainedCNTKModelUtils { lazy val greyhoundImageLocation: String = { val loc = "/tmp/greyhound.jpg" val f = new File(loc) diff --git a/core/src/test/scala/com/microsoft/ml/spark/explainers/split3/ImageLIMEExplainerSuite.scala b/deep-learning/src/test/scala/com/microsoft/ml/spark/explainers/split2/ImageLIMEExplainerSuite.scala similarity index 98% rename from core/src/test/scala/com/microsoft/ml/spark/explainers/split3/ImageLIMEExplainerSuite.scala rename to deep-learning/src/test/scala/com/microsoft/ml/spark/explainers/split2/ImageLIMEExplainerSuite.scala index 41bc9b21ab2..131b69f6fdb 100644 --- a/core/src/test/scala/com/microsoft/ml/spark/explainers/split3/ImageLIMEExplainerSuite.scala +++ b/deep-learning/src/test/scala/com/microsoft/ml/spark/explainers/split2/ImageLIMEExplainerSuite.scala @@ -1,13 +1,13 @@ // Copyright (C) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. See LICENSE in project root for information. -package com.microsoft.ml.spark.explainers.split3 +package com.microsoft.ml.spark.explainers.split2 import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing} import com.microsoft.ml.spark.explainers.BreezeUtils._ import com.microsoft.ml.spark.explainers.{ImageExplainersSuite, ImageFormat, ImageLIME, LocalExplainer} -import com.microsoft.ml.spark.lime.SuperpixelData import com.microsoft.ml.spark.io.IOImplicits._ +import com.microsoft.ml.spark.lime.SuperpixelData import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.util.MLReadable import org.apache.spark.sql.functions.col diff --git a/core/src/test/scala/com/microsoft/ml/spark/explainers/split2/ImageSHAPExplainerSuite.scala b/deep-learning/src/test/scala/com/microsoft/ml/spark/explainers/split3/ImageSHAPExplainerSuite.scala similarity index 97% rename from core/src/test/scala/com/microsoft/ml/spark/explainers/split2/ImageSHAPExplainerSuite.scala rename to deep-learning/src/test/scala/com/microsoft/ml/spark/explainers/split3/ImageSHAPExplainerSuite.scala index 59fba17bb7a..1de490a4a8e 100644 --- a/core/src/test/scala/com/microsoft/ml/spark/explainers/split2/ImageSHAPExplainerSuite.scala +++ b/deep-learning/src/test/scala/com/microsoft/ml/spark/explainers/split3/ImageSHAPExplainerSuite.scala @@ -1,11 +1,11 @@ // Copyright (C) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. See LICENSE in project root for information. -package com.microsoft.ml.spark.explainers.split2 +package com.microsoft.ml.spark.explainers.split3 import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing} -import com.microsoft.ml.spark.explainers.{ImageExplainersSuite, ImageFormat, ImageSHAP, LocalExplainer} import com.microsoft.ml.spark.explainers.BreezeUtils._ +import com.microsoft.ml.spark.explainers.{ImageExplainersSuite, ImageFormat, ImageSHAP, LocalExplainer} import com.microsoft.ml.spark.lime.SuperpixelData import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.util.MLReadable diff --git a/core/src/main/scala/com/microsoft/ml/spark/lightgbm/PartitionProcessor.scala b/lightgbm/src/main/scala/com/microsoft/ml/spark/lightgbm/PartitionProcessor.scala similarity index 100% rename from core/src/main/scala/com/microsoft/ml/spark/lightgbm/PartitionProcessor.scala rename to lightgbm/src/main/scala/com/microsoft/ml/spark/lightgbm/PartitionProcessor.scala diff --git a/core/src/main/scala/com/microsoft/ml/spark/lightgbm/dataset/DatasetUtils.scala b/lightgbm/src/main/scala/com/microsoft/ml/spark/lightgbm/dataset/DatasetUtils.scala similarity index 98% rename from core/src/main/scala/com/microsoft/ml/spark/lightgbm/dataset/DatasetUtils.scala rename to lightgbm/src/main/scala/com/microsoft/ml/spark/lightgbm/dataset/DatasetUtils.scala index a404a42e37f..02ba5b698e1 100644 --- a/core/src/main/scala/com/microsoft/ml/spark/lightgbm/dataset/DatasetUtils.scala +++ b/lightgbm/src/main/scala/com/microsoft/ml/spark/lightgbm/dataset/DatasetUtils.scala @@ -134,8 +134,9 @@ object DatasetUtils { /** * Sample the first several rows to determine whether to construct sparse or dense matrix in lightgbm native code. - * @param rowsIter Iterator of rows. - * @param schema The schema. + * + * @param rowsIter Iterator of rows. + * @param schema The schema. * @param columnParams The column parameters. * @return A reconstructed iterator with the same original rows and whether the matrix should be sparse or dense. */ @@ -158,7 +159,7 @@ object DatasetUtils { } def addFeaturesToChunkedArray(featuresChunkedArrayOpt: Option[doubleChunkedArray], numCols: Int, - rowAsDoubleArray: Array[Double]): Unit = { + rowAsDoubleArray: Array[Double]): Unit = { featuresChunkedArrayOpt.foreach { featuresChunkedArray => rowAsDoubleArray.foreach { doubleVal => featuresChunkedArray.add(doubleVal)