Skip to content

Commit 417aa1e

Browse files
committed
update
1 parent 3861273 commit 417aa1e

File tree

1 file changed

+12
-13
lines changed

1 file changed

+12
-13
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.hadoop.fs.Path
2525

2626
import org.apache.spark.SparkException
2727
import org.apache.spark.annotation.{Experimental, Since}
28+
import org.apache.spark.broadcast.Broadcast
2829
import org.apache.spark.internal.Logging
2930
import org.apache.spark.ml.feature.Instance
3031
import org.apache.spark.ml.linalg._
@@ -937,6 +938,8 @@ class BinaryLogisticRegressionSummary private[classification] (
937938
* @param fitIntercept Whether to fit an intercept term.
938939
*/
939940
private class LogisticAggregator(
941+
val bcCoeffs: Broadcast[Vector],
942+
val bcFeaturesStd: Broadcast[Array[Double]],
940943
private val numFeatures: Int,
941944
numClasses: Int,
942945
fitIntercept: Boolean) extends Serializable {
@@ -952,29 +955,25 @@ private class LogisticAggregator(
952955
* of the objective function.
953956
*
954957
* @param instance The instance of data point to be added.
955-
* @param coefficients The coefficients corresponding to the features.
956-
* @param featuresStd The standard deviation values of the features.
957958
* @return This LogisticAggregator object.
958959
*/
959-
def add(
960-
instance: Instance,
961-
coefficients: Vector,
962-
featuresStd: Array[Double]): this.type = {
960+
def add(instance: Instance): this.type = {
963961
instance match { case Instance(label, weight, features) =>
964962
require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." +
965963
s" Expecting $numFeatures but got ${features.size}.")
966964
require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")
967965

968966
if (weight == 0.0) return this
969967

970-
val coefficientsArray = coefficients match {
968+
val coefficientsArray = bcCoeffs.value match {
971969
case dv: DenseVector => dv.values
972970
case _ =>
973971
throw new IllegalArgumentException(
974-
s"coefficients only supports dense vector but got type ${coefficients.getClass}.")
972+
s"coefficients only supports dense vector but got type ${bcCoeffs.value.getClass}.")
975973
}
976974
val localGradientSumArray = gradientSumArray
977975

976+
val featuresStd = bcFeaturesStd.value
978977
numClasses match {
979978
case 2 =>
980979
// For Binary Logistic Regression.
@@ -1075,20 +1074,20 @@ private class LogisticCostFun(
10751074
featuresMean: Array[Double],
10761075
regParamL2: Double) extends DiffFunction[BDV[Double]] {
10771076

1077+
val bcFeaturesStd = instances.context.broadcast(featuresStd)
1078+
10781079
override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
10791080
val numFeatures = featuresStd.length
10801081
val coeffs = Vectors.fromBreeze(coefficients)
1082+
val bcCoeffs = instances.context.broadcast(coeffs)
10811083
val n = coeffs.size
1082-
val localFeaturesStd = featuresStd
1083-
10841084

10851085
val logisticAggregator = {
1086-
val seqOp = (c: LogisticAggregator, instance: Instance) =>
1087-
c.add(instance, coeffs, localFeaturesStd)
1086+
val seqOp = (c: LogisticAggregator, instance: Instance) => c.add(instance)
10881087
val combOp = (c1: LogisticAggregator, c2: LogisticAggregator) => c1.merge(c2)
10891088

10901089
instances.treeAggregate(
1091-
new LogisticAggregator(numFeatures, numClasses, fitIntercept)
1090+
new LogisticAggregator(bcCoeffs, bcFeaturesStd, numFeatures, numClasses, fitIntercept)
10921091
)(seqOp, combOp)
10931092
}
10941093

0 commit comments

Comments
 (0)