@@ -25,6 +25,7 @@ import org.apache.hadoop.fs.Path
2525
2626import org .apache .spark .SparkException
2727import org .apache .spark .annotation .{Experimental , Since }
28+ import org .apache .spark .broadcast .Broadcast
2829import org .apache .spark .internal .Logging
2930import org .apache .spark .ml .feature .Instance
3031import org .apache .spark .ml .linalg ._
@@ -937,6 +938,8 @@ class BinaryLogisticRegressionSummary private[classification] (
937938 * @param fitIntercept Whether to fit an intercept term.
938939 */
939940private class LogisticAggregator (
941+ val bcCoeffs : Broadcast [Vector ],
942+ val bcFeaturesStd : Broadcast [Array [Double ]],
940943 private val numFeatures : Int ,
941944 numClasses : Int ,
942945 fitIntercept : Boolean ) extends Serializable {
@@ -952,29 +955,25 @@ private class LogisticAggregator(
952955 * of the objective function.
953956 *
954957 * @param instance The instance of data point to be added.
955- * @param coefficients The coefficients corresponding to the features.
956- * @param featuresStd The standard deviation values of the features.
957958 * @return This LogisticAggregator object.
958959 */
959- def add (
960- instance : Instance ,
961- coefficients : Vector ,
962- featuresStd : Array [Double ]): this .type = {
960+ def add (instance : Instance ): this .type = {
963961 instance match { case Instance (label, weight, features) =>
964962 require(numFeatures == features.size, s " Dimensions mismatch when adding new instance. " +
965963 s " Expecting $numFeatures but got ${features.size}. " )
966964 require(weight >= 0.0 , s " instance weight, $weight has to be >= 0.0 " )
967965
968966 if (weight == 0.0 ) return this
969967
970- val coefficientsArray = coefficients match {
968+ val coefficientsArray = bcCoeffs.value match {
971969 case dv : DenseVector => dv.values
972970 case _ =>
973971 throw new IllegalArgumentException (
974- s " coefficients only supports dense vector but got type ${coefficients .getClass}. " )
972+ s " coefficients only supports dense vector but got type ${bcCoeffs.value .getClass}. " )
975973 }
976974 val localGradientSumArray = gradientSumArray
977975
976+ val featuresStd = bcFeaturesStd.value
978977 numClasses match {
979978 case 2 =>
980979 // For Binary Logistic Regression.
@@ -1075,20 +1074,20 @@ private class LogisticCostFun(
10751074 featuresMean : Array [Double ],
10761075 regParamL2 : Double ) extends DiffFunction [BDV [Double ]] {
10771076
1077+ val bcFeaturesStd = instances.context.broadcast(featuresStd)
1078+
10781079 override def calculate (coefficients : BDV [Double ]): (Double , BDV [Double ]) = {
10791080 val numFeatures = featuresStd.length
10801081 val coeffs = Vectors .fromBreeze(coefficients)
1082+ val bcCoeffs = instances.context.broadcast(coeffs)
10811083 val n = coeffs.size
1082- val localFeaturesStd = featuresStd
1083-
10841084
10851085 val logisticAggregator = {
1086- val seqOp = (c : LogisticAggregator , instance : Instance ) =>
1087- c.add(instance, coeffs, localFeaturesStd)
1086+ val seqOp = (c : LogisticAggregator , instance : Instance ) => c.add(instance)
10881087 val combOp = (c1 : LogisticAggregator , c2 : LogisticAggregator ) => c1.merge(c2)
10891088
10901089 instances.treeAggregate(
1091- new LogisticAggregator (numFeatures, numClasses, fitIntercept)
1090+ new LogisticAggregator (bcCoeffs, bcFeaturesStd, numFeatures, numClasses, fitIntercept)
10921091 )(seqOp, combOp)
10931092 }
10941093
0 commit comments