Skip to content

Commit 52beb20

Browse files
DB Tsaimengxr
authored andcommitted
[SPARK-2477][MLlib] Using appendBias for adding intercept in GeneralizedLinearAlgorithm
Instead of using prependOne currently in GeneralizedLinearAlgorithm, we would like to use appendBias for 1) keeping the indices of original training set unchanged by adding the intercept into the last element of vector and 2) using the same public API for consistently adding intercept. Author: DB Tsai <[email protected]> Closes apache#1410 from dbtsai/SPARK-2477_intercept_with_appendBias and squashes the following commits: 011432c [DB Tsai] From Alpine Data Labs
1 parent dd95aba commit 52beb20

File tree

1 file changed

+5
-16
lines changed

1 file changed

+5
-16
lines changed

mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,12 @@
1717

1818
package org.apache.spark.mllib.regression
1919

20-
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV}
21-
2220
import org.apache.spark.annotation.DeveloperApi
2321
import org.apache.spark.{Logging, SparkException}
2422
import org.apache.spark.rdd.RDD
2523
import org.apache.spark.mllib.optimization._
2624
import org.apache.spark.mllib.linalg.{Vectors, Vector}
25+
import org.apache.spark.mllib.util.MLUtils._
2726

2827
/**
2928
* :: DeveloperApi ::
@@ -124,16 +123,6 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
124123
run(input, initialWeights)
125124
}
126125

127-
/** Prepends one to the input vector. */
128-
private def prependOne(vector: Vector): Vector = {
129-
val vector1 = vector.toBreeze match {
130-
case dv: BDV[Double] => BDV.vertcat(BDV.ones[Double](1), dv)
131-
case sv: BSV[Double] => BSV.vertcat(new BSV[Double](Array(0), Array(1.0), 1), sv)
132-
case v: Any => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
133-
}
134-
Vectors.fromBreeze(vector1)
135-
}
136-
137126
/**
138127
* Run the algorithm with the configured parameters on an input RDD
139128
* of LabeledPoint entries starting from the initial weights provided.
@@ -147,23 +136,23 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
147136

148137
// Prepend an extra variable consisting of all 1.0's for the intercept.
149138
val data = if (addIntercept) {
150-
input.map(labeledPoint => (labeledPoint.label, prependOne(labeledPoint.features)))
139+
input.map(labeledPoint => (labeledPoint.label, appendBias(labeledPoint.features)))
151140
} else {
152141
input.map(labeledPoint => (labeledPoint.label, labeledPoint.features))
153142
}
154143

155144
val initialWeightsWithIntercept = if (addIntercept) {
156-
prependOne(initialWeights)
145+
appendBias(initialWeights)
157146
} else {
158147
initialWeights
159148
}
160149

161150
val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept)
162151

163-
val intercept = if (addIntercept) weightsWithIntercept(0) else 0.0
152+
val intercept = if (addIntercept) weightsWithIntercept(weightsWithIntercept.size - 1) else 0.0
164153
val weights =
165154
if (addIntercept) {
166-
Vectors.dense(weightsWithIntercept.toArray.slice(1, weightsWithIntercept.size))
155+
Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1))
167156
} else {
168157
weightsWithIntercept
169158
}

0 commit comments

Comments
 (0)