Skip to content

Commit a19fea6

Browse files
committed
Add UDAF interface.
1 parent 262d4c4 commit a19fea6

File tree

6 files changed

+415
-2
lines changed

6 files changed

+415
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate2
2020
import org.apache.spark.sql.catalyst.dsl.expressions._
2121
import org.apache.spark.sql.catalyst.errors.TreeNodeException
2222
import org.apache.spark.sql.catalyst.expressions._
23-
import org.apache.spark.sql.catalyst.InternalRow
23+
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
2424
import org.apache.spark.sql.types._
25+
import org.apache.spark.sql.Row
2526

2627
/** The mode of an [[AggregateFunction]]. */
2728
private[sql] sealed trait AggregateMode
@@ -277,3 +278,78 @@ case class Average(child: Expression) extends AlgebraicAggregate {
277278
override def dataType: DataType = resultType
278279
override def children: Seq[Expression] = child :: Nil
279280
}
281+
282+
abstract class AggregationBuffer(
283+
toCatalystConverters: Array[Any => Any],
284+
toScalaConverters: Array[Any => Any],
285+
bufferOffset: Int)
286+
extends Row {
287+
288+
override def length: Int = toCatalystConverters.length
289+
290+
protected val offsets: Array[Int] = {
291+
val newOffsets = new Array[Int](length)
292+
var i = 0
293+
while (i < newOffsets.length) {
294+
newOffsets(i) = bufferOffset + i
295+
i += 1
296+
}
297+
newOffsets
298+
}
299+
}
300+
301+
class MutableAggregationBuffer(
302+
toCatalystConverters: Array[Any => Any],
303+
toScalaConverters: Array[Any => Any],
304+
bufferOffset: Int,
305+
var underlyingBuffer: MutableRow)
306+
extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) {
307+
308+
override def apply(i: Int): Any = {
309+
if (i >= length || i < 0) {
310+
throw new IllegalArgumentException(
311+
s"Could not access ${i}th value in this buffer because it only has $length values.")
312+
}
313+
toScalaConverters(i)(underlyingBuffer(offsets(i)))
314+
}
315+
316+
def update(i: Int, value: Any): Unit = {
317+
if (i >= length || i < 0) {
318+
throw new IllegalArgumentException(
319+
s"Could not update ${i}th value in this buffer because it only has $length values.")
320+
}
321+
underlyingBuffer.update(offsets(i), toCatalystConverters(i)(value))
322+
}
323+
324+
override def copy(): MutableAggregationBuffer = {
325+
new MutableAggregationBuffer(
326+
toCatalystConverters,
327+
toScalaConverters,
328+
bufferOffset,
329+
underlyingBuffer)
330+
}
331+
}
332+
333+
class InputAggregationBuffer(
334+
toCatalystConverters: Array[Any => Any],
335+
toScalaConverters: Array[Any => Any],
336+
bufferOffset: Int,
337+
var underlyingInputBuffer: Row)
338+
extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) {
339+
340+
override def apply(i: Int): Any = {
341+
if (i >= length || i < 0) {
342+
throw new IllegalArgumentException(
343+
s"Could not access ${i}th value in this buffer because it only has $length values.")
344+
}
345+
toScalaConverters(i)(underlyingInputBuffer(offsets(i)))
346+
}
347+
348+
override def copy(): InputAggregationBuffer = {
349+
new InputAggregationBuffer(
350+
toCatalystConverters,
351+
toScalaConverters,
352+
bufferOffset,
353+
underlyingInputBuffer)
354+
}
355+
}

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
287287
@transient
288288
val udf: UDFRegistration = new UDFRegistration(this)
289289

290+
@transient
291+
val udaf: UDAFRegistration = new UDAFRegistration(this)
292+
290293
/**
291294
* Returns true if the table is currently cached in-memory.
292295
* @group cachemgmt
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import org.apache.spark.Logging
21+
import org.apache.spark.sql.catalyst.expressions.{Expression}
22+
import org.apache.spark.sql.execution.expressions.aggregate2.{ScalaUDAF, UserDefinedAggregateFunction}
23+
24+
25+
class UDAFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
26+
27+
private val functionRegistry = sqlContext.functionRegistry
28+
29+
def register(
30+
name: String,
31+
func: UserDefinedAggregateFunction): UserDefinedAggregateFunction = {
32+
def builder(children: Seq[Expression]) = ScalaUDAF(children, func)
33+
functionRegistry.registerFunction(name, builder)
34+
func
35+
}
36+
}
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.expressions.aggregate2
19+
20+
import org.apache.spark.Logging
21+
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
22+
import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters}
23+
import org.apache.spark.sql.catalyst.expressions._
24+
import org.apache.spark.sql.catalyst.expressions.aggregate2.{InputAggregationBuffer, AggregateFunction2, MutableAggregationBuffer}
25+
import org.apache.spark.sql.types._
26+
import org.apache.spark.sql.Row
27+
28+
abstract class UserDefinedAggregateFunction extends Serializable {
29+
30+
def inputDataType: StructType
31+
32+
def bufferSchema: StructType
33+
34+
def returnDataType: DataType
35+
36+
def deterministic: Boolean
37+
38+
def initialize(buffer: MutableAggregationBuffer): Unit
39+
40+
def update(buffer: MutableAggregationBuffer, input: Row): Unit
41+
42+
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit
43+
44+
def evaluate(buffer: Row): Any
45+
46+
}
47+
48+
case class ScalaUDAF(
49+
children: Seq[Expression],
50+
udaf: UserDefinedAggregateFunction)
51+
extends AggregateFunction2 with ImplicitCastInputTypes with Logging {
52+
53+
require(
54+
children.length == udaf.inputDataType.length,
55+
s"$udaf only accepts ${udaf.inputDataType.length} arguments, " +
56+
s"but ${children.length} are provided.")
57+
58+
override def nullable: Boolean = true
59+
60+
override def dataType: DataType = udaf.returnDataType
61+
62+
override def deterministic: Boolean = udaf.deterministic
63+
64+
override val inputTypes: Seq[DataType] = udaf.inputDataType.map(_.dataType)
65+
66+
override val bufferSchema: StructType = udaf.bufferSchema
67+
68+
override val bufferAttributes: Seq[AttributeReference] = bufferSchema.toAttributes
69+
70+
override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance())
71+
72+
val childrenSchema: StructType = {
73+
val inputFields = children.zipWithIndex.map {
74+
case (child, index) =>
75+
StructField(s"input$index", child.dataType, child.nullable, Metadata.empty)
76+
}
77+
StructType(inputFields)
78+
}
79+
80+
lazy val inputProjection = {
81+
val inputAttributes = childrenSchema.toAttributes
82+
log.debug(
83+
s"Creating MutableProj: $children, inputSchema: $inputAttributes.")
84+
try {
85+
GenerateMutableProjection.generate(children, inputAttributes)()
86+
} catch {
87+
case e: Exception =>
88+
log.error("Failed to generate mutable projection, fallback to interpreted", e)
89+
new InterpretedMutableProjection(children, inputAttributes)
90+
}
91+
}
92+
93+
val inputToScalaConverters: Any => Any =
94+
CatalystTypeConverters.createToScalaConverter(childrenSchema)
95+
96+
val bufferValuesToCatalystConverters: Array[Any => Any] = bufferSchema.fields.map { field =>
97+
CatalystTypeConverters.createToCatalystConverter(field.dataType)
98+
}
99+
100+
val bufferValuesToScalaConverters: Array[Any => Any] = bufferSchema.fields.map { field =>
101+
CatalystTypeConverters.createToScalaConverter(field.dataType)
102+
}
103+
104+
lazy val inputAggregateBuffer: InputAggregationBuffer =
105+
new InputAggregationBuffer(
106+
bufferValuesToCatalystConverters,
107+
bufferValuesToScalaConverters,
108+
bufferOffset,
109+
null)
110+
111+
lazy val mutableAggregateBuffer: MutableAggregationBuffer =
112+
new MutableAggregationBuffer(
113+
bufferValuesToCatalystConverters,
114+
bufferValuesToScalaConverters,
115+
bufferOffset,
116+
null)
117+
118+
119+
override def initialize(buffer: MutableRow): Unit = {
120+
mutableAggregateBuffer.underlyingBuffer = buffer
121+
122+
udaf.initialize(mutableAggregateBuffer)
123+
}
124+
125+
override def update(buffer: MutableRow, input: InternalRow): Unit = {
126+
mutableAggregateBuffer.underlyingBuffer = buffer
127+
128+
udaf.update(
129+
mutableAggregateBuffer,
130+
inputToScalaConverters(inputProjection(input)).asInstanceOf[Row])
131+
}
132+
133+
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
134+
mutableAggregateBuffer.underlyingBuffer = buffer1
135+
inputAggregateBuffer.underlyingInputBuffer = buffer2
136+
137+
udaf.update(mutableAggregateBuffer, inputAggregateBuffer)
138+
}
139+
140+
override def eval(buffer: InternalRow = null): Any = {
141+
inputAggregateBuffer.underlyingInputBuffer = buffer
142+
143+
udaf.evaluate(inputAggregateBuffer)
144+
}
145+
146+
override def toString: String = {
147+
s"""${udaf.getClass.getSimpleName}(${children.mkString(",")})"""
148+
}
149+
}
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package test.org.apache.spark.sql.hive.aggregate2;
19+
20+
import java.util.ArrayList;
21+
import java.util.List;
22+
23+
import org.apache.spark.sql.catalyst.expressions.aggregate2.MutableAggregationBuffer;
24+
import org.apache.spark.sql.execution.expressions.aggregate2.UserDefinedAggregateFunction;
25+
import org.apache.spark.sql.types.StructField;
26+
import org.apache.spark.sql.types.StructType;
27+
import org.apache.spark.sql.types.DataType;
28+
import org.apache.spark.sql.types.DataTypes;
29+
import org.apache.spark.sql.Row;
30+
31+
public class MyJavaUDAF extends UserDefinedAggregateFunction {
32+
33+
private StructType _inputDataType;
34+
35+
private StructType _bufferSchema;
36+
37+
private DataType _returnDataType;
38+
39+
public MyJavaUDAF() {
40+
List<StructField> inputfields = new ArrayList<StructField>();
41+
inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
42+
_inputDataType = DataTypes.createStructType(inputfields);
43+
44+
List<StructField> bufferFields = new ArrayList<StructField>();
45+
bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true));
46+
_bufferSchema = DataTypes.createStructType(bufferFields);
47+
48+
_returnDataType = DataTypes.DoubleType;
49+
}
50+
51+
@Override public StructType inputDataType() {
52+
return _inputDataType;
53+
}
54+
55+
@Override public StructType bufferSchema() {
56+
return _bufferSchema;
57+
}
58+
59+
@Override public DataType returnDataType() {
60+
return _returnDataType;
61+
}
62+
63+
@Override public boolean deterministic() {
64+
return true;
65+
}
66+
67+
@Override public void initialize(MutableAggregationBuffer buffer) {
68+
buffer.update(0, null);
69+
}
70+
71+
@Override public void update(MutableAggregationBuffer buffer, Row input) {
72+
if (!input.isNullAt(0)) {
73+
if (buffer.isNullAt(0)) {
74+
buffer.update(0, input.getDouble(0));
75+
} else {
76+
Double newValue = input.getDouble(0) * buffer.getDouble(0);
77+
buffer.update(0, newValue);
78+
}
79+
}
80+
}
81+
82+
@Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
83+
if (!buffer2.isNullAt(0)) {
84+
if (buffer1.isNullAt(0)) {
85+
buffer1.update(0, buffer2.getDouble(0));
86+
} else {
87+
Double newValue = buffer2.getDouble(0) * buffer1.getDouble(0);
88+
buffer1.update(0, newValue);
89+
}
90+
}
91+
}
92+
93+
@Override public Object evaluate(Row buffer) {
94+
if (buffer.isNullAt(0)) {
95+
return null;
96+
} else {
97+
return buffer.getDouble(0);
98+
}
99+
}
100+
}

0 commit comments

Comments
 (0)