Skip to content

Commit 1f0a469

Browse files
sethahmengxr
authored andcommitted
[SPARK-16008][ML] Remove unnecessary serialization in logistic regression
JIRA: [SPARK-16008](https://issues.apache.org/jira/browse/SPARK-16008) ## What changes were proposed in this pull request? `LogisticAggregator` stores references to two arrays of dimension `numFeatures` which are serialized before the combine op, unnecessarily. This results in the shuffle write being ~3x (for multiclass logistic regression, this number will go up) larger than it should be (in MLlib, for instance, it is 3x smaller). This patch modifies `LogisticAggregator.add` to accept the two arrays as method parameters which avoids the serialization. ## How was this patch tested? I tested this locally and verified the serialization reduction. ![image](https://cloud.githubusercontent.com/assets/7275795/16140387/d2974bac-3404-11e6-94f9-268860c931a2.png) Additionally, I ran some tests of a 4 node cluster (4x48 cores, 4x128 GB RAM). Data set size of 2M rows and 10k features showed >2x iteration speedup. Author: sethah <[email protected]> Closes #13729 from sethah/lr_improvement.
1 parent 34d6c4c commit 1f0a469

File tree

1 file changed

+29
-28
lines changed

1 file changed

+29
-28
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -937,50 +937,47 @@ class BinaryLogisticRegressionSummary private[classification] (
937937
* Two LogisticAggregator can be merged together to have a summary of loss and gradient of
938938
* the corresponding joint dataset.
939939
*
940-
* @param coefficients The coefficients corresponding to the features.
941940
* @param numClasses the number of possible outcomes for k classes classification problem in
942941
* Multinomial Logistic Regression.
943942
* @param fitIntercept Whether to fit an intercept term.
944-
* @param featuresStd The standard deviation values of the features.
945-
* @param featuresMean The mean values of the features.
946943
*/
947944
private class LogisticAggregator(
948-
coefficients: Vector,
945+
private val numFeatures: Int,
949946
numClasses: Int,
950-
fitIntercept: Boolean,
951-
featuresStd: Array[Double],
952-
featuresMean: Array[Double]) extends Serializable {
947+
fitIntercept: Boolean) extends Serializable {
953948

954949
private var weightSum = 0.0
955950
private var lossSum = 0.0
956951

957-
private val coefficientsArray = coefficients match {
958-
case dv: DenseVector => dv.values
959-
case _ =>
960-
throw new IllegalArgumentException(
961-
s"coefficients only supports dense vector but got type ${coefficients.getClass}.")
962-
}
963-
964-
private val dim = if (fitIntercept) coefficientsArray.length - 1 else coefficientsArray.length
965-
966-
private val gradientSumArray = Array.ofDim[Double](coefficientsArray.length)
952+
private val gradientSumArray =
953+
Array.ofDim[Double](if (fitIntercept) numFeatures + 1 else numFeatures)
967954

968955
/**
969956
* Add a new training instance to this LogisticAggregator, and update the loss and gradient
970957
* of the objective function.
971958
*
972959
* @param instance The instance of data point to be added.
960+
* @param coefficients The coefficients corresponding to the features.
961+
* @param featuresStd The standard deviation values of the features.
973962
* @return This LogisticAggregator object.
974963
*/
975-
def add(instance: Instance): this.type = {
964+
def add(
965+
instance: Instance,
966+
coefficients: Vector,
967+
featuresStd: Array[Double]): this.type = {
976968
instance match { case Instance(label, weight, features) =>
977-
require(dim == features.size, s"Dimensions mismatch when adding new instance." +
978-
s" Expecting $dim but got ${features.size}.")
969+
require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." +
970+
s" Expecting $numFeatures but got ${features.size}.")
979971
require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")
980972

981973
if (weight == 0.0) return this
982974

983-
val localCoefficientsArray = coefficientsArray
975+
val coefficientsArray = coefficients match {
976+
case dv: DenseVector => dv.values
977+
case _ =>
978+
throw new IllegalArgumentException(
979+
s"coefficients only supports dense vector but got type ${coefficients.getClass}.")
980+
}
984981
val localGradientSumArray = gradientSumArray
985982

986983
numClasses match {
@@ -990,11 +987,11 @@ private class LogisticAggregator(
990987
var sum = 0.0
991988
features.foreachActive { (index, value) =>
992989
if (featuresStd(index) != 0.0 && value != 0.0) {
993-
sum += localCoefficientsArray(index) * (value / featuresStd(index))
990+
sum += coefficientsArray(index) * (value / featuresStd(index))
994991
}
995992
}
996993
sum + {
997-
if (fitIntercept) localCoefficientsArray(dim) else 0.0
994+
if (fitIntercept) coefficientsArray(numFeatures) else 0.0
998995
}
999996
}
1000997

@@ -1007,7 +1004,7 @@ private class LogisticAggregator(
10071004
}
10081005

10091006
if (fitIntercept) {
1010-
localGradientSumArray(dim) += multiplier
1007+
localGradientSumArray(numFeatures) += multiplier
10111008
}
10121009

10131010
if (label > 0) {
@@ -1034,8 +1031,8 @@ private class LogisticAggregator(
10341031
* @return This LogisticAggregator object.
10351032
*/
10361033
def merge(other: LogisticAggregator): this.type = {
1037-
require(dim == other.dim, s"Dimensions mismatch when merging with another " +
1038-
s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.")
1034+
require(numFeatures == other.numFeatures, s"Dimensions mismatch when merging with another " +
1035+
s"LeastSquaresAggregator. Expecting $numFeatures but got ${other.numFeatures}.")
10391036

10401037
if (other.weightSum != 0.0) {
10411038
weightSum += other.weightSum
@@ -1086,13 +1083,17 @@ private class LogisticCostFun(
10861083
override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
10871084
val numFeatures = featuresStd.length
10881085
val coeffs = Vectors.fromBreeze(coefficients)
1086+
val n = coeffs.size
1087+
val localFeaturesStd = featuresStd
1088+
10891089

10901090
val logisticAggregator = {
1091-
val seqOp = (c: LogisticAggregator, instance: Instance) => c.add(instance)
1091+
val seqOp = (c: LogisticAggregator, instance: Instance) =>
1092+
c.add(instance, coeffs, localFeaturesStd)
10921093
val combOp = (c1: LogisticAggregator, c2: LogisticAggregator) => c1.merge(c2)
10931094

10941095
instances.treeAggregate(
1095-
new LogisticAggregator(coeffs, numClasses, fitIntercept, featuresStd, featuresMean)
1096+
new LogisticAggregator(numFeatures, numClasses, fitIntercept)
10961097
)(seqOp, combOp)
10971098
}
10981099

0 commit comments

Comments
 (0)