Skip to content

Commit adbe50a

Browse files
committed
* fixed LinearRegression train() to use embedded paramMap
* added Predictor.predict(RDD[Vector]) method * updated Linear/LogisticRegressionSuites
1 parent 58802e3 commit adbe50a

File tree

4 files changed

+147
-3
lines changed

4 files changed

+147
-3
lines changed

mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ private[ml] abstract class PredictionModel[M <: PredictionModel[M]]
137137
}
138138

139139
/**
140+
* Strongly typed version of [[transform()]].
140141
* Default implementation using single-instance predict().
141142
*
142143
* Developers should override this for efficiency. E.g., this does not broadcast the model.
@@ -147,6 +148,9 @@ private[ml] abstract class PredictionModel[M <: PredictionModel[M]]
147148
dataset.map(tmpModel.predict)
148149
}
149150

151+
/** Strongly typed version of [[transform()]]. */
152+
def predict(dataset: RDD[Vector]): RDD[Double] = predict(dataset, new ParamMap)
153+
150154
/**
151155
* Predict label for the given features.
152156
*/

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class LinearRegression extends Regressor[LinearRegression, LinearRegressionModel
5353
* These values override any specified in this Estimator's embedded ParamMap.
5454
*/
5555
override def train(dataset: RDD[LabeledPoint], paramMap: ParamMap): LinearRegressionModel = {
56+
val map = this.paramMap ++ paramMap
5657
val oldDataset = dataset.map { case LabeledPoint(label: Double, features: Vector, weight) =>
5758
org.apache.spark.mllib.regression.LabeledPoint(label, features)
5859
}
@@ -62,10 +63,10 @@ class LinearRegression extends Regressor[LinearRegression, LinearRegressionModel
6263
}
6364
val lr = new LinearRegressionWithSGD()
6465
lr.optimizer
65-
.setRegParam(paramMap(regParam))
66-
.setNumIterations(paramMap(maxIter))
66+
.setRegParam(map(regParam))
67+
.setNumIterations(map(maxIter))
6768
val model = lr.run(oldDataset)
68-
val lrm = new LinearRegressionModel(this, paramMap, model.weights, model.intercept)
69+
val lrm = new LinearRegressionModel(this, map, model.weights, model.intercept)
6970
if (handlePersistence) {
7071
oldDataset.unpersist()
7172
}

mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,19 @@ package org.apache.spark.ml.classification
1919

2020
import org.scalatest.FunSuite
2121

22+
import org.apache.spark.ml.LabeledPoint
2223
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
24+
import org.apache.spark.mllib.linalg.Vector
2325
import org.apache.spark.mllib.util.MLlibTestSparkContext
26+
import org.apache.spark.mllib.util.TestingUtils._
2427
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
2528

2629

2730
class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
2831

2932
@transient var sqlContext: SQLContext = _
3033
@transient var dataset: DataFrame = _
34+
private val eps: Double = 1e-5
3135

3236
override def beforeAll(): Unit = {
3337
super.beforeAll()
@@ -38,6 +42,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
3842

3943
test("logistic regression: default params") {
4044
val lr = new LogisticRegression
45+
assert(lr.getLabelCol == "label")
4146
val model = lr.fit(dataset)
4247
model.transform(dataset)
4348
.select("label", "prediction")
@@ -96,4 +101,43 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
96101
assert(model2.getThreshold === 0.4)
97102
assert(model2.getScoreCol == "theProb")
98103
}
104+
105+
test("logistic regression: Predictor, Classifier methods") {
106+
val sqlContext = this.sqlContext
107+
import sqlContext._
108+
val lr = new LogisticRegression
109+
110+
// fit() vs. train()
111+
val model1 = lr.fit(dataset)
112+
val rdd = dataset.select('label, 'features).map { case Row(label: Double, features: Vector) =>
113+
LabeledPoint(label, features)
114+
}
115+
val features = rdd.map(_.features)
116+
val model2 = lr.train(rdd)
117+
assert(model1.intercept == model2.intercept)
118+
assert(model1.weights.equals(model2.weights))
119+
assert(model1.numClasses == model2.numClasses)
120+
assert(model1.numClasses === 2)
121+
122+
// transform() vs. predict()
123+
val trans = model1.transform(dataset).select('prediction)
124+
val preds = model1.predict(rdd.map(_.features))
125+
trans.zip(preds).collect().foreach { case (Row(pred1: Double), pred2: Double) =>
126+
assert(pred1 == pred2)
127+
}
128+
129+
// Check various types of predictions.
130+
val allPredictions = features.map { f =>
131+
(model1.predictRaw(f), model1.predictProbabilities(f), model1.predict(f))
132+
}.collect()
133+
val threshold = model1.getThreshold
134+
allPredictions.foreach { case (raw: Vector, prob: Vector, pred: Double) =>
135+
val computeProbFromRaw: (Double => Double) = (m) => 1.0 / (1.0 + math.exp(-m))
136+
raw.toArray.map(computeProbFromRaw).zip(prob.toArray).foreach { case (r, p) =>
137+
assert(r ~== p relTol eps)
138+
}
139+
val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
140+
assert(pred == predFromProb)
141+
}
142+
}
99143
}
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.regression
19+
20+
import org.scalatest.FunSuite
21+
22+
import org.apache.spark.ml.LabeledPoint
23+
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
24+
import org.apache.spark.mllib.linalg.Vector
25+
import org.apache.spark.mllib.util.MLlibTestSparkContext
26+
import org.apache.spark.mllib.util.TestingUtils._
27+
import org.apache.spark.sql.{Row, SQLContext, SchemaRDD}
28+
29+
class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
30+
31+
@transient var sqlContext: SQLContext = _
32+
@transient var dataset: SchemaRDD = _
33+
34+
override def beforeAll(): Unit = {
35+
super.beforeAll()
36+
sqlContext = new SQLContext(sc)
37+
dataset = sqlContext.createSchemaRDD(
38+
sc.parallelize(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 2))
39+
}
40+
41+
test("linear regression: default params") {
42+
val sqlContext = this.sqlContext
43+
import sqlContext._
44+
val lr = new LinearRegression
45+
assert(lr.getLabelCol == "label")
46+
val model = lr.fit(dataset)
47+
model.transform(dataset)
48+
.select('label, 'prediction)
49+
.collect()
50+
// Check defaults
51+
assert(model.getFeaturesCol == "features")
52+
assert(model.getPredictionCol == "prediction")
53+
}
54+
55+
test("linear regression with setters") {
56+
// Set params, train, and check as many as we can.
57+
val sqlContext = this.sqlContext
58+
import sqlContext._
59+
val lr = new LinearRegression()
60+
.setMaxIter(10)
61+
.setRegParam(1.0)
62+
val model = lr.fit(dataset)
63+
assert(model.fittingParamMap.get(lr.maxIter) === Some(10))
64+
assert(model.fittingParamMap.get(lr.regParam) === Some(1.0))
65+
66+
// Call fit() with new params, and check as many as we can.
67+
val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.predictionCol -> "thePred")
68+
assert(model2.fittingParamMap.get(lr.maxIter) === Some(5))
69+
assert(model2.fittingParamMap.get(lr.regParam) === Some(0.1))
70+
assert(model2.getPredictionCol == "thePred")
71+
}
72+
73+
test("linear regression: Predictor, Regressor methods") {
74+
val sqlContext = this.sqlContext
75+
import sqlContext._
76+
val lr = new LinearRegression
77+
78+
// fit() vs. train()
79+
val model1 = lr.fit(dataset)
80+
val rdd = dataset.select('label, 'features).map { case Row(label: Double, features: Vector) =>
81+
LabeledPoint(label, features)
82+
}
83+
val features = rdd.map(_.features)
84+
val model2 = lr.train(rdd)
85+
assert(model1.intercept == model2.intercept)
86+
assert(model1.weights.equals(model2.weights))
87+
88+
// transform() vs. predict()
89+
val trans = model1.transform(dataset).select('prediction)
90+
val preds = model1.predict(rdd.map(_.features))
91+
trans.zip(preds).collect().foreach { case (Row(pred1: Double), pred2: Double) =>
92+
assert(pred1 == pred2)
93+
}
94+
}
95+
}

0 commit comments

Comments
 (0)