Skip to content

Commit 9c2bf47

Browse files
committed
destroy featuresstd and featuresmean
1 parent 0d99795 commit 9c2bf47

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
279279
val yStd = if (rawYStd > 0) rawYStd else math.abs(yMean)
280280
val featuresMean = featuresSummarizer.mean.toArray
281281
val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
282+
val bcFeaturesMean = instances.context.broadcast(featuresMean)
283+
val bcFeaturesStd = instances.context.broadcast(featuresStd)
282284

283285
if (!$(fitIntercept) && (0 until numFeatures).exists { i =>
284286
featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) {
@@ -294,7 +296,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
294296
val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam
295297

296298
val costFun = new LeastSquaresCostFun(instances, yStd, yMean, $(fitIntercept),
297-
$(standardization), featuresStd, featuresMean, effectiveL2RegParam)
299+
$(standardization), bcFeaturesStd, bcFeaturesMean, effectiveL2RegParam)
298300

299301
val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
300302
new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
@@ -339,6 +341,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
339341
throw new SparkException(msg)
340342
}
341343

344+
bcFeaturesMean.destroy(blocking = false)
345+
bcFeaturesStd.destroy(blocking = false)
346+
342347
/*
343348
The coefficients are trained in the scaled space; we're converting them back to
344349
the original space.
@@ -1009,16 +1014,14 @@ private class LeastSquaresCostFun(
10091014
labelMean: Double,
10101015
fitIntercept: Boolean,
10111016
standardization: Boolean,
1012-
featuresStd: Array[Double],
1013-
featuresMean: Array[Double],
1017+
bcFeaturesStd: Broadcast[Array[Double]],
1018+
bcFeaturesMean: Broadcast[Array[Double]],
10141019
effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] {
10151020

1016-
val bcFeaturesStd = instances.context.broadcast(featuresStd)
1017-
val bcFeaturesMean = instances.context.broadcast(featuresMean)
1018-
10191021
override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
10201022
val coeffs = Vectors.fromBreeze(coefficients)
10211023
val bcCoeffs = instances.context.broadcast(coeffs)
1024+
val localFeaturesStd = bcFeaturesStd.value
10221025

10231026
val leastSquaresAggregator = {
10241027
val seqOp = (c: LeastSquaresAggregator, instance: Instance) => c.add(instance)
@@ -1044,13 +1047,13 @@ private class LeastSquaresCostFun(
10441047
totalGradientArray(index) += effectiveL2regParam * value
10451048
value * value
10461049
} else {
1047-
if (featuresStd(index) != 0.0) {
1050+
if (localFeaturesStd(index) != 0.0) {
10481051
// If `standardization` is false, we still standardize the data
10491052
// to improve the rate of convergence; as a result, we have to
10501053
// perform this reverse standardization by penalizing each component
10511054
// differently to get effectively the same objective function when
10521055
// the training dataset is not standardized.
1053-
val temp = value / (featuresStd(index) * featuresStd(index))
1056+
val temp = value / (localFeaturesStd(index) * localFeaturesStd(index))
10541057
totalGradientArray(index) += effectiveL2regParam * temp
10551058
value * temp
10561059
} else {

0 commit comments

Comments
 (0)