Skip to content

Commit bb2665a

Browse files
committed
merged with elastic net pr
1 parent ecda302 commit bb2665a

File tree

3 files changed

+29
-19
lines changed

3 files changed

+29
-19
lines changed

mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ private[shared] object SharedParamsCodeGen {
5454
isValid = "ParamValidate.gtEq[Int](1)"),
5555
ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
5656
ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()")),
57-
ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter"),
57+
ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]." +
58+
" For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.",
59+
isValid = "ParamValidate.inRange[Double](0, 1)"),
5860
ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"),
5961
ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization."))
6062

mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -250,37 +250,33 @@ private[ml] trait HasSeed extends Params {
250250
}
251251

252252
/**
253-
* :: DeveloperApi ::
254-
* Trait for shared param elasticNetParam.
253+
* (private[ml]) Trait for shared param elasticNetParam.
255254
*/
256-
@DeveloperApi
257-
trait HasElasticNetParam extends Params {
255+
private[ml] trait HasElasticNetParam extends Params {
258256

259257
/**
260-
* Param for the ElasticNet mixing parameter.
258+
* Param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty..
261259
* @group param
262260
*/
263-
final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter")
261+
final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", ParamValidate.inRange[Double](0, 1))
264262

265263
/** @group getParam */
266264
final def getElasticNetParam: Double = getOrDefault(elasticNetParam)
267265
}
268266

269267
/**
270-
* :: DeveloperApi ::
271-
* Trait for shared param tol.
268+
* (private[ml]) Trait for shared param convergenceTol.
272269
*/
273-
@DeveloperApi
274-
trait HasTol extends Params {
270+
private[ml] trait HasConvergenceTol extends Params {
275271

276272
/**
277273
* Param for the convergence tolerance for iterative algorithms.
278274
* @group param
279275
*/
280-
final val tol: DoubleParam = new DoubleParam(this, "tol", "the convergence tolerance for iterative algorithms")
276+
final val convergenceTol: DoubleParam = new DoubleParam(this, "convergenceTol", "the convergence tolerance for iterative algorithms")
281277

282278
/** @group getParam */
283-
final def getTol: Double = getOrDefault(tol)
279+
final def getConvergenceTol: Double = getOrDefault(convergenceTol)
284280
}
285281

286282
/**

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction}
2525

2626
import org.apache.spark.annotation.AlphaComponent
2727
import org.apache.spark.ml.param.{Params, ParamMap}
28-
import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol}
28+
import org.apache.spark.ml.param.shared.{HasConvergenceTol, HasElasticNetParam, HasMaxIter,
29+
HasRegParam}
2930
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
3031
import org.apache.spark.mllib.linalg.{Vector, Vectors}
3132
import org.apache.spark.mllib.linalg.BLAS._
@@ -40,12 +41,22 @@ import org.apache.spark.Logging
4041
* Params for linear regression.
4142
*/
4243
private[regression] trait LinearRegressionParams extends RegressorParams
43-
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol // TODO: elasticnetparam, tol
44+
with HasRegParam with HasElasticNetParam with HasMaxIter with HasConvergenceTol
4445

4546
/**
4647
* :: AlphaComponent ::
4748
*
4849
* Linear regression.
50+
*
51+
* The learning objective is to minimize the squared error, with regularization.
52+
* The specific squared error loss function used is:
53+
* L = 1/2n ||A weights - y||^2^
54+
*
55+
* This support multiple types of regularization:
56+
* - none (a.k.a. ordinary least squares)
57+
* - L2 (ridge regression)
58+
* - L1 (Lasso)
59+
* - L2 + L1 (elastic net)
4960
*/
5061
@AlphaComponent
5162
class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel]
@@ -83,8 +94,8 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
8394
* Default is 1E-6.
8495
* @group setParam
8596
*/
86-
def setTol(value: Double): this.type = set(tol, value)
87-
setDefault(tol -> 1E-6)
97+
def setTol(value: Double): this.type = set(convergenceTol, value)
98+
setDefault(convergenceTol -> 1E-6)
8899

89100
override protected def train(dataset: DataFrame, paramMap: ParamMap): LinearRegressionModel = {
90101
// Extract columns from data. If dataset is persisted, do not persist instances.
@@ -133,9 +144,10 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
133144
featuresStd, featuresMean, effectiveL2RegParam)
134145

135146
val optimizer = if (paramMap(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
136-
new BreezeLBFGS[BDV[Double]](paramMap(maxIter), 10, paramMap(tol))
147+
new BreezeLBFGS[BDV[Double]](paramMap(maxIter), 10, paramMap(convergenceTol))
137148
} else {
138-
new BreezeOWLQN[Int, BDV[Double]](paramMap(maxIter), 10, effectiveL1RegParam, paramMap(tol))
149+
new BreezeOWLQN[Int, BDV[Double]](paramMap(maxIter), 10, effectiveL1RegParam,
150+
paramMap(convergenceTol))
139151
}
140152

141153
val initialWeights = Vectors.zeros(numFeatures)

0 commit comments

Comments
 (0)