Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,9 @@ class SortBasedAggregator(
// Firstly, update the aggregation buffer with input rows.
while (hasNextInput &&
groupingKeyOrdering.compare(inputIterator.getKey, groupingKey) == 0) {
processRow(result.aggregationBuffer, inputIterator.getValue)
// Since `inputIterator.getValue` is an `UnsafeRow` whose underlying buffer will be
// overwritten when `inputIterator` steps forward, we need to do a deep copy here.
processRow(result.aggregationBuffer, inputIterator.getValue.copy())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the problem is, during processRow we cache the input row somehow?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's caused by MutableProjection? As MutableProjection may keep an "pointer" that points to a memory region of an unsafe row. Maybe we can fix this bug by #15082?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm, #15082 needs some significant refactor, we should get this fix in 2.1 first.

hasNextInput = inputIterator.next()
}

Expand All @@ -271,7 +273,12 @@ class SortBasedAggregator(
// be called after calling processRow.
while (hasNextAggBuffer &&
groupingKeyOrdering.compare(initialAggBufferIterator.getKey, groupingKey) == 0) {
mergeAggregationBuffers(result.aggregationBuffer, initialAggBufferIterator.getValue)
mergeAggregationBuffers(
result.aggregationBuffer,
// Since `inputIterator.getValue` is an `UnsafeRow` whose underlying buffer will be
// overwritten when `inputIterator` steps forward, we need to do a deep copy here.
initialAggBufferIterator.getValue.copy()
)
hasNextAggBuffer = initialAggBufferIterator.next()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,23 +205,19 @@ class ObjectHashAggregateSuite
// A TypedImperativeAggregate function
val typed = percentile_approx($"c0", 0.5)

// A Hive UDAF without partial aggregation support
val withoutPartial = function("hive_max", $"c1")

// A Spark SQL native aggregate function with partial aggregation support that can be executed
// by the Tungsten `HashAggregateExec`
val withPartialUnsafe = max($"c2")
val withPartialUnsafe = max($"c1")

// A Spark SQL native aggregate function with partial aggregation support that can only be
// executed by the Tungsten `HashAggregateExec`
val withPartialSafe = max($"c3")
val withPartialSafe = max($"c2")

// A Spark SQL native distinct aggregate function
val withDistinct = countDistinct($"c4")
val withDistinct = countDistinct($"c3")

val allAggs = Seq(
"typed" -> typed,
"without partial" -> withoutPartial,
"with partial + unsafe" -> withPartialUnsafe,
"with partial + safe" -> withPartialSafe,
"with distinct" -> withDistinct
Expand Down Expand Up @@ -276,10 +272,9 @@ class ObjectHashAggregateSuite
// Generates a random schema for the randomized data generator
val schema = new StructType()
.add("c0", numericTypes(random.nextInt(numericTypes.length)), nullable = true)
.add("c1", orderedTypes(random.nextInt(orderedTypes.length)), nullable = true)
.add("c2", fixedLengthTypes(random.nextInt(fixedLengthTypes.length)), nullable = true)
.add("c3", varLenOrderedTypes(random.nextInt(varLenOrderedTypes.length)), nullable = true)
.add("c4", allTypes(random.nextInt(allTypes.length)), nullable = true)
.add("c1", fixedLengthTypes(random.nextInt(fixedLengthTypes.length)), nullable = true)
.add("c2", varLenOrderedTypes(random.nextInt(varLenOrderedTypes.length)), nullable = true)
.add("c3", allTypes(random.nextInt(allTypes.length)), nullable = true)

logInfo(
s"""Using the following random schema to generate all the randomized aggregation tests:
Expand Down Expand Up @@ -325,70 +320,67 @@ class ObjectHashAggregateSuite

// Currently Spark SQL doesn't support evaluating distinct aggregate function together
// with aggregate functions without partial aggregation support.
if (!(aggs.contains(withoutPartial) && aggs.contains(withDistinct))) {
// TODO Re-enables them after fixing SPARK-18403
ignore(
s"randomized aggregation test - " +
s"${names.mkString("[", ", ", "]")} - " +
s"${if (withGroupingKeys) "with" else "without"} grouping keys - " +
s"with ${if (emptyInput) "empty" else "non-empty"} input"
) {
var expected: Seq[Row] = null
var actual1: Seq[Row] = null
var actual2: Seq[Row] = null

// Disables `ObjectHashAggregateExec` to obtain a standard answer
withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") {
val aggDf = doAggregation(df)

if (aggs.intersect(Seq(withoutPartial, withPartialSafe, typed)).nonEmpty) {
assert(containsSortAggregateExec(aggDf))
assert(!containsObjectHashAggregateExec(aggDf))
assert(!containsHashAggregateExec(aggDf))
} else {
assert(!containsSortAggregateExec(aggDf))
assert(!containsObjectHashAggregateExec(aggDf))
assert(containsHashAggregateExec(aggDf))
}

expected = aggDf.collect().toSeq
test(
s"randomized aggregation test - " +
s"${names.mkString("[", ", ", "]")} - " +
s"${if (withGroupingKeys) "with" else "without"} grouping keys - " +
s"with ${if (emptyInput) "empty" else "non-empty"} input"
) {
var expected: Seq[Row] = null
var actual1: Seq[Row] = null
var actual2: Seq[Row] = null

// Disables `ObjectHashAggregateExec` to obtain a standard answer
withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") {
val aggDf = doAggregation(df)

if (aggs.intersect(Seq(withPartialSafe, typed)).nonEmpty) {
assert(containsSortAggregateExec(aggDf))
assert(!containsObjectHashAggregateExec(aggDf))
assert(!containsHashAggregateExec(aggDf))
} else {
assert(!containsSortAggregateExec(aggDf))
assert(!containsObjectHashAggregateExec(aggDf))
assert(containsHashAggregateExec(aggDf))
}

// Enables `ObjectHashAggregateExec`
withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") {
val aggDf = doAggregation(df)

if (aggs.contains(typed) && !aggs.contains(withoutPartial)) {
assert(!containsSortAggregateExec(aggDf))
assert(containsObjectHashAggregateExec(aggDf))
assert(!containsHashAggregateExec(aggDf))
} else if (aggs.intersect(Seq(withoutPartial, withPartialSafe)).nonEmpty) {
assert(containsSortAggregateExec(aggDf))
assert(!containsObjectHashAggregateExec(aggDf))
assert(!containsHashAggregateExec(aggDf))
} else {
assert(!containsSortAggregateExec(aggDf))
assert(!containsObjectHashAggregateExec(aggDf))
assert(containsHashAggregateExec(aggDf))
}

// Disables sort-based aggregation fallback (we only generate 50 rows, so 100 is
// big enough) to obtain a result to be checked.
withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "100") {
actual1 = aggDf.collect().toSeq
}

// Enables sort-based aggregation fallback to obtain another result to be checked.
withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "3") {
// Here we are not reusing `aggDf` because the physical plan in `aggDf` is
// cached and won't be re-planned using the new fallback threshold.
actual2 = doAggregation(df).collect().toSeq
}
expected = aggDf.collect().toSeq
}

// Enables `ObjectHashAggregateExec`
withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") {
val aggDf = doAggregation(df)

if (aggs.contains(typed)) {
assert(!containsSortAggregateExec(aggDf))
assert(containsObjectHashAggregateExec(aggDf))
assert(!containsHashAggregateExec(aggDf))
} else if (aggs.contains(withPartialSafe)) {
assert(containsSortAggregateExec(aggDf))
assert(!containsObjectHashAggregateExec(aggDf))
assert(!containsHashAggregateExec(aggDf))
} else {
assert(!containsSortAggregateExec(aggDf))
assert(!containsObjectHashAggregateExec(aggDf))
assert(containsHashAggregateExec(aggDf))
}

doubleSafeCheckRows(actual1, expected, 1e-4)
doubleSafeCheckRows(actual2, expected, 1e-4)
// Disables sort-based aggregation fallback (we only generate 50 rows, so 100 is
// big enough) to obtain a result to be checked.
withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "100") {
actual1 = aggDf.collect().toSeq
}

// Enables sort-based aggregation fallback to obtain another result to be checked.
withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "3") {
// Here we are not reusing `aggDf` because the physical plan in `aggDf` is
// cached and won't be re-planned using the new fallback threshold.
actual2 = doAggregation(df).collect().toSeq
}
}

doubleSafeCheckRows(actual1, expected, 1e-4)
doubleSafeCheckRows(actual2, expected, 1e-4)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the changes made above in this file are used to resolve a logical conflict with PR #15703. We don't really have any aggregate functions that don't support partial aggregation now after merging #15703, must update the tests to reflect that.

}
}
}
Expand Down Expand Up @@ -425,7 +417,35 @@ class ObjectHashAggregateSuite
}
}

private def function(name: String, args: Column*): Column = {
Column(UnresolvedFunction(FunctionIdentifier(name), args.map(_.expr), isDistinct = false))
test("SPARK-18403 Fix unsafe data false sharing issue in ObjectHashAggregateExec") {
// SPARK-18403: An unsafe data false sharing issue may trigger OOM / SIGSEGV when evaluating
// certain aggregate functions. To reproduce this issue, the following conditions must be
// met:
//
// 1. The aggregation must be evaluated using `ObjectHashAggregateExec`;
// 2. There must be an input column whose data type involves `ArrayType` or `MapType`;
// 3. Sort-based aggregation fallback must be triggered during evaluation.
withSQLConf(
SQLConf.USE_OBJECT_HASH_AGG.key -> "true",
SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "1"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related to this PR, but the config name looks weird, how about OBJECT_AGG_FALLBACK_TO_SORT_THRESHOLD

) {
checkAnswer(
Seq
.fill(2)(Tuple1(Array.empty[Int]))
.toDF("c0")
.groupBy(lit(1))
.agg(typed_count($"c0"), max($"c0")),
Row(1, 2, Array.empty[Int])
)

checkAnswer(
Seq
.fill(2)(Tuple1(Map.empty[Int, Int]))
.toDF("c0")
.groupBy(lit(1))
.agg(typed_count($"c0"), first($"c0")),
Row(1, 2, Map.empty[Int, Int])
)
}
}
}