Skip to content

Commit 011432c

Browse files
author
DB Tsai
committed
From Alpine Data Labs
1 parent 38ccd6e commit 011432c

File tree

1 file changed

+5
-16
lines changed

1 file changed

+5
-16
lines changed

mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,12 @@
1717

1818
package org.apache.spark.mllib.regression
1919

20-
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV}
21-
2220
import org.apache.spark.annotation.DeveloperApi
2321
import org.apache.spark.{Logging, SparkException}
2422
import org.apache.spark.rdd.RDD
2523
import org.apache.spark.mllib.optimization._
2624
import org.apache.spark.mllib.linalg.{Vectors, Vector}
25+
import org.apache.spark.mllib.util.MLUtils._
2726

2827
/**
2928
* :: DeveloperApi ::
@@ -124,16 +123,6 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
124123
run(input, initialWeights)
125124
}
126125

127-
/** Prepends one to the input vector. */
128-
private def prependOne(vector: Vector): Vector = {
129-
val vector1 = vector.toBreeze match {
130-
case dv: BDV[Double] => BDV.vertcat(BDV.ones[Double](1), dv)
131-
case sv: BSV[Double] => BSV.vertcat(new BSV[Double](Array(0), Array(1.0), 1), sv)
132-
case v: Any => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
133-
}
134-
Vectors.fromBreeze(vector1)
135-
}
136-
137126
/**
138127
* Run the algorithm with the configured parameters on an input RDD
139128
* of LabeledPoint entries starting from the initial weights provided.
@@ -147,23 +136,23 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
147136

148137
// Prepend an extra variable consisting of all 1.0's for the intercept.
149138
val data = if (addIntercept) {
150-
input.map(labeledPoint => (labeledPoint.label, prependOne(labeledPoint.features)))
139+
input.map(labeledPoint => (labeledPoint.label, appendBias(labeledPoint.features)))
151140
} else {
152141
input.map(labeledPoint => (labeledPoint.label, labeledPoint.features))
153142
}
154143

155144
val initialWeightsWithIntercept = if (addIntercept) {
156-
prependOne(initialWeights)
145+
appendBias(initialWeights)
157146
} else {
158147
initialWeights
159148
}
160149

161150
val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept)
162151

163-
val intercept = if (addIntercept) weightsWithIntercept(0) else 0.0
152+
val intercept = if (addIntercept) weightsWithIntercept(weightsWithIntercept.size - 1) else 0.0
164153
val weights =
165154
if (addIntercept) {
166-
Vectors.dense(weightsWithIntercept.toArray.slice(1, weightsWithIntercept.size))
155+
Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1))
167156
} else {
168157
weightsWithIntercept
169158
}

0 commit comments

Comments
 (0)