Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,12 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
}

// Setup unique distinct aggregate children.
val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq
val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair).toMap
val distinctAggChildAttrs = distinctAggChildAttrMap.values.toSeq
val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct
val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By calling toMap I am potentially breaking the alignment between distinctAggChildren and distinctAggChildAttrs.

val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2)

// Setup expand & aggregate operators for distinct aggregate expressions.
val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap
val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map {
case ((group, expressions), i) =>
val id = Literal(i + 1)
Expand All @@ -170,7 +171,7 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
val operators = expressions.map { e =>
val af = e.aggregateFunction
val naf = patchAggregateFunctionChildren(af) { x =>
evalWithinGroup(id, distinctAggChildAttrMap(x))
evalWithinGroup(id, distinctAggChildAttrLookup(x))
}
(e, e.copy(aggregateFunction = naf, isDistinct = false))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,36 @@ class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFun
}
}

class LongProductSum extends UserDefinedAggregateFunction {
def inputSchema: StructType = new StructType()
.add("a", LongType)
.add("b", LongType)

def bufferSchema: StructType = new StructType()
.add("product", LongType)

def dataType: DataType = LongType

def deterministic: Boolean = true

def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
}

def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (!(input.isNullAt(0) || input.isNullAt(1))) {
buffer(0) = buffer.getLong(0) + input.getLong(0) * input.getLong(1)
}
}

def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
}

def evaluate(buffer: Row): Any =
buffer.getLong(0)
}

abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
import testImplicits._

Expand Down Expand Up @@ -110,6 +140,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
// Register UDAFs
sqlContext.udf.register("mydoublesum", new MyDoubleSum)
sqlContext.udf.register("mydoubleavg", new MyDoubleAvg)
sqlContext.udf.register("longProductSum", new LongProductSum)
}

override def afterAll(): Unit = {
Expand Down Expand Up @@ -545,19 +576,21 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
| count(distinct value2),
| sum(distinct value2),
| count(distinct value1, value2),
| longProductSum(distinct value1, value2),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, what is the semantic of this? If we have [value1=null, value2=null], [value1=null, value2=1], [value1=1, value2=null], and [value1=1, value2=1], we have 4 distinct input rows, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be more specific, my question is not for the semantic of longProductSum since an aggregate function can define its own semantic on how to handle nulls. My question is for our rewriter. Since we use an aggregate to do distinct, for the case above, seems we will get 4 distinct input rows, which makes sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes and No.

The input we care about only consists of these tuples: [value1=null, value2=null], [value1=null, value2=1], [value1=1, value2=null], and [value1=1, value2=1]

However in the current implementation a distinct aggregate will see more input than those. It will also see records from other groups. However, the values in these records are nulled out. The assumption here is that an AggregateFunction is not changed by an all NULL update. The only case I can think of that would be problematic is a FIRST(DISTINCT ...); which shouldn't be used like that anyway.

We could solve this by wrapping AggregateFunctions with an operator which will only update if the group id is correct.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, right. I was thinking about just the inputs of this particular distinct agg function (records with other gid values are ignored).

| count(value1),
| sum(value1),
| count(value2),
| sum(value2),
| longProductSum(value1, value2),
| count(*),
| count(1)
|FROM agg2
|GROUP BY key
""".stripMargin),
Row(null, 3, 30, 3, 60, 3, 3, 30, 3, 60, 4, 4) ::
Row(1, 2, 40, 3, -10, 3, 3, 70, 3, -10, 3, 3) ::
Row(2, 2, 0, 1, 1, 1, 3, 1, 3, 3, 4, 4) ::
Row(3, 0, null, 1, 3, 0, 0, null, 1, 3, 2, 2) :: Nil)
Row(null, 3, 30, 3, 60, 3, -4700, 3, 30, 3, 60, -4700, 4, 4) ::
Row(1, 2, 40, 3, -10, 3, -100, 3, 70, 3, -10, -100, 3, 3) ::
Row(2, 2, 0, 1, 1, 1, 1, 3, 1, 3, 3, 2, 4, 4) ::
Row(3, 0, null, 1, 3, 0, 0, 0, null, 1, 3, 0, 2, 2) :: Nil)
}

test("test count") {
Expand Down