Skip to content

Commit 1447c23

Browse files
author
DB Tsai
committed
Fixed updater bug
1 parent e248328 commit 1447c23

File tree

1 file changed

+3
-2
lines changed
  • mllib/src/main/scala/org/apache/spark/mllib/optimization

1 file changed

+3
-2
lines changed

mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ class L1Updater extends Updater {
112112
val thisIterStepSize = stepSize / math.sqrt(iter)
113113
// Take gradient step
114114
val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector
115+
val norm = brzNorm(brzWeights, 1.0)
115116
brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights)
116117
// Apply proximal operator (soft thresholding)
117118
val shrinkageVal = regParam * thisIterStepSize
@@ -122,7 +123,7 @@ class L1Updater extends Updater {
122123
i += 1
123124
}
124125

125-
(Vectors.fromBreeze(brzWeights), brzNorm(brzWeights, 1.0) * regParam)
126+
(Vectors.fromBreeze(brzWeights), regParam * norm)
126127
}
127128
}
128129

@@ -146,9 +147,9 @@ class SquaredL2Updater extends Updater {
146147
// w' = (1 - thisIterStepSize * regParam) * w - thisIterStepSize * gradient
147148
val thisIterStepSize = stepSize / math.sqrt(iter)
148149
val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector
150+
val norm = brzNorm(brzWeights, 2.0)
149151
brzWeights :*= (1.0 - thisIterStepSize * regParam)
150152
brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights)
151-
val norm = brzNorm(brzWeights, 2.0)
152153

153154
(Vectors.fromBreeze(brzWeights), 0.5 * regParam * norm * norm)
154155
}

0 commit comments

Comments
 (0)