From 1447c234092339f67d1887bfc75731665264b770 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Fri, 29 Aug 2014 14:13:11 -0700 Subject: [PATCH] Fixed updater bug --- .../scala/org/apache/spark/mllib/optimization/Updater.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala index 3ed3a5b9b384..fe745dfb3aec 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala @@ -112,6 +112,7 @@ class L1Updater extends Updater { val thisIterStepSize = stepSize / math.sqrt(iter) // Take gradient step val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector + val norm = brzNorm(brzWeights, 1.0) brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights) // Apply proximal operator (soft thresholding) val shrinkageVal = regParam * thisIterStepSize @@ -122,7 +123,7 @@ class L1Updater extends Updater { i += 1 } - (Vectors.fromBreeze(brzWeights), brzNorm(brzWeights, 1.0) * regParam) + (Vectors.fromBreeze(brzWeights), regParam * norm) } } @@ -146,9 +147,9 @@ class SquaredL2Updater extends Updater { // w' = (1 - thisIterStepSize * regParam) * w - thisIterStepSize * gradient val thisIterStepSize = stepSize / math.sqrt(iter) val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector + val norm = brzNorm(brzWeights, 2.0) brzWeights :*= (1.0 - thisIterStepSize * regParam) brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights) - val norm = brzNorm(brzWeights, 2.0) (Vectors.fromBreeze(brzWeights), 0.5 * regParam * norm * norm) }