|
| 1 | +package org.apache.spark.ml.regression |
| 2 | + |
| 3 | +import org.apache.spark.annotation.AlphaComponent |
| 4 | +import org.apache.spark.ml.LabeledPoint |
| 5 | +import org.apache.spark.ml.param.{ParamMap, HasMaxIter, HasRegParam} |
| 6 | +import org.apache.spark.mllib.linalg.{BLAS, Vector} |
| 7 | +import org.apache.spark.mllib.regression.LinearRegressionWithSGD |
| 8 | +import org.apache.spark.rdd.RDD |
| 9 | +import org.apache.spark.storage.StorageLevel |
| 10 | + |
| 11 | +/** |
| 12 | + * :: AlphaComponent :: |
| 13 | + * Params for linear regression. |
| 14 | + */ |
| 15 | +@AlphaComponent |
| 16 | +private[regression] trait LinearRegressionParams extends RegressorParams |
| 17 | + with HasRegParam with HasMaxIter |
| 18 | + |
| 19 | + |
| 20 | +/** |
| 21 | + * Logistic regression. |
| 22 | + */ |
| 23 | +class LinearRegression extends Regressor[LinearRegression, LinearRegressionModel] |
| 24 | + with LinearRegressionParams { |
| 25 | + |
| 26 | + // TODO: Extend IterativeEstimator |
| 27 | + |
| 28 | + setRegParam(0.1) |
| 29 | + setMaxIter(100) |
| 30 | + |
| 31 | + def setRegParam(value: Double): this.type = set(regParam, value) |
| 32 | + def setMaxIter(value: Int): this.type = set(maxIter, value) |
| 33 | + |
| 34 | + def train(dataset: RDD[LabeledPoint], paramMap: ParamMap): LinearRegressionModel = { |
| 35 | + val oldDataset = dataset.map { case LabeledPoint(label: Double, features: Vector, weight) => |
| 36 | + org.apache.spark.mllib.regression.LabeledPoint(label, features) |
| 37 | + } |
| 38 | + val handlePersistence = oldDataset.getStorageLevel == StorageLevel.NONE |
| 39 | + if (handlePersistence) { |
| 40 | + oldDataset.persist(StorageLevel.MEMORY_AND_DISK) |
| 41 | + } |
| 42 | + val lr = new LinearRegressionWithSGD() |
| 43 | + lr.optimizer |
| 44 | + .setRegParam(paramMap(regParam)) |
| 45 | + .setNumIterations(paramMap(maxIter)) |
| 46 | + val model = lr.run(oldDataset) |
| 47 | + val lrm = new LinearRegressionModel(this, paramMap, model.weights, model.intercept) |
| 48 | + if (handlePersistence) { |
| 49 | + oldDataset.unpersist() |
| 50 | + } |
| 51 | + lrm |
| 52 | + } |
| 53 | +} |
| 54 | + |
| 55 | + |
| 56 | +/** |
| 57 | + * :: AlphaComponent :: |
| 58 | + * Model produced by [[LinearRegression]]. |
| 59 | + */ |
| 60 | +@AlphaComponent |
| 61 | +class LinearRegressionModel private[ml] ( |
| 62 | + override val parent: LinearRegression, |
| 63 | + override val fittingParamMap: ParamMap, |
| 64 | + val weights: Vector, |
| 65 | + val intercept: Double) |
| 66 | + extends RegressionModel[LinearRegressionModel] |
| 67 | + with LinearRegressionParams { |
| 68 | + |
| 69 | + override def predict(features: Vector): Double = { |
| 70 | + BLAS.dot(features, weights) + intercept |
| 71 | + } |
| 72 | +} |
0 commit comments