Skip to content

Commit 19aa523

Browse files
committed
update toString and add parsers for Vectors and LabeledPoint
1 parent fdae095 commit 19aa523

File tree

4 files changed

+110
-3
lines changed

4 files changed

+110
-3
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.util.Arrays
2222

2323
import scala.annotation.varargs
2424
import scala.collection.JavaConverters._
25+
import scala.util.parsing.combinator.JavaTokenParsers
2526

2627
import breeze.linalg.{Vector => BV, DenseVector => BDV, SparseVector => BSV}
2728

@@ -124,6 +125,8 @@ object Vectors {
124125
}.toSeq)
125126
}
126127

128+
private[mllib] def parse(s: String): Vector = VectorParsers(s)
129+
127130
/**
128131
* Creates a vector instance from a breeze vector.
129132
*/
@@ -171,8 +174,11 @@ class SparseVector(
171174
val indices: Array[Int],
172175
val values: Array[Double]) extends Vector {
173176

177+
require(indices.length == values.length)
178+
174179
override def toString: String = {
175-
"(" + size + "," + indices.zip(values).mkString("[", "," ,"]") + ")"
180+
Seq(size, indices.mkString("[", ",", "]"), values.mkString("[", ",", "]"))
181+
.mkString("(", ",", ")")
176182
}
177183

178184
override def toArray: Array[Double] = {
@@ -188,3 +194,28 @@ class SparseVector(
188194

189195
private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size)
190196
}
197+
198+
/**
199+
* Parsers for string representation of [[org.apache.spark.mllib.linalg.Vector]].
200+
*/
201+
private[mllib] class VectorParsers extends JavaTokenParsers {
202+
lazy val indices: Parser[Array[Int]] = "[" ~ repsep(wholeNumber, ",") ~ "]" ^^ {
203+
case "[" ~ ii ~ "]" => ii.map(_.toInt).toArray
204+
}
205+
lazy val values: Parser[Array[Double]] = "[" ~ repsep(floatingPointNumber, ",") ~ "]" ^^ {
206+
case "[" ~ vv ~ "]" => vv.map(_.toDouble).toArray
207+
}
208+
lazy val denseVector: Parser[DenseVector] = values ^^ {
209+
case vv => new DenseVector(vv)
210+
}
211+
lazy val sparseVector: Parser[SparseVector] =
212+
"(" ~ wholeNumber ~ "," ~ indices ~ "," ~ values ~ ")" ^^ {
213+
case "(" ~ size ~ "," ~ ii ~ "," ~ vv ~ ")" => new SparseVector(size.toInt, ii, vv)
214+
}
215+
lazy val vector: Parser[Vector] = denseVector | sparseVector
216+
}
217+
218+
private[mllib] object VectorParsers extends VectorParsers {
219+
/** Parses a string into an [[org.apache.spark.mllib.linalg.Vector]]. */
220+
def apply(s: String): Vector = parse(vector, s).get
221+
}

mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.mllib.regression
1919

20-
import org.apache.spark.mllib.linalg.Vector
20+
import org.apache.spark.mllib.linalg.{Vector, VectorParsers}
2121

2222
/**
2323
* Class that represents the features and labels of a data point.
@@ -27,6 +27,25 @@ import org.apache.spark.mllib.linalg.Vector
2727
*/
2828
case class LabeledPoint(label: Double, features: Vector) {
2929
override def toString: String = {
30-
"LabeledPoint(%s, %s)".format(label, features)
30+
Seq(label, features).mkString("(", ",", ")")
3131
}
3232
}
33+
34+
object LabeledPoint {
35+
/** Parses a string into an [[org.apache.spark.mllib.regression.LabeledPoint]]. */
36+
def parse(s: String) = LabeledPointParsers.parse(s)
37+
}
38+
39+
/**
40+
* Parsers for string representation of [[org.apache.spark.mllib.regression.LabeledPoint]].
41+
*/
42+
private[mllib] class LabeledPointParsers extends VectorParsers {
43+
lazy val labeledPoint: Parser[LabeledPoint] = "(" ~ floatingPointNumber ~ "," ~ vector ~ ")" ^^ {
44+
case "(" ~ l ~ "," ~ v ~ ")" => LabeledPoint(l.toDouble, v)
45+
}
46+
}
47+
48+
private[mllib] object LabeledPointParsers extends LabeledPointParsers {
49+
/** Parses a string into an [[org.apache.spark.mllib.regression.LabeledPoint]]. */
50+
def parse(s: String): LabeledPoint = parse(labeledPoint, s).get
51+
}

mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,4 +100,27 @@ class VectorsSuite extends FunSuite {
100100
assert(vec2(6) === 4.0)
101101
assert(vec2(7) === 0.0)
102102
}
103+
104+
test("parse vectors") {
105+
val vectors = Seq(
106+
Vectors.dense(Array.empty[Double]),
107+
Vectors.dense(1.0),
108+
Vectors.dense(1.0, 0.0, -2.0),
109+
Vectors.sparse(0, Array.empty[Int], Array.empty[Double]),
110+
Vectors.sparse(1, Array(0), Array(1.0)),
111+
Vectors.sparse(3, Array(0, 2), Array(1.0, -2.0)))
112+
vectors.foreach { v =>
113+
val v1 = Vectors.parse(v.toString)
114+
assert(v.getClass === v1.getClass)
115+
assert(v === v1)
116+
}
117+
118+
val malformatted = Seq("1", "[1,]", "[1,2", "(1,[1,2])", "(1,[1],[2.0,1.0])")
119+
malformatted.foreach { s =>
120+
intercept[RuntimeException] {
121+
Vectors.parse(s)
122+
println(s"Didn't detect malformatted string $s.")
123+
}
124+
}
125+
}
103126
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.regression
19+
20+
import org.scalatest.FunSuite
21+
22+
import org.apache.spark.mllib.linalg.Vectors
23+
24+
class LabeledPointSuite extends FunSuite {
25+
26+
test("parse labeled points") {
27+
val points = Seq(
28+
LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
29+
LabeledPoint(0.0, Vectors.sparse(2, Array(1), Array(-1.0))))
30+
points.foreach { p =>
31+
assert(p === LabeledPoint.parse(p.toString))
32+
}
33+
}
34+
}

0 commit comments

Comments
 (0)