Skip to content

Commit 5f1cee6

Browse files
Nakul JindalDB Tsai
authored andcommitted
[SPARK-11332] [ML] Refactored to use ml.feature.Instance instead of WeightedLeastSquare.Instance
WeightedLeastSquares now uses the common Instance class in ml.feature instead of a private one. Author: Nakul Jindal <[email protected]> Closes #9325 from nakul02/SPARK-11332_refactor_WeightedLeastSquares_dot_Instance.
1 parent 82c1c57 commit 5f1cee6

File tree

3 files changed

+15
-24
lines changed

3 files changed

+15
-24
lines changed

mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.ml.optim
1919

2020
import org.apache.spark.Logging
21+
import org.apache.spark.ml.feature.Instance
2122
import org.apache.spark.mllib.linalg._
2223
import org.apache.spark.rdd.RDD
2324

@@ -121,16 +122,6 @@ private[ml] class WeightedLeastSquares(
121122

122123
private[ml] object WeightedLeastSquares {
123124

124-
/**
125-
* Case class for weighted observations.
126-
* @param w weight, must be positive
127-
* @param a features
128-
* @param b label
129-
*/
130-
case class Instance(w: Double, a: Vector, b: Double) {
131-
require(w >= 0.0, s"Weight cannot be negative: $w.")
132-
}
133-
134125
/**
135126
* Aggregator to provide necessary summary statistics for solving [[WeightedLeastSquares]].
136127
*/
@@ -168,20 +159,20 @@ private[ml] object WeightedLeastSquares {
168159
* Adds an instance.
169160
*/
170161
def add(instance: Instance): this.type = {
171-
val Instance(w, a, b) = instance
172-
val ak = a.size
162+
val Instance(l, w, f) = instance
163+
val ak = f.size
173164
if (!initialized) {
174165
init(ak)
175166
}
176167
assert(ak == k, s"Dimension mismatch. Expect vectors of size $k but got $ak.")
177168
count += 1L
178169
wSum += w
179170
wwSum += w * w
180-
bSum += w * b
181-
bbSum += w * b * b
182-
BLAS.axpy(w, a, aSum)
183-
BLAS.axpy(w * b, a, abSum)
184-
BLAS.spr(w, a, aaSum)
171+
bSum += w * l
172+
bbSum += w * l * l
173+
BLAS.axpy(w, f, aSum)
174+
BLAS.axpy(w * l, f, abSum)
175+
BLAS.spr(w, f, aaSum)
185176
this
186177
}
187178

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,10 @@ class LinearRegression(override val uid: String)
154154
"solver is used.'")
155155
// For low dimensional data, WeightedLeastSquares is more efficiently since the
156156
// training algorithm only requires one pass through the data. (SPARK-10668)
157-
val instances: RDD[WeightedLeastSquares.Instance] = dataset.select(
157+
val instances: RDD[Instance] = dataset.select(
158158
col($(labelCol)), w, col($(featuresCol))).map {
159159
case Row(label: Double, weight: Double, features: Vector) =>
160-
WeightedLeastSquares.Instance(weight, features, label)
160+
Instance(label, weight, features)
161161
}
162162

163163
val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam),

mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.ml.optim
1919

2020
import org.apache.spark.SparkFunSuite
21-
import org.apache.spark.ml.optim.WeightedLeastSquares.Instance
21+
import org.apache.spark.ml.feature.Instance
2222
import org.apache.spark.mllib.linalg.Vectors
2323
import org.apache.spark.mllib.util.MLlibTestSparkContext
2424
import org.apache.spark.mllib.util.TestingUtils._
@@ -38,10 +38,10 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext
3838
w <- c(1, 2, 3, 4)
3939
*/
4040
instances = sc.parallelize(Seq(
41-
Instance(1.0, Vectors.dense(0.0, 5.0).toSparse, 17.0),
42-
Instance(2.0, Vectors.dense(1.0, 7.0), 19.0),
43-
Instance(3.0, Vectors.dense(2.0, 11.0), 23.0),
44-
Instance(4.0, Vectors.dense(3.0, 13.0), 29.0)
41+
Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
42+
Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)),
43+
Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)),
44+
Instance(29.0, 4.0, Vectors.dense(3.0, 13.0))
4545
), 2)
4646
}
4747

0 commit comments

Comments
 (0)