Skip to content

Commit ef8fdea

Browse files
committed
dimension corrections
1 parent 86505b7 commit ef8fdea

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -950,7 +950,8 @@ private class LogisticAggregator(
950950
private var lossSum = 0.0
951951

952952
private val dim = numFeatures
953-
private val gradientSumArray = Array.ofDim[Double](dim)
953+
private val gradientSumArray =
954+
Array.ofDim[Double](if (fitIntercept) numFeatures + 1 else numFeatures)
954955

955956
/**
956957
* Add a new training instance to this LogisticAggregator, and update the loss and gradient
@@ -1092,7 +1093,7 @@ private class LogisticCostFun(
10921093
val combOp = (c1: LogisticAggregator, c2: LogisticAggregator) => c1.merge(c2)
10931094

10941095
instances.treeAggregate(
1095-
new LogisticAggregator(n, numClasses, fitIntercept)
1096+
new LogisticAggregator(numFeatures, numClasses, fitIntercept)
10961097
)(seqOp, combOp)
10971098
}
10981099

mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ class LogisticRegressionWithLBFGS
421421
private def run(input: RDD[LabeledPoint], initialWeights: Vector, userSuppliedWeights: Boolean):
422422
LogisticRegressionModel = {
423423
// ml's Logistic regression only supports binary classification currently.
424-
if (numOfLinearPredictor == 1 && false) {
424+
if (numOfLinearPredictor == 1) {
425425
def runWithMlLogisticRegression(elasticNetParam: Double) = {
426426
// Prepare the ml LogisticRegression based on our settings
427427
val lr = new org.apache.spark.ml.classification.LogisticRegression()

0 commit comments

Comments
 (0)