Skip to content

Commit bc654e1

Browse files
committed
Added spark.ml LinearRegressionSuite
1 parent 8d13233 commit bc654e1

File tree

2 files changed

+131
-4
lines changed

2 files changed

+131
-4
lines changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.classification;
19+
20+
import scala.Tuple2;
21+
22+
import java.io.Serializable;
23+
import java.util.ArrayList;
24+
import java.util.List;
25+
26+
import org.junit.After;
27+
import org.junit.Before;
28+
import org.junit.Test;
29+
30+
import org.apache.spark.api.java.JavaRDD;
31+
import org.apache.spark.api.java.JavaSparkContext;
32+
import org.apache.spark.api.java.function.Function;
33+
import org.apache.spark.ml.LabeledPoint;
34+
import org.apache.spark.ml.regression.LinearRegression;
35+
import org.apache.spark.ml.regression.LinearRegressionModel;
36+
import static org.apache.spark.mllib.classification.LogisticRegressionSuite
37+
.generateLogisticInputAsList;
38+
import org.apache.spark.mllib.linalg.Vector;
39+
import org.apache.spark.sql.api.java.JavaSQLContext;
40+
import org.apache.spark.sql.api.java.JavaSchemaRDD;
41+
import org.apache.spark.sql.api.java.Row;
42+
43+
44+
public class JavaLinearRegressionSuite implements Serializable {
45+
46+
private transient JavaSparkContext jsc;
47+
private transient JavaSQLContext jsql;
48+
private transient JavaSchemaRDD dataset;
49+
private transient JavaRDD<LabeledPoint> datasetRDD;
50+
private transient JavaRDD<Vector> featuresRDD;
51+
private double eps = 1e-5;
52+
53+
@Before
54+
public void setUp() {
55+
jsc = new JavaSparkContext("local", "JavaLinearRegressionSuite");
56+
jsql = new JavaSQLContext(jsc);
57+
List<LabeledPoint> points = new ArrayList<LabeledPoint>();
58+
for (org.apache.spark.mllib.regression.LabeledPoint lp:
59+
generateLogisticInputAsList(1.0, 1.0, 100, 42)) {
60+
points.add(new LabeledPoint(lp.label(), lp.features()));
61+
}
62+
datasetRDD = jsc.parallelize(points, 2);
63+
featuresRDD = datasetRDD.map(new Function<LabeledPoint, Vector>() {
64+
@Override public Vector call(LabeledPoint lp) { return lp.features(); }
65+
});
66+
dataset = jsql.applySchema(datasetRDD, LabeledPoint.class);
67+
dataset.registerTempTable("dataset");
68+
}
69+
70+
@After
71+
public void tearDown() {
72+
jsc.stop();
73+
jsc = null;
74+
}
75+
76+
@Test
77+
public void linearRegressionDefaultParams() {
78+
LinearRegression lr = new LinearRegression();
79+
assert(lr.getLabelCol().equals("label"));
80+
LinearRegressionModel model = lr.fit(dataset);
81+
model.transform(dataset).registerTempTable("prediction");
82+
JavaSchemaRDD predictions = jsql.sql("SELECT label, prediction FROM prediction");
83+
predictions.collect();
84+
// Check defaults
85+
assert(model.getFeaturesCol().equals("features"));
86+
assert(model.getPredictionCol().equals("prediction"));
87+
}
88+
89+
@Test
90+
public void linearRegressionWithSetters() {
91+
// Set params, train, and check as many params as we can.
92+
LinearRegression lr = new LinearRegression()
93+
.setMaxIter(10)
94+
.setRegParam(1.0);
95+
LinearRegressionModel model = lr.fit(dataset);
96+
assert(model.fittingParamMap().get(lr.maxIter()).get() == 10);
97+
assert(model.fittingParamMap().get(lr.regParam()).get() == 1.0);
98+
99+
// Call fit() with new params, and check as many params as we can.
100+
LinearRegressionModel model2 =
101+
lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred"));
102+
assert(model2.fittingParamMap().get(lr.maxIter()).get() == 5);
103+
assert(model2.fittingParamMap().get(lr.regParam()).get() == 0.1);
104+
assert(model2.getPredictionCol().equals("thePred"));
105+
}
106+
107+
@Test
108+
public void linearRegressionPredictorClassifierMethods() {
109+
LinearRegression lr = new LinearRegression();
110+
111+
// fit() vs. train()
112+
LinearRegressionModel model1 = lr.fit(dataset);
113+
LinearRegressionModel model2 = lr.train(datasetRDD);
114+
assert(model1.intercept() == model2.intercept());
115+
assert(model1.weights().equals(model2.weights()));
116+
117+
// transform() vs. predict()
118+
model1.transform(dataset).registerTempTable("transformed");
119+
JavaSchemaRDD trans = jsql.sql("SELECT prediction FROM transformed");
120+
JavaRDD<Double> preds = model1.predict(featuresRDD);
121+
for (Tuple2<Row, Double> trans_pred: trans.zip(preds).collect()) {
122+
double t = trans_pred._1().getDouble(0);
123+
double p = trans_pred._2();
124+
assert(t == p);
125+
}
126+
}
127+
}

mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,13 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
6060
.setMaxIter(10)
6161
.setRegParam(1.0)
6262
val model = lr.fit(dataset)
63-
assert(model.fittingParamMap.get(lr.maxIter) === Some(10))
64-
assert(model.fittingParamMap.get(lr.regParam) === Some(1.0))
63+
assert(model.fittingParamMap.get(lr.maxIter).get === 10)
64+
assert(model.fittingParamMap.get(lr.regParam).get === 1.0)
6565

6666
// Call fit() with new params, and check as many as we can.
6767
val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.predictionCol -> "thePred")
68-
assert(model2.fittingParamMap.get(lr.maxIter) === Some(5))
69-
assert(model2.fittingParamMap.get(lr.regParam) === Some(0.1))
68+
assert(model2.fittingParamMap.get(lr.maxIter).get === 5)
69+
assert(model2.fittingParamMap.get(lr.regParam).get === 0.1)
7070
assert(model2.getPredictionCol == "thePred")
7171
}
7272

0 commit comments

Comments
 (0)