Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 8 additions & 15 deletions mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,15 @@ package org.apache.spark.ml.feature

import scala.beans.BeanInfo

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.sql.{DataFrame, Row}


@BeanInfo
case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String])

class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
class NGramSuite extends MLTest with DefaultReadWriteTest {

import org.apache.spark.ml.feature.NGramSuite._
import testImplicits._

test("default behavior yields bigram features") {
Expand Down Expand Up @@ -83,16 +81,11 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe
.setN(3)
testDefaultReadWrite(t)
}
}

object NGramSuite extends SparkFunSuite {

def testNGram(t: NGram, dataset: Dataset[_]): Unit = {
t.transform(dataset)
.select("nGrams", "wantedNGrams")
.collect()
.foreach { case Row(actualNGrams, wantedNGrams) =>
def testNGram(t: NGram, dataFrame: DataFrame): Unit = {
testTransformer[(Seq[String], Seq[String])](dataFrame, t, "nGrams", "wantedNGrams") {
case Row(actualNGrams : Seq[String], wantedNGrams: Seq[String]) =>
assert(actualNGrams === wantedNGrams)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,17 @@

package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}


class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
class NormalizerSuite extends MLTest with DefaultReadWriteTest {

import testImplicits._

@transient var data: Array[Vector] = _
@transient var dataFrame: DataFrame = _
@transient var normalizer: Normalizer = _
@transient var l1Normalized: Array[Vector] = _
@transient var l2Normalized: Array[Vector] = _

Expand Down Expand Up @@ -62,49 +58,40 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
Vectors.dense(0.897906166, 0.113419726, 0.42532397),
Vectors.sparse(3, Seq())
)

dataFrame = data.map(NormalizerSuite.FeatureData).toSeq.toDF()
normalizer = new Normalizer()
.setInputCol("features")
.setOutputCol("normalized_features")
}

def collectResult(result: DataFrame): Array[Vector] = {
result.select("normalized_features").collect().map {
case Row(features: Vector) => features
}
}

def assertTypeOfVector(lhs: Array[Vector], rhs: Array[Vector]): Unit = {
assert((lhs, rhs).zipped.forall {
def assertTypeOfVector(lhs: Vector, rhs: Vector): Unit = {
assert((lhs, rhs) match {
case (v1: DenseVector, v2: DenseVector) => true
case (v1: SparseVector, v2: SparseVector) => true
case _ => false
}, "The vector type should be preserved after normalization.")
}

def assertValues(lhs: Array[Vector], rhs: Array[Vector]): Unit = {
assert((lhs, rhs).zipped.forall { (vector1, vector2) =>
vector1 ~== vector2 absTol 1E-5
}, "The vector value is not correct after normalization.")
def assertValues(lhs: Vector, rhs: Vector): Unit = {
assert(lhs ~== rhs absTol 1E-5, "The vector value is not correct after normalization.")
}

test("Normalization with default parameter") {
val result = collectResult(normalizer.transform(dataFrame))

assertTypeOfVector(data, result)
val normalizer = new Normalizer().setInputCol("features").setOutputCol("normalized")
val dataFrame: DataFrame = data.zip(l2Normalized).seq.toDF("features", "expected")

assertValues(result, l2Normalized)
testTransformer[(Vector, Vector)](dataFrame, normalizer, "features", "normalized", "expected") {
case Row(features: Vector, normalized: Vector, expected: Vector) =>
assertTypeOfVector(normalized, features)
assertValues(normalized, expected)
}
}

test("Normalization with setter") {
normalizer.setP(1)
val dataFrame: DataFrame = data.zip(l1Normalized).seq.toDF("features", "expected")
val normalizer = new Normalizer().setInputCol("features").setOutputCol("normalized").setP(1)

val result = collectResult(normalizer.transform(dataFrame))

assertTypeOfVector(data, result)

assertValues(result, l1Normalized)
testTransformer[(Vector, Vector)](dataFrame, normalizer, "features", "normalized", "expected") {
case Row(features: Vector, normalized: Vector, expected: Vector) =>
assertTypeOfVector(normalized, features)
assertValues(normalized, expected)
}
}

test("read/write") {
Expand All @@ -115,7 +102,3 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
testDefaultReadWrite(t)
}
}

private object NormalizerSuite {
case class FeatureData(features: Vector)
}
Loading