Skip to content

Commit 1680905

Browse files
committed
Added JavaLabeledPointSuite.java for spark.ml, and added constructor to LabeledPoint which defaults weight to 1.0
1 parent adbe50a commit 1680905

File tree

3 files changed

+68
-2
lines changed

3 files changed

+68
-2
lines changed

mllib/src/main/scala/org/apache/spark/ml/LabeledPoint.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ import org.apache.spark.mllib.linalg.Vector
3434
@BeanInfo
3535
case class LabeledPoint(label: Double, features: Vector, weight: Double) {
3636

37+
/** Constructor which sets instance weight to 1.0 */
38+
def this(label: Double, features: Vector) = this(label, features, 1.0)
39+
3740
override def toString: String = {
3841
"(%s,%s,%s)".format(label, features, weight)
3942
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
package org.apache.spark.ml;
2+
3+
import java.util.List;
4+
5+
import org.junit.After;
6+
import org.junit.Before;
7+
import org.junit.Test;
8+
9+
import com.google.common.collect.Lists;
10+
11+
import org.apache.spark.api.java.JavaRDD;
12+
import org.apache.spark.api.java.JavaSparkContext;
13+
import org.apache.spark.mllib.linalg.Vector;
14+
import org.apache.spark.mllib.linalg.Vectors;
15+
import org.apache.spark.sql.api.java.JavaSQLContext;
16+
import org.apache.spark.sql.api.java.JavaSchemaRDD;
17+
import org.apache.spark.sql.api.java.Row;
18+
19+
/**
20+
* Test {@link LabeledPoint} in Java
21+
*/
22+
public class JavaLabeledPointSuite {
23+
24+
private transient JavaSparkContext jsc;
25+
private transient JavaSQLContext jsql;
26+
27+
@Before
28+
public void setUp() {
29+
jsc = new JavaSparkContext("local", "JavaLabeledPointSuite");
30+
jsql = new JavaSQLContext(jsc);
31+
}
32+
33+
@After
34+
public void tearDown() {
35+
jsc.stop();
36+
jsc = null;
37+
}
38+
39+
@Test
40+
public void labeledPointDefaultWeight() {
41+
double label = 1.0;
42+
Vector features = Vectors.dense(1.0, 2.0, 3.0);
43+
LabeledPoint lp1 = new LabeledPoint(label, features);
44+
LabeledPoint lp2 = new LabeledPoint(label, features, 1.0);
45+
assert(lp1.equals(lp2));
46+
}
47+
48+
@Test
49+
public void labeledPointSchemaRDD() {
50+
List<LabeledPoint> arr = Lists.newArrayList(
51+
new LabeledPoint(0.0, Vectors.dense(1.0, 2.0, 3.0)),
52+
new LabeledPoint(1.0, Vectors.dense(1.1, 2.1, 3.1)),
53+
new LabeledPoint(0.0, Vectors.dense(1.2, 2.2, 3.2)),
54+
new LabeledPoint(1.0, Vectors.dense(1.3, 2.3, 3.3)));
55+
JavaRDD<LabeledPoint> rdd = jsc.parallelize(arr);
56+
JavaSchemaRDD schemaRDD = jsql.applySchema(rdd, LabeledPoint.class);
57+
schemaRDD.registerTempTable("points");
58+
List<Row> points = jsql.sql("SELECT label, features FROM points").collect();
59+
assert (points.size() == arr.size());
60+
}
61+
}

mllib/src/test/scala/org/apache/spark/ml/LabeledPointSuite.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@ package org.apache.spark.ml
1919

2020
import org.scalatest.FunSuite
2121

22-
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
2322
import org.apache.spark.mllib.linalg.Vectors
2423
import org.apache.spark.mllib.util.MLlibTestSparkContext
25-
import org.apache.spark.sql.{SQLContext, SchemaRDD}
24+
import org.apache.spark.sql.SQLContext
2625

26+
/**
27+
* Test [[LabeledPoint]]
28+
*/
2729
class LabeledPointSuite extends FunSuite with MLlibTestSparkContext {
2830

2931
@transient var sqlContext: SQLContext = _

0 commit comments

Comments
 (0)