Skip to content

Commit e43803b

Browse files
committed
[SPARK-6948] [MLLIB] compress vectors in VectorAssembler
The compression is based on storage. brkyvz Author: Xiangrui Meng <[email protected]> Closes #5985 from mengxr/SPARK-6948 and squashes the following commits: df56a00 [Xiangrui Meng] update python tests 6d90d45 [Xiangrui Meng] compress vectors in VectorAssembler
1 parent 658a478 commit e43803b

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,6 @@ object VectorAssembler {
102102
case o =>
103103
throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.")
104104
}
105-
Vectors.sparse(cur, indices.result(), values.result())
105+
Vectors.sparse(cur, indices.result(), values.result()).compressed
106106
}
107107
}

mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
2020
import org.scalatest.FunSuite
2121

2222
import org.apache.spark.SparkException
23-
import org.apache.spark.mllib.linalg.{Vector, Vectors}
23+
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
2424
import org.apache.spark.mllib.util.MLlibTestSparkContext
2525
import org.apache.spark.sql.{Row, SQLContext}
2626

@@ -48,6 +48,14 @@ class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext {
4848
}
4949
}
5050

51+
test("assemble should compress vectors") {
52+
import org.apache.spark.ml.feature.VectorAssembler.assemble
53+
val v1 = assemble(0.0, 0.0, 0.0, Vectors.dense(4.0))
54+
assert(v1.isInstanceOf[SparseVector])
55+
val v2 = assemble(1.0, 2.0, 3.0, Vectors.sparse(1, Array(0), Array(4.0)))
56+
assert(v2.isInstanceOf[DenseVector])
57+
}
58+
5159
test("VectorAssembler") {
5260
val df = sqlContext.createDataFrame(Seq(
5361
(0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L)

python/pyspark/ml/feature.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,12 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol):
121121
>>> df = sc.parallelize([Row(a=1, b=0, c=3)]).toDF()
122122
>>> vecAssembler = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features")
123123
>>> vecAssembler.transform(df).head().features
124-
SparseVector(3, {0: 1.0, 2: 3.0})
124+
DenseVector([1.0, 0.0, 3.0])
125125
>>> vecAssembler.setParams(outputCol="freqs").transform(df).head().freqs
126-
SparseVector(3, {0: 1.0, 2: 3.0})
126+
DenseVector([1.0, 0.0, 3.0])
127127
>>> params = {vecAssembler.inputCols: ["b", "a"], vecAssembler.outputCol: "vector"}
128128
>>> vecAssembler.transform(df, params).head().vector
129-
SparseVector(2, {1: 1.0})
129+
DenseVector([0.0, 1.0])
130130
"""
131131

132132
_java_class = "org.apache.spark.ml.feature.VectorAssembler"

0 commit comments

Comments
 (0)