Skip to content

Commit d705e87

Browse files
committed
Added LinearRegression and Regressor back from ml-api branch
1 parent 52f4fde commit d705e87

File tree

2 files changed

+105
-0
lines changed

2 files changed

+105
-0
lines changed
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package org.apache.spark.ml.regression
2+
3+
import org.apache.spark.annotation.AlphaComponent
4+
import org.apache.spark.ml.LabeledPoint
5+
import org.apache.spark.ml.param.{ParamMap, HasMaxIter, HasRegParam}
6+
import org.apache.spark.mllib.linalg.{BLAS, Vector}
7+
import org.apache.spark.mllib.regression.LinearRegressionWithSGD
8+
import org.apache.spark.rdd.RDD
9+
import org.apache.spark.storage.StorageLevel
10+
11+
/**
12+
* :: AlphaComponent ::
13+
* Params for linear regression.
14+
*/
15+
@AlphaComponent
16+
private[regression] trait LinearRegressionParams extends RegressorParams
17+
with HasRegParam with HasMaxIter
18+
19+
20+
/**
21+
* Logistic regression.
22+
*/
23+
class LinearRegression extends Regressor[LinearRegression, LinearRegressionModel]
24+
with LinearRegressionParams {
25+
26+
// TODO: Extend IterativeEstimator
27+
28+
setRegParam(0.1)
29+
setMaxIter(100)
30+
31+
def setRegParam(value: Double): this.type = set(regParam, value)
32+
def setMaxIter(value: Int): this.type = set(maxIter, value)
33+
34+
def train(dataset: RDD[LabeledPoint], paramMap: ParamMap): LinearRegressionModel = {
35+
val oldDataset = dataset.map { case LabeledPoint(label: Double, features: Vector, weight) =>
36+
org.apache.spark.mllib.regression.LabeledPoint(label, features)
37+
}
38+
val handlePersistence = oldDataset.getStorageLevel == StorageLevel.NONE
39+
if (handlePersistence) {
40+
oldDataset.persist(StorageLevel.MEMORY_AND_DISK)
41+
}
42+
val lr = new LinearRegressionWithSGD()
43+
lr.optimizer
44+
.setRegParam(paramMap(regParam))
45+
.setNumIterations(paramMap(maxIter))
46+
val model = lr.run(oldDataset)
47+
val lrm = new LinearRegressionModel(this, paramMap, model.weights, model.intercept)
48+
if (handlePersistence) {
49+
oldDataset.unpersist()
50+
}
51+
lrm
52+
}
53+
}
54+
55+
56+
/**
57+
* :: AlphaComponent ::
58+
* Model produced by [[LinearRegression]].
59+
*/
60+
@AlphaComponent
61+
class LinearRegressionModel private[ml] (
62+
override val parent: LinearRegression,
63+
override val fittingParamMap: ParamMap,
64+
val weights: Vector,
65+
val intercept: Double)
66+
extends RegressionModel[LinearRegressionModel]
67+
with LinearRegressionParams {
68+
69+
override def predict(features: Vector): Double = {
70+
BLAS.dot(features, weights) + intercept
71+
}
72+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package org.apache.spark.ml.regression
2+
3+
import org.apache.spark.annotation.AlphaComponent
4+
import org.apache.spark.ml.Evaluator
5+
import org.apache.spark.ml.evaluation.RegressionEvaluator
6+
import org.apache.spark.ml.impl.estimator.{PredictionModel, HasDefaultEvaluator, Predictor,
7+
PredictorParams}
8+
import org.apache.spark.mllib.linalg.Vector
9+
10+
@AlphaComponent
11+
private[regression] trait RegressorParams extends PredictorParams
12+
13+
/**
14+
* Single-label regression
15+
*/
16+
abstract class Regressor[Learner <: Regressor[Learner, M], M <: RegressionModel[M]]
17+
extends Predictor[Learner, M]
18+
with RegressorParams
19+
with HasDefaultEvaluator {
20+
21+
override def defaultEvaluator: Evaluator = new RegressionEvaluator
22+
}
23+
24+
25+
private[ml] abstract class RegressionModel[M <: RegressionModel[M]]
26+
extends PredictionModel[M] with RegressorParams {
27+
28+
/**
29+
* Predict real-valued label for the given features.
30+
*/
31+
def predict(features: Vector): Double
32+
33+
}

0 commit comments

Comments
 (0)