@@ -28,6 +28,7 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
2828
2929 @ transient var instances : Array [Instance ] = _
3030 @ transient var instancesConstantFeature : Array [Instance ] = _
31+ @ transient var instancesConstantFeatureFiltered : Array [Instance ] = _
3132
3233 override def beforeAll (): Unit = {
3334 super .beforeAll()
@@ -41,6 +42,11 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
4142 Instance (1.0 , 0.5 , Vectors .dense(1.0 , 1.0 )),
4243 Instance (2.0 , 0.3 , Vectors .dense(1.0 , 0.5 ))
4344 )
45+ instancesConstantFeatureFiltered = Array (
46+ Instance (0.0 , 0.1 , Vectors .dense(2.0 )),
47+ Instance (1.0 , 0.5 , Vectors .dense(1.0 )),
48+ Instance (2.0 , 0.3 , Vectors .dense(0.5 ))
49+ )
4450 }
4551
4652 /** Get summary statistics for some data and create a new LogisticAggregator. */
@@ -233,30 +239,44 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
233239 val binaryInstances = instancesConstantFeature.map { instance =>
234240 if (instance.label <= 1.0 ) instance else Instance (0.0 , instance.weight, instance.features)
235241 }
242+ val binaryInstancesFiltered = instancesConstantFeatureFiltered.map { instance =>
243+ if (instance.label <= 1.0 ) instance else Instance (0.0 , instance.weight, instance.features)
244+ }
236245 val coefArray = Array (1.0 , 2.0 , - 2.0 , 3.0 , 0.0 , - 1.0 )
246+ val coefArrayFiltered = Array (3.0 , 0.0 , - 1.0 )
237247 val interceptArray = Array (4.0 , 2.0 , - 3.0 )
238248 val aggConstantFeature = getNewAggregator(instancesConstantFeature,
239249 Vectors .dense(coefArray ++ interceptArray), fitIntercept = true , isMultinomial = true )
240- instances.foreach(aggConstantFeature.add)
250+ val aggConstantFeatureFiltered = getNewAggregator(instancesConstantFeatureFiltered,
251+ Vectors .dense(coefArrayFiltered ++ interceptArray), fitIntercept = true , isMultinomial = true )
252+
253+ instancesConstantFeature.foreach(aggConstantFeature.add)
254+ instancesConstantFeatureFiltered.foreach(aggConstantFeatureFiltered.add)
241255
242256 // constant features should not affect gradient
243- def validateGradient (grad : Vector ): Unit = {
244- assert(grad(0 ) === 0.0 )
245- grad.toArray.foreach { gradientValue =>
246- assert(! gradientValue.isNaN &&
247- gradientValue > Double .NegativeInfinity && gradientValue < Double .PositiveInfinity )
257+ def validateGradient (grad : Vector , gradFiltered : Vector , numCoefficientSets : Int ): Unit = {
258+ for (i <- 0 until numCoefficientSets) {
259+ assert(grad(i) === 0.0 )
260+ assert(grad(numCoefficientSets + i) == gradFiltered(i))
248261 }
249262 }
250263
251- validateGradient(aggConstantFeature.gradient)
264+ validateGradient(aggConstantFeature.gradient, aggConstantFeatureFiltered.gradient, 3 )
252265
253266 val binaryCoefArray = Array (1.0 , 2.0 )
267+ val binaryCoefArrayFiltered = Array (2.0 )
254268 val intercept = 1.0
255269 val aggConstantFeatureBinary = getNewAggregator(binaryInstances,
256270 Vectors .dense(binaryCoefArray ++ Array (intercept)), fitIntercept = true ,
257271 isMultinomial = false )
258- instances.foreach(aggConstantFeatureBinary.add)
272+ val aggConstantFeatureBinaryFiltered = getNewAggregator(binaryInstancesFiltered,
273+ Vectors .dense(binaryCoefArrayFiltered ++ Array (intercept)), fitIntercept = true ,
274+ isMultinomial = false )
275+ binaryInstances.foreach(aggConstantFeatureBinary.add)
276+ binaryInstancesFiltered.foreach(aggConstantFeatureBinaryFiltered.add)
277+
259278 // constant features should not affect gradient
260- validateGradient(aggConstantFeatureBinary.gradient)
279+ validateGradient(aggConstantFeatureBinary.gradient,
280+ aggConstantFeatureBinaryFiltered.gradient, 1 )
261281 }
262282}
0 commit comments