Skip to content

Commit bccab92

Browse files
committed
add task size test to LBFGS
1 parent 02103ba commit bccab92

File tree

4 files changed

+81
-17
lines changed

4 files changed

+81
-17
lines changed

mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -195,18 +195,21 @@ object LBFGS extends Logging {
195195

196196
override def calculate(weights: BDV[Double]) = {
197197
// Have a local copy to avoid the serialization of CostFun object which is not serializable.
198-
val localData = data
199198
val localGradient = gradient
200-
201-
val (gradientSum, lossSum) = localData.aggregate((BDV.zeros[Double](weights.size), 0.0))(
202-
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
203-
val l = localGradient.compute(
204-
features, label, Vectors.fromBreeze(weights), Vectors.fromBreeze(grad))
205-
(grad, loss + l)
206-
},
207-
combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
208-
(grad1 += grad2, loss1 + loss2)
209-
})
199+
val n = weights.length
200+
val bcWeights = data.context.broadcast(weights)
201+
202+
val (gradientSum, lossSum) = data.mapPartitions { iter =>
203+
val cumGrad = Vectors.dense(new Array[Double](n))
204+
val thisWeights = Vectors.fromBreeze(bcWeights.value)
205+
var loss = 0.0
206+
iter.foreach { case (label, features) =>
207+
loss += localGradient.compute(features, label, thisWeights, cumGrad)
208+
}
209+
Iterator((cumGrad.toBreeze.asInstanceOf[BDV[Double]], loss))
210+
}.reduce { case ((grad1, loss1), (grad2, loss2)) =>
211+
(grad1 += grad2, loss1 + loss2)
212+
}
210213

211214
/**
212215
* regVal is sum of weight squares if it's L2 updater;

mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717

1818
package org.apache.spark.mllib.optimization
1919

20-
import org.scalatest.FunSuite
21-
import org.scalatest.Matchers
20+
import java.util.Random
21+
22+
import org.scalatest.{FunSuite, Matchers}
2223

23-
import org.apache.spark.mllib.regression.LabeledPoint
2424
import org.apache.spark.mllib.linalg.Vectors
25-
import org.apache.spark.mllib.util.LocalSparkContext
25+
import org.apache.spark.mllib.regression.LabeledPoint
26+
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
2627

2728
class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
2829

@@ -230,3 +231,22 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
230231
"The weight differences between LBFGS and GD should be within 2%.")
231232
}
232233
}
234+
235+
class LBFGSTaskSuite extends FunSuite with LocalClusterSparkContext {
236+
237+
test("task size should be small") {
238+
val m = 10
239+
val n = 200000
240+
val examples = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
241+
val random = new Random(idx)
242+
iter.map(i => (1.0, Vectors.dense(Array.fill(n)(random.nextDouble))))
243+
}.cache()
244+
val lbfgs = new LBFGS(new LogisticGradient, new SquaredL2Updater)
245+
.setNumCorrections(1)
246+
.setConvergenceTol(1e-12)
247+
.setMaxNumIterations(1)
248+
.setRegParam(1.0)
249+
val random = new Random(0)
250+
val weights = lbfgs.optimize(examples, Vectors.dense(Array.fill(n)(random.nextDouble)))
251+
}
252+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.util
19+
20+
import org.scalatest.{Suite, BeforeAndAfterAll}
21+
22+
import org.apache.spark.{SparkConf, SparkContext}
23+
24+
trait LocalClusterSparkContext extends BeforeAndAfterAll { self: Suite =>
25+
@transient var sc: SparkContext = _
26+
27+
override def beforeAll() {
28+
val conf = new SparkConf()
29+
.setMaster("local-cluster[2, 1, 512]")
30+
.setAppName("test-cluster")
31+
.set("spark.akka.frameSize", "1") // set to 1MB to detect direct serialization of data
32+
sc = new SparkContext(conf)
33+
super.beforeAll()
34+
}
35+
36+
override def afterAll() {
37+
if (sc != null) {
38+
sc.stop()
39+
}
40+
super.afterAll()
41+
}
42+
}

mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@ trait LocalSparkContext extends BeforeAndAfterAll { self: Suite =>
2727

2828
override def beforeAll() {
2929
val conf = new SparkConf()
30-
.setMaster("local-cluster[2, 1, 512]")
30+
.setMaster("local")
3131
.setAppName("test")
32-
.set("spark.akka.frameSize", "1") // set to 1MB to detect direct serialization of data
3332
sc = new SparkContext(conf)
3433
super.beforeAll()
3534
}

0 commit comments

Comments
 (0)