From b4333a3b89362952e900acd7824e5d40500e3b9f Mon Sep 17 00:00:00 2001 From: Anthony Truchet Date: Tue, 29 Nov 2016 19:20:38 +0100 Subject: [PATCH] [SPARK-18471][MLLIB] Fix huge vectors of zero send in closure in L-BFGS Introduced util tTreeAggregatoreWithZeroGenerator to avoid sending huge zero vector in L-BFGS or similar agregation, as only the size of the zero value to be generated is captured in the closure. --- .../spark/mllib/optimization/LBFGS.scala | 25 ++++--- .../TreeAggregatorWithZeroGenerator.scala | 69 +++++++++++++++++++ 2 files changed, 85 insertions(+), 9 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/util/TreeAggregatorWithZeroGenerator.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 900eec18489c..c0478d95cb95 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -18,14 +18,13 @@ package org.apache.spark.mllib.optimization import scala.collection.mutable - import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS} - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.axpy +import org.apache.spark.mllib.util.TreeAggregateWithZeroGenerator import org.apache.spark.rdd.RDD /** @@ -241,16 +240,24 @@ object LBFGS extends Logging { val bcW = data.context.broadcast(w) val localGradient = gradient - val (gradientSum, lossSum) = data.treeAggregate((Vectors.zeros(n), 0.0))( - seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => - val l = localGradient.compute( - features, label, bcW.value, grad) + // Given (current accumulated gradient, current loss) and (label, features) + // tuples, updates the current gradient and current loss + val seqOp = (c: (Vector, Double), v: (Double, Vector)) => + (c, v) match { + case ((grad, loss), (label, features)) => + val l = localGradient.compute(features, label, bcW.value, grad) (grad, loss + l) - }, - combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => + } + + // Adds two (gradient, loss) tuples + val combOp = (c1: (Vector, Double), c2: (Vector, Double)) => + (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => axpy(1.0, grad2, grad1) (grad1, loss1 + loss2) - }) + } + + val (gradientSum, lossSum) = TreeAggregateWithZeroGenerator( + () => (Vectors.zeros(n), 0.0))(seqOp, combOp)(data) /** * regVal is sum of weight squares if it's L2 updater; diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/TreeAggregatorWithZeroGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/TreeAggregatorWithZeroGenerator.scala new file mode 100644 index 000000000000..7772a38b79e7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/TreeAggregatorWithZeroGenerator.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import scala.reflect.ClassTag + +import org.apache.spark.rdd.RDD + +object TreeAggregateWithZeroGenerator { + + /** + * Reduces the elements of this RDD in a multi-level tree pattern. + * + * treeAggregate wrapper that consumes a function to produce the zero element + * instead of the zero element itself. Useful, when the zero element is heavy + * but it's generator is 'small', e.g. Vectors.zeros(millions of elements) + * + * @param depth suggested depth of the tree (default: 2) + * @see [[org.apache.spark.rdd.RDD#reduce]] + */ + def apply[U: ClassTag, T: ClassTag](zeroGenerator: () => U)( + seqOp: (U, T) => U, + combOp: (U, U) => U, + depth: Int = 2) ( + data: RDD[T]): U = { + + val lazySeqOp: (Option[U], T) => Option[U] = (acc, entry) => + if (acc.isDefined) { + Some(seqOp(acc.get, entry)) + } else { + Some(seqOp(zeroGenerator(), entry)) + } + + val lazyCombOp: (Option[U], Option[U]) => Option[U] = (acc1, acc2) => { + if (acc1.isDefined && acc2.isDefined) { + Some(combOp(acc1.get, acc2.get)) + } else if (acc1.isDefined) { + acc1 + } else if (acc2.isDefined) { + acc2 + } else { + Option.empty[U] + } + } + + val result = data.treeAggregate(Option.empty[U])(lazySeqOp, lazyCombOp, depth = depth) + + if (result.isDefined) { + result.get + } else { + zeroGenerator() + } + } +}