1717
1818package org .apache .spark .mllib .regression
1919
20- import breeze .linalg .{DenseVector => BDV , SparseVector => BSV }
21-
2220import org .apache .spark .annotation .DeveloperApi
2321import org .apache .spark .{Logging , SparkException }
2422import org .apache .spark .rdd .RDD
2523import org .apache .spark .mllib .optimization ._
2624import org .apache .spark .mllib .linalg .{Vectors , Vector }
25+ import org .apache .spark .mllib .util .MLUtils ._
2726
2827/**
2928 * :: DeveloperApi ::
@@ -124,16 +123,6 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
124123 run(input, initialWeights)
125124 }
126125
127- /** Prepends one to the input vector. */
128- private def prependOne (vector : Vector ): Vector = {
129- val vector1 = vector.toBreeze match {
130- case dv : BDV [Double ] => BDV .vertcat(BDV .ones[Double ](1 ), dv)
131- case sv : BSV [Double ] => BSV .vertcat(new BSV [Double ](Array (0 ), Array (1.0 ), 1 ), sv)
132- case v : Any => throw new IllegalArgumentException (" Do not support vector type " + v.getClass)
133- }
134- Vectors .fromBreeze(vector1)
135- }
136-
137126 /**
138127 * Run the algorithm with the configured parameters on an input RDD
139128 * of LabeledPoint entries starting from the initial weights provided.
@@ -147,23 +136,23 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
147136
148137 // Prepend an extra variable consisting of all 1.0's for the intercept.
149138 val data = if (addIntercept) {
150- input.map(labeledPoint => (labeledPoint.label, prependOne (labeledPoint.features)))
139+ input.map(labeledPoint => (labeledPoint.label, appendBias (labeledPoint.features)))
151140 } else {
152141 input.map(labeledPoint => (labeledPoint.label, labeledPoint.features))
153142 }
154143
155144 val initialWeightsWithIntercept = if (addIntercept) {
156- prependOne (initialWeights)
145+ appendBias (initialWeights)
157146 } else {
158147 initialWeights
159148 }
160149
161150 val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept)
162151
163- val intercept = if (addIntercept) weightsWithIntercept(0 ) else 0.0
152+ val intercept = if (addIntercept) weightsWithIntercept(weightsWithIntercept.size - 1 ) else 0.0
164153 val weights =
165154 if (addIntercept) {
166- Vectors .dense(weightsWithIntercept.toArray.slice(1 , weightsWithIntercept.size))
155+ Vectors .dense(weightsWithIntercept.toArray.slice(0 , weightsWithIntercept.size - 1 ))
167156 } else {
168157 weightsWithIntercept
169158 }
0 commit comments