Skip to content

Commit 0772834

Browse files
committed
nit
nit nit
1 parent 9275258 commit 0772834

File tree

2 files changed

+13
-17
lines changed

2 files changed

+13
-17
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -242,19 +242,18 @@ class LinearSVC @Since("2.2.0") (
242242
Note that the intercept in scaled space and original space is the same;
243243
as a result, no scaling is needed.
244244
*/
245-
val state = if ($(blockSize) == 1) {
245+
val rawCoefficients = if ($(blockSize) == 1) {
246246
trainOnRows(instances, featuresStd, regularization, optimizer)
247247
} else {
248248
trainOnBlocks(instances, featuresStd, regularization, optimizer)
249249
}
250250
if (instances.getStorageLevel != StorageLevel.NONE) instances.unpersist()
251251

252-
if (state == null) {
252+
if (rawCoefficients == null) {
253253
val msg = s"${optimizer.getClass.getName} failed."
254254
instr.logError(msg)
255255
throw new SparkException(msg)
256256
}
257-
val rawCoefficients = state.x.toArray
258257

259258
val coefficientArray = Array.tabulate(numFeatures) { i =>
260259
if (featuresStd(i) != 0.0) rawCoefficients(i) / featuresStd(i) else 0.0
@@ -267,7 +266,7 @@ class LinearSVC @Since("2.2.0") (
267266
instances: RDD[Instance],
268267
featuresStd: Array[Double],
269268
regularization: Option[L2Regularization],
270-
optimizer: BreezeOWLQN[Int, BDV[Double]]): optimizer.State = {
269+
optimizer: BreezeOWLQN[Int, BDV[Double]]): Array[Double] = {
271270
val numFeatures = featuresStd.length
272271
val numFeaturesPlusIntercept = if ($(fitIntercept)) numFeatures + 1 else numFeatures
273272

@@ -287,14 +286,14 @@ class LinearSVC @Since("2.2.0") (
287286
}
288287
bcFeaturesStd.destroy()
289288

290-
state
289+
if (state == null) null else state.x.toArray
291290
}
292291

293292
private def trainOnBlocks(
294293
instances: RDD[Instance],
295294
featuresStd: Array[Double],
296295
regularization: Option[L2Regularization],
297-
optimizer: BreezeOWLQN[Int, BDV[Double]]): optimizer.State = {
296+
optimizer: BreezeOWLQN[Int, BDV[Double]]): Array[Double] = {
298297
val numFeatures = featuresStd.length
299298
val numFeaturesPlusIntercept = if ($(fitIntercept)) numFeatures + 1 else numFeatures
300299

@@ -331,9 +330,8 @@ class LinearSVC @Since("2.2.0") (
331330
blocks.unpersist()
332331
bcFeaturesStd.destroy()
333332

334-
state
333+
if (state == null) null else state.x.toArray
335334
}
336-
337335
}
338336

339337
@Since("2.2.0")

mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HingeAggregator.scala

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ private[ml] class BlockHingeAggregator(
138138
}
139139

140140
@transient private lazy val intercept =
141-
if (fitIntercept) coefficientsArray(numFeatures) else 0.0
141+
if (fitIntercept) coefficientsArray.last else 0.0
142142

143143
@transient private lazy val linearGradSumVec =
144144
if (fitIntercept) Vectors.zeros(numFeatures).toDense else null
@@ -161,15 +161,13 @@ private[ml] class BlockHingeAggregator(
161161

162162
if (block.weightIter.forall(_ == 0)) return this
163163
val size = block.size
164-
val localGradientSumArray = gradientSumArray
165164

166165
// vec/arr here represents dotProducts
167166
val vec = if (size == blockSize) auxiliaryVec else Vectors.zeros(size).toDense
168167
val arr = vec.values
169168

170169
if (fitIntercept && intercept != 0) {
171-
var i = 0
172-
while (i < size) { arr(i) = intercept; i += 1 }
170+
java.util.Arrays.fill(arr, intercept)
173171
BLAS.gemv(1.0, block.matrix, linear, 1.0, vec)
174172
} else {
175173
BLAS.gemv(1.0, block.matrix, linear, 0.0, vec)
@@ -206,16 +204,16 @@ private[ml] class BlockHingeAggregator(
206204
block.matrix match {
207205
case dm: DenseMatrix =>
208206
BLAS.nativeBLAS.dgemv("N", dm.numCols, dm.numRows, 1.0, dm.values, dm.numCols,
209-
arr, 1, 1.0, localGradientSumArray, 1)
210-
if (fitIntercept) localGradientSumArray(numFeatures) += arr.sum
207+
arr, 1, 1.0, gradientSumArray, 1)
208+
if (fitIntercept) gradientSumArray(numFeatures) += arr.sum
211209

212210
case sm: SparseMatrix if fitIntercept =>
213211
BLAS.gemv(1.0, sm.transpose, vec, 0.0, linearGradSumVec)
214-
linearGradSumVec.foreachNonZero { (i, v) => localGradientSumArray(i) += v }
215-
localGradientSumArray(numFeatures) += arr.sum
212+
BLAS.nativeBLAS.daxpy(numFeatures, 1.0, linearGradSumVec.values, 1, gradientSumArray, 1)
213+
gradientSumArray(numFeatures) += arr.sum
216214

217215
case sm: SparseMatrix if !fitIntercept =>
218-
val gradSumVec = new DenseVector(localGradientSumArray)
216+
val gradSumVec = new DenseVector(gradientSumArray)
219217
BLAS.gemv(1.0, sm.transpose, vec, 1.0, gradSumVec)
220218
}
221219

0 commit comments

Comments
 (0)