Skip to content

Commit 1f4ba14

Browse files
committed
update testcase
1 parent 0f28e5e commit 1f4ba14

File tree

2 files changed

+31
-11
lines changed

2 files changed

+31
-11
lines changed

mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1457,9 +1457,9 @@ class LogisticRegressionSuite
14571457
*/
14581458

14591459
val coefficientsR = new DenseMatrix(3, 2, Array(
1460-
0.1881871, -0.0,
1460+
0.1881871, 0.0,
14611461
-0.02412645, 0.0,
1462-
-0.1640607, -0.0), isTransposed = true)
1462+
-0.1640607, 0.0), isTransposed = true)
14631463
val interceptsR = Vectors.dense(0.2658824, 0.53604701, -0.8019294)
14641464

14651465
model.coefficientMatrix.colIter.foreach(v => assert(v.toArray.sum ~== 0.0 absTol eps))

mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregatorSuite.scala

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
2828

2929
@transient var instances: Array[Instance] = _
3030
@transient var instancesConstantFeature: Array[Instance] = _
31+
@transient var instancesConstantFeatureFiltered: Array[Instance] = _
3132

3233
override def beforeAll(): Unit = {
3334
super.beforeAll()
@@ -41,6 +42,11 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
4142
Instance(1.0, 0.5, Vectors.dense(1.0, 1.0)),
4243
Instance(2.0, 0.3, Vectors.dense(1.0, 0.5))
4344
)
45+
instancesConstantFeatureFiltered = Array(
46+
Instance(0.0, 0.1, Vectors.dense(2.0)),
47+
Instance(1.0, 0.5, Vectors.dense(1.0)),
48+
Instance(2.0, 0.3, Vectors.dense(0.5))
49+
)
4450
}
4551

4652
/** Get summary statistics for some data and create a new LogisticAggregator. */
@@ -233,30 +239,44 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
233239
val binaryInstances = instancesConstantFeature.map { instance =>
234240
if (instance.label <= 1.0) instance else Instance(0.0, instance.weight, instance.features)
235241
}
242+
val binaryInstancesFiltered = instancesConstantFeatureFiltered.map { instance =>
243+
if (instance.label <= 1.0) instance else Instance(0.0, instance.weight, instance.features)
244+
}
236245
val coefArray = Array(1.0, 2.0, -2.0, 3.0, 0.0, -1.0)
246+
val coefArrayFiltered = Array(3.0, 0.0, -1.0)
237247
val interceptArray = Array(4.0, 2.0, -3.0)
238248
val aggConstantFeature = getNewAggregator(instancesConstantFeature,
239249
Vectors.dense(coefArray ++ interceptArray), fitIntercept = true, isMultinomial = true)
240-
instances.foreach(aggConstantFeature.add)
250+
val aggConstantFeatureFiltered = getNewAggregator(instancesConstantFeatureFiltered,
251+
Vectors.dense(coefArrayFiltered ++ interceptArray), fitIntercept = true, isMultinomial = true)
252+
253+
instancesConstantFeature.foreach(aggConstantFeature.add)
254+
instancesConstantFeatureFiltered.foreach(aggConstantFeatureFiltered.add)
241255

242256
// constant features should not affect gradient
243-
def validateGradient(grad: Vector): Unit = {
244-
assert(grad(0) === 0.0)
245-
grad.toArray.foreach { gradientValue =>
246-
assert(!gradientValue.isNaN &&
247-
gradientValue > Double.NegativeInfinity && gradientValue < Double.PositiveInfinity)
257+
def validateGradient(grad: Vector, gradFiltered: Vector, numCoefficientSets: Int): Unit = {
258+
for (i <- 0 until numCoefficientSets) {
259+
assert(grad(i) === 0.0)
260+
assert(grad(numCoefficientSets + i) == gradFiltered(i))
248261
}
249262
}
250263

251-
validateGradient(aggConstantFeature.gradient)
264+
validateGradient(aggConstantFeature.gradient, aggConstantFeatureFiltered.gradient, 3)
252265

253266
val binaryCoefArray = Array(1.0, 2.0)
267+
val binaryCoefArrayFiltered = Array(2.0)
254268
val intercept = 1.0
255269
val aggConstantFeatureBinary = getNewAggregator(binaryInstances,
256270
Vectors.dense(binaryCoefArray ++ Array(intercept)), fitIntercept = true,
257271
isMultinomial = false)
258-
instances.foreach(aggConstantFeatureBinary.add)
272+
val aggConstantFeatureBinaryFiltered = getNewAggregator(binaryInstancesFiltered,
273+
Vectors.dense(binaryCoefArrayFiltered ++ Array(intercept)), fitIntercept = true,
274+
isMultinomial = false)
275+
binaryInstances.foreach(aggConstantFeatureBinary.add)
276+
binaryInstancesFiltered.foreach(aggConstantFeatureBinaryFiltered.add)
277+
259278
// constant features should not affect gradient
260-
validateGradient(aggConstantFeatureBinary.gradient)
279+
validateGradient(aggConstantFeatureBinary.gradient,
280+
aggConstantFeatureBinaryFiltered.gradient, 1)
261281
}
262282
}

0 commit comments

Comments
 (0)