Skip to content

Commit f515f94

Browse files
witgomengxr
authored andcommitted
[SPARK-4526][MLLIB]GradientDescent get a wrong gradient value according to the gradient formula.
This is caused by the miniBatchSize parameter.The number of `RDD.sample` returns is not fixed. cc mengxr Author: GuoQiang Li <[email protected]> Closes #3399 from witgo/GradientDescent and squashes the following commits: 13cb228 [GuoQiang Li] review commit 668ab66 [GuoQiang Li] Double to Long b6aa11a [GuoQiang Li] Check miniBatchSize is greater than 0 0b5c3e3 [GuoQiang Li] Minor fix 12e7424 [GuoQiang Li] GradientDescent get a wrong gradient value according to the gradient formula, which is caused by the miniBatchSize parameter.
1 parent 89f9122 commit f515f94

File tree

1 file changed

+26
-19
lines changed

1 file changed

+26
-19
lines changed

mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -160,14 +160,15 @@ object GradientDescent extends Logging {
160160
val stochasticLossHistory = new ArrayBuffer[Double](numIterations)
161161

162162
val numExamples = data.count()
163-
val miniBatchSize = numExamples * miniBatchFraction
164163

165164
// if no data, return initial weights to avoid NaNs
166165
if (numExamples == 0) {
167-
168-
logInfo("GradientDescent.runMiniBatchSGD returning initial weights, no data found")
166+
logWarning("GradientDescent.runMiniBatchSGD returning initial weights, no data found")
169167
return (initialWeights, stochasticLossHistory.toArray)
168+
}
170169

170+
if (numExamples * miniBatchFraction < 1) {
171+
logWarning("The miniBatchFraction is too small")
171172
}
172173

173174
// Initialize weights as a column vector
@@ -185,25 +186,31 @@ object GradientDescent extends Logging {
185186
val bcWeights = data.context.broadcast(weights)
186187
// Sample a subset (fraction miniBatchFraction) of the total data
187188
// compute and sum up the subgradients on this subset (this is one map-reduce)
188-
val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i)
189-
.treeAggregate((BDV.zeros[Double](n), 0.0))(
190-
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
191-
val l = gradient.compute(features, label, bcWeights.value, Vectors.fromBreeze(grad))
192-
(grad, loss + l)
189+
val (gradientSum, lossSum, miniBatchSize) = data.sample(false, miniBatchFraction, 42 + i)
190+
.treeAggregate((BDV.zeros[Double](n), 0.0, 0L))(
191+
seqOp = (c, v) => {
192+
// c: (grad, loss, count), v: (label, features)
193+
val l = gradient.compute(v._2, v._1, bcWeights.value, Vectors.fromBreeze(c._1))
194+
(c._1, c._2 + l, c._3 + 1)
193195
},
194-
combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
195-
(grad1 += grad2, loss1 + loss2)
196+
combOp = (c1, c2) => {
197+
// c: (grad, loss, count)
198+
(c1._1 += c2._1, c1._2 + c2._2, c1._3 + c2._3)
196199
})
197200

198-
/**
199-
* NOTE(Xinghao): lossSum is computed using the weights from the previous iteration
200-
* and regVal is the regularization value computed in the previous iteration as well.
201-
*/
202-
stochasticLossHistory.append(lossSum / miniBatchSize + regVal)
203-
val update = updater.compute(
204-
weights, Vectors.fromBreeze(gradientSum / miniBatchSize), stepSize, i, regParam)
205-
weights = update._1
206-
regVal = update._2
201+
if (miniBatchSize > 0) {
202+
/**
203+
* NOTE(Xinghao): lossSum is computed using the weights from the previous iteration
204+
* and regVal is the regularization value computed in the previous iteration as well.
205+
*/
206+
stochasticLossHistory.append(lossSum / miniBatchSize + regVal)
207+
val update = updater.compute(
208+
weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble), stepSize, i, regParam)
209+
weights = update._1
210+
regVal = update._2
211+
} else {
212+
logWarning(s"Iteration ($i/$numIterations). The size of sampled batch is zero")
213+
}
207214
}
208215

209216
logInfo("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s".format(

0 commit comments

Comments
 (0)