File tree Expand file tree Collapse file tree 1 file changed +11
-2
lines changed
mllib/src/test/scala/org/apache/spark/ml/optim/aggregator Expand file tree Collapse file tree 1 file changed +11
-2
lines changed Original file line number Diff line number Diff line change @@ -238,8 +238,17 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
238238 val aggConstantFeature = getNewAggregator(instancesConstantFeature,
239239 Vectors .dense(coefArray ++ interceptArray), fitIntercept = true , isMultinomial = true )
240240 instances.foreach(aggConstantFeature.add)
241+
241242 // constant features should not affect gradient
242- assert(aggConstantFeature.gradient(0 ) === 0.0 )
243+ def validateGradient (grad : Vector ): Unit = {
244+ assert(grad(0 ) === 0.0 )
245+ grad.toArray.foreach { gradientValue =>
246+ assert(! gradientValue.isNaN &&
247+ gradientValue > Double .NegativeInfinity && gradientValue < Double .PositiveInfinity )
248+ }
249+ }
250+
251+ validateGradient(aggConstantFeature.gradient)
243252
244253 val binaryCoefArray = Array (1.0 , 2.0 )
245254 val intercept = 1.0
@@ -248,6 +257,6 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
248257 isMultinomial = false )
249258 instances.foreach(aggConstantFeatureBinary.add)
250259 // constant features should not affect gradient
251- assert (aggConstantFeatureBinary.gradient( 0 ) === 0.0 )
260+ validateGradient (aggConstantFeatureBinary.gradient)
252261 }
253262}
You can’t perform that action at this time.
0 commit comments