@@ -66,6 +66,33 @@ class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFun
6666 }
6767}
6868
69+ class ScalaAggregateFunctionWithoutInputSchema extends UserDefinedAggregateFunction {
70+
71+ def inputSchema : StructType = StructType (Nil )
72+
73+ def bufferSchema : StructType = StructType (StructField (" value" , LongType ) :: Nil )
74+
75+ def dataType : DataType = LongType
76+
77+ def deterministic : Boolean = true
78+
79+ def initialize (buffer : MutableAggregationBuffer ): Unit = {
80+ buffer.update(0 , 0L )
81+ }
82+
83+ def update (buffer : MutableAggregationBuffer , input : Row ): Unit = {
84+ buffer.update(0 , input.getAs[Seq [Row ]](0 ).map(_.getAs[Int ](" v" )).sum + buffer.getLong(0 ))
85+ }
86+
87+ def merge (buffer1 : MutableAggregationBuffer , buffer2 : Row ): Unit = {
88+ buffer1.update(0 , buffer1.getLong(0 ) + buffer2.getLong(0 ))
89+ }
90+
91+ def evaluate (buffer : Row ): Any = {
92+ buffer.getLong(0 )
93+ }
94+ }
95+
6996class LongProductSum extends UserDefinedAggregateFunction {
7097 def inputSchema : StructType = new StructType ()
7198 .add(" a" , LongType )
@@ -858,6 +885,43 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
858885 )
859886 }
860887 }
888+
889+ test(" udaf without specifying inputSchema" ) {
890+ withTempTable(" noInputSchemaUDAF" ) {
891+ sqlContext.udf.register(" noInputSchema" , new ScalaAggregateFunctionWithoutInputSchema )
892+
893+ val data =
894+ Row (1 , Seq (Row (1 ), Row (2 ), Row (3 ))) ::
895+ Row (1 , Seq (Row (4 ), Row (5 ), Row (6 ))) ::
896+ Row (2 , Seq (Row (- 10 ))) :: Nil
897+ val schema =
898+ StructType (
899+ StructField (" key" , IntegerType ) ::
900+ StructField (" myArray" ,
901+ ArrayType (StructType (StructField (" v" , IntegerType ) :: Nil ))) :: Nil )
902+ sqlContext.createDataFrame(
903+ sparkContext.parallelize(data, 2 ),
904+ schema)
905+ .registerTempTable(" noInputSchemaUDAF" )
906+
907+ checkAnswer(
908+ sqlContext.sql(
909+ """
910+ |SELECT key, noInputSchema(myArray)
911+ |FROM noInputSchemaUDAF
912+ |GROUP BY key
913+ """ .stripMargin),
914+ Row (1 , 21 ) :: Row (2 , - 10 ) :: Nil )
915+
916+ checkAnswer(
917+ sqlContext.sql(
918+ """
919+ |SELECT noInputSchema(myArray)
920+ |FROM noInputSchemaUDAF
921+ """ .stripMargin),
922+ Row (11 ) :: Nil )
923+ }
924+ }
861925}
862926
863927
0 commit comments