Skip to content

Commit 0f28e5e

Browse files
committed
update logisticAggregatorSuite
1 parent 8515b20 commit 0f28e5e

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregatorSuite.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,17 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
238238
val aggConstantFeature = getNewAggregator(instancesConstantFeature,
239239
Vectors.dense(coefArray ++ interceptArray), fitIntercept = true, isMultinomial = true)
240240
instances.foreach(aggConstantFeature.add)
241+
241242
// constant features should not affect gradient
242-
assert(aggConstantFeature.gradient(0) === 0.0)
243+
def validateGradient(grad: Vector): Unit = {
244+
assert(grad(0) === 0.0)
245+
grad.toArray.foreach { gradientValue =>
246+
assert(!gradientValue.isNaN &&
247+
gradientValue > Double.NegativeInfinity && gradientValue < Double.PositiveInfinity)
248+
}
249+
}
250+
251+
validateGradient(aggConstantFeature.gradient)
243252

244253
val binaryCoefArray = Array(1.0, 2.0)
245254
val intercept = 1.0
@@ -248,6 +257,6 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
248257
isMultinomial = false)
249258
instances.foreach(aggConstantFeatureBinary.add)
250259
// constant features should not affect gradient
251-
assert(aggConstantFeatureBinary.gradient(0) === 0.0)
260+
validateGradient(aggConstantFeatureBinary.gradient)
252261
}
253262
}

0 commit comments

Comments
 (0)