Skip to content

Commit e9fcd49

Browse files
committed
add serializeLabeledPoint and tests
1 parent aea4ae3 commit e9fcd49

File tree

2 files changed

+75
-3
lines changed

2 files changed

+75
-3
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
2727
import org.apache.spark.mllib.recommendation._
2828
import org.apache.spark.mllib.regression._
2929
import org.apache.spark.rdd.RDD
30+
import org.apache.spark.SparkException
3031

3132
/**
3233
* :: DeveloperApi ::
@@ -41,7 +42,7 @@ class PythonMLLibAPI extends Serializable {
4142
private val DENSE_MATRIX_MAGIC: Byte = 3
4243
private val LABELED_POINT_MAGIC: Byte = 4
4344

44-
private def deserializeDoubleVector(bytes: Array[Byte], offset: Int = 0): Vector = {
45+
private[python] def deserializeDoubleVector(bytes: Array[Byte], offset: Int = 0): Vector = {
4546
require(bytes.length - offset >= 5, "Byte array too short")
4647
val magic = bytes(offset)
4748
if (magic == DENSE_VECTOR_MAGIC) {
@@ -116,7 +117,7 @@ class PythonMLLibAPI extends Serializable {
116117
bytes
117118
}
118119

119-
private def serializeDoubleVector(vector: Vector): Array[Byte] = vector match {
120+
private[python] def serializeDoubleVector(vector: Vector): Array[Byte] = vector match {
120121
case s: SparseVector =>
121122
serializeSparseVector(s)
122123
case _ =>
@@ -167,7 +168,18 @@ class PythonMLLibAPI extends Serializable {
167168
bytes
168169
}
169170

170-
private def deserializeLabeledPoint(bytes: Array[Byte]): LabeledPoint = {
171+
private[python] def serializeLabeledPoint(p: LabeledPoint): Array[Byte] = {
172+
val fb = serializeDoubleVector(p.features)
173+
val bytes = new Array[Byte](1 + 8 + fb.length)
174+
val bb = ByteBuffer.wrap(bytes)
175+
bb.order(ByteOrder.nativeOrder())
176+
bb.put(LABELED_POINT_MAGIC)
177+
bb.putDouble(p.label)
178+
bb.put(fb)
179+
bytes
180+
}
181+
182+
private[python] def deserializeLabeledPoint(bytes: Array[Byte]): LabeledPoint = {
171183
require(bytes.length >= 9, "Byte array too short")
172184
val magic = bytes(0)
173185
if (magic != LABELED_POINT_MAGIC) {
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.api.python
19+
20+
import org.scalatest.FunSuite
21+
22+
import org.apache.spark.mllib.linalg.Vectors
23+
import org.apache.spark.mllib.regression.LabeledPoint
24+
25+
class PythonMLLibAPISuite extends FunSuite {
26+
val py = new PythonMLLibAPI
27+
28+
test("vector serialization") {
29+
val vectors = Seq(
30+
Vectors.dense(Array.empty[Double]),
31+
Vectors.dense(0.0),
32+
Vectors.dense(0.0, -2.0),
33+
Vectors.sparse(0, Array.empty[Int], Array.empty[Double]),
34+
Vectors.sparse(1, Array.empty[Int], Array.empty[Double]),
35+
Vectors.sparse(2, Array(1), Array(-2.0)))
36+
vectors.foreach { v =>
37+
val bytes = py.serializeDoubleVector(v)
38+
val u = py.deserializeDoubleVector(bytes)
39+
assert(u.getClass === v.getClass)
40+
assert(u === v)
41+
}
42+
}
43+
44+
test("labeled point serialization") {
45+
val points = Seq(
46+
LabeledPoint(0.0, Vectors.dense(Array.empty[Double])),
47+
LabeledPoint(1.0, Vectors.dense(0.0)),
48+
LabeledPoint(-0.5, Vectors.dense(0.0, -2.0)),
49+
LabeledPoint(0.0, Vectors.sparse(0, Array.empty[Int], Array.empty[Double])),
50+
LabeledPoint(1.0, Vectors.sparse(1, Array.empty[Int], Array.empty[Double])),
51+
LabeledPoint(-0.5, Vectors.sparse(2, Array(1), Array(-2.0))))
52+
points.foreach { p =>
53+
val bytes = py.serializeLabeledPoint(p)
54+
val q = py.deserializeLabeledPoint(bytes)
55+
assert(q.label === p.label)
56+
assert(q.features.getClass === p.features.getClass)
57+
assert(q.features === p.features)
58+
}
59+
}
60+
}

0 commit comments

Comments
 (0)