@@ -205,23 +205,19 @@ class ObjectHashAggregateSuite
205205 // A TypedImperativeAggregate function
206206 val typed = percentile_approx($" c0" , 0.5 )
207207
208- // A Hive UDAF without partial aggregation support
209- val withoutPartial = function(" hive_max" , $" c1" )
210-
211208 // A Spark SQL native aggregate function with partial aggregation support that can be executed
212209 // by the Tungsten `HashAggregateExec`
213- val withPartialUnsafe = max($" c2 " )
210+ val withPartialUnsafe = max($" c1 " )
214211
215212 // A Spark SQL native aggregate function with partial aggregation support that can only be
216213 // executed by the Tungsten `HashAggregateExec`
217- val withPartialSafe = max($" c3 " )
214+ val withPartialSafe = max($" c2 " )
218215
219216 // A Spark SQL native distinct aggregate function
220- val withDistinct = countDistinct($" c4 " )
217+ val withDistinct = countDistinct($" c3 " )
221218
222219 val allAggs = Seq (
223220 " typed" -> typed,
224- " without partial" -> withoutPartial,
225221 " with partial + unsafe" -> withPartialUnsafe,
226222 " with partial + safe" -> withPartialSafe,
227223 " with distinct" -> withDistinct
@@ -276,10 +272,9 @@ class ObjectHashAggregateSuite
276272 // Generates a random schema for the randomized data generator
277273 val schema = new StructType ()
278274 .add(" c0" , numericTypes(random.nextInt(numericTypes.length)), nullable = true )
279- .add(" c1" , orderedTypes(random.nextInt(orderedTypes.length)), nullable = true )
280- .add(" c2" , fixedLengthTypes(random.nextInt(fixedLengthTypes.length)), nullable = true )
281- .add(" c3" , varLenOrderedTypes(random.nextInt(varLenOrderedTypes.length)), nullable = true )
282- .add(" c4" , allTypes(random.nextInt(allTypes.length)), nullable = true )
275+ .add(" c1" , fixedLengthTypes(random.nextInt(fixedLengthTypes.length)), nullable = true )
276+ .add(" c2" , varLenOrderedTypes(random.nextInt(varLenOrderedTypes.length)), nullable = true )
277+ .add(" c3" , allTypes(random.nextInt(allTypes.length)), nullable = true )
283278
284279 logInfo(
285280 s """ Using the following random schema to generate all the randomized aggregation tests:
@@ -325,69 +320,67 @@ class ObjectHashAggregateSuite
325320
326321 // Currently Spark SQL doesn't support evaluating distinct aggregate function together
327322 // with aggregate functions without partial aggregation support.
328- if (! (aggs.contains(withoutPartial) && aggs.contains(withDistinct))) {
329- test(
330- s " randomized aggregation test - " +
331- s " ${names.mkString(" [" , " , " , " ]" )} - " +
332- s " ${if (withGroupingKeys) " with" else " without" } grouping keys - " +
333- s " with ${if (emptyInput) " empty" else " non-empty" } input "
334- ) {
335- var expected : Seq [Row ] = null
336- var actual1 : Seq [Row ] = null
337- var actual2 : Seq [Row ] = null
338-
339- // Disables `ObjectHashAggregateExec` to obtain a standard answer
340- withSQLConf(SQLConf .USE_OBJECT_HASH_AGG .key -> " false" ) {
341- val aggDf = doAggregation(df)
342-
343- if (aggs.intersect(Seq (withoutPartial, withPartialSafe, typed)).nonEmpty) {
344- assert(containsSortAggregateExec(aggDf))
345- assert(! containsObjectHashAggregateExec(aggDf))
346- assert(! containsHashAggregateExec(aggDf))
347- } else {
348- assert(! containsSortAggregateExec(aggDf))
349- assert(! containsObjectHashAggregateExec(aggDf))
350- assert(containsHashAggregateExec(aggDf))
351- }
352-
353- expected = aggDf.collect().toSeq
323+ test(
324+ s " randomized aggregation test - " +
325+ s " ${names.mkString(" [" , " , " , " ]" )} - " +
326+ s " ${if (withGroupingKeys) " with" else " without" } grouping keys - " +
327+ s " with ${if (emptyInput) " empty" else " non-empty" } input "
328+ ) {
329+ var expected : Seq [Row ] = null
330+ var actual1 : Seq [Row ] = null
331+ var actual2 : Seq [Row ] = null
332+
333+ // Disables `ObjectHashAggregateExec` to obtain a standard answer
334+ withSQLConf(SQLConf .USE_OBJECT_HASH_AGG .key -> " false" ) {
335+ val aggDf = doAggregation(df)
336+
337+ if (aggs.intersect(Seq (withPartialSafe, typed)).nonEmpty) {
338+ assert(containsSortAggregateExec(aggDf))
339+ assert(! containsObjectHashAggregateExec(aggDf))
340+ assert(! containsHashAggregateExec(aggDf))
341+ } else {
342+ assert(! containsSortAggregateExec(aggDf))
343+ assert(! containsObjectHashAggregateExec(aggDf))
344+ assert(containsHashAggregateExec(aggDf))
345+ }
346+
347+ expected = aggDf.collect().toSeq
348+ }
349+
350+ // Enables `ObjectHashAggregateExec`
351+ withSQLConf(SQLConf .USE_OBJECT_HASH_AGG .key -> " true" ) {
352+ val aggDf = doAggregation(df)
353+
354+ if (aggs.contains(typed)) {
355+ assert(! containsSortAggregateExec(aggDf))
356+ assert(containsObjectHashAggregateExec(aggDf))
357+ assert(! containsHashAggregateExec(aggDf))
358+ } else if (aggs.contains(withPartialSafe)) {
359+ assert(containsSortAggregateExec(aggDf))
360+ assert(! containsObjectHashAggregateExec(aggDf))
361+ assert(! containsHashAggregateExec(aggDf))
362+ } else {
363+ assert(! containsSortAggregateExec(aggDf))
364+ assert(! containsObjectHashAggregateExec(aggDf))
365+ assert(containsHashAggregateExec(aggDf))
354366 }
355367
356- // Enables `ObjectHashAggregateExec`
357- withSQLConf(SQLConf .USE_OBJECT_HASH_AGG .key -> " true" ) {
358- val aggDf = doAggregation(df)
359-
360- if (aggs.contains(typed) && ! aggs.contains(withoutPartial)) {
361- assert(! containsSortAggregateExec(aggDf))
362- assert(containsObjectHashAggregateExec(aggDf))
363- assert(! containsHashAggregateExec(aggDf))
364- } else if (aggs.intersect(Seq (withoutPartial, withPartialSafe)).nonEmpty) {
365- assert(containsSortAggregateExec(aggDf))
366- assert(! containsObjectHashAggregateExec(aggDf))
367- assert(! containsHashAggregateExec(aggDf))
368- } else {
369- assert(! containsSortAggregateExec(aggDf))
370- assert(! containsObjectHashAggregateExec(aggDf))
371- assert(containsHashAggregateExec(aggDf))
372- }
373-
374- // Disables sort-based aggregation fallback (we only generate 50 rows, so 100 is
375- // big enough) to obtain a result to be checked.
376- withSQLConf(SQLConf .OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD .key -> " 100" ) {
377- actual1 = aggDf.collect().toSeq
378- }
379-
380- // Enables sort-based aggregation fallback to obtain another result to be checked.
381- withSQLConf(SQLConf .OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD .key -> " 3" ) {
382- // Here we are not reusing `aggDf` because the physical plan in `aggDf` is
383- // cached and won't be re-planned using the new fallback threshold.
384- actual2 = doAggregation(df).collect().toSeq
385- }
368+ // Disables sort-based aggregation fallback (we only generate 50 rows, so 100 is
369+ // big enough) to obtain a result to be checked.
370+ withSQLConf(SQLConf .OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD .key -> " 100" ) {
371+ actual1 = aggDf.collect().toSeq
386372 }
387373
388- doubleSafeCheckRows(actual1, expected, 1e-4 )
389- doubleSafeCheckRows(actual2, expected, 1e-4 )
374+ // Enables sort-based aggregation fallback to obtain another result to be checked.
375+ withSQLConf(SQLConf .OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD .key -> " 3" ) {
376+ // Here we are not reusing `aggDf` because the physical plan in `aggDf` is
377+ // cached and won't be re-planned using the new fallback threshold.
378+ actual2 = doAggregation(df).collect().toSeq
379+ }
390380 }
381+
382+ doubleSafeCheckRows(actual1, expected, 1e-4 )
383+ doubleSafeCheckRows(actual2, expected, 1e-4 )
391384 }
392385 }
393386 }
@@ -424,10 +417,6 @@ class ObjectHashAggregateSuite
424417 }
425418 }
426419
427- private def function (name : String , args : Column * ): Column = {
428- Column (UnresolvedFunction (FunctionIdentifier (name), args.map(_.expr), isDistinct = false ))
429- }
430-
431420 test(" SPARK-18403 Fix unsafe data false sharing issue in ObjectHashAggregateExec" ) {
432421 // SPARK-18403: An unsafe data false sharing issue may trigger OOM / SIGSEGV when evaluating
433422 // certain aggregate functions. To reproduce this issue, the following conditions must be
0 commit comments