1818use crate :: signature:: TypeSignature ;
1919use arrow:: datatypes:: {
2020 DataType , FieldRef , TimeUnit , DECIMAL128_MAX_PRECISION , DECIMAL128_MAX_SCALE ,
21- DECIMAL256_MAX_PRECISION , DECIMAL256_MAX_SCALE ,
21+ DECIMAL256_MAX_PRECISION , DECIMAL256_MAX_SCALE , DECIMAL32_MAX_PRECISION ,
22+ DECIMAL64_MAX_PRECISION ,
2223} ;
2324
2425use datafusion_common:: { internal_err, plan_err, Result } ;
@@ -150,6 +151,18 @@ pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
150151 DataType :: Int64 => Ok ( DataType :: Int64 ) ,
151152 DataType :: UInt64 => Ok ( DataType :: UInt64 ) ,
152153 DataType :: Float64 => Ok ( DataType :: Float64 ) ,
154+ DataType :: Decimal32 ( precision, scale) => {
155+ // in the spark, the result type is DECIMAL(min(38,precision+10), s)
156+ // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
157+ let new_precision = DECIMAL32_MAX_PRECISION . min ( * precision + 10 ) ;
158+ Ok ( DataType :: Decimal128 ( new_precision, * scale) )
159+ }
160+ DataType :: Decimal64 ( precision, scale) => {
161+ // in the spark, the result type is DECIMAL(min(38,precision+10), s)
162+ // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
163+ let new_precision = DECIMAL64_MAX_PRECISION . min ( * precision + 10 ) ;
164+ Ok ( DataType :: Decimal128 ( new_precision, * scale) )
165+ }
153166 DataType :: Decimal128 ( precision, scale) => {
154167 // In the spark, the result type is DECIMAL(min(38,precision+10), s)
155168 // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
@@ -222,6 +235,16 @@ pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result<DataType>
222235/// Internal sum type of an average
223236pub fn avg_sum_type ( arg_type : & DataType ) -> Result < DataType > {
224237 match arg_type {
238+ DataType :: Decimal32 ( precision, scale) => {
239+ // In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s)
240+ let new_precision = DECIMAL32_MAX_PRECISION . min ( * precision + 10 ) ;
241+ Ok ( DataType :: Decimal32 ( new_precision, * scale) )
242+ }
243+ DataType :: Decimal64 ( precision, scale) => {
244+ // In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s)
245+ let new_precision = DECIMAL64_MAX_PRECISION . min ( * precision + 10 ) ;
246+ Ok ( DataType :: Decimal64 ( new_precision, * scale) )
247+ }
225248 DataType :: Decimal128 ( precision, scale) => {
226249 // In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s)
227250 let new_precision = DECIMAL128_MAX_PRECISION . min ( * precision + 10 ) ;
@@ -249,7 +272,7 @@ pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool {
249272 _ => matches ! (
250273 arg_type,
251274 arg_type if NUMERICS . contains( arg_type)
252- || matches!( arg_type, DataType :: Decimal128 ( _, _) | DataType :: Decimal256 ( _, _) )
275+ || matches!( arg_type, DataType :: Decimal32 ( _ , _ ) | DataType :: Decimal64 ( _ , _ ) | DataType :: Decimal128 ( _, _) | DataType :: Decimal256 ( _, _) )
253276 ) ,
254277 }
255278}
@@ -262,7 +285,7 @@ pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool {
262285 _ => matches ! (
263286 arg_type,
264287 arg_type if NUMERICS . contains( arg_type)
265- || matches!( arg_type, DataType :: Decimal128 ( _, _) | DataType :: Decimal256 ( _, _) )
288+ || matches!( arg_type, DataType :: Decimal32 ( _ , _ ) | DataType :: Decimal64 ( _ , _ ) | DataType :: Decimal128 ( _, _) | DataType :: Decimal256 ( _, _) )
266289 ) ,
267290 }
268291}
@@ -297,6 +320,8 @@ pub fn coerce_avg_type(func_name: &str, arg_types: &[DataType]) -> Result<Vec<Da
297320 // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
298321 fn coerced_type ( func_name : & str , data_type : & DataType ) -> Result < DataType > {
299322 match & data_type {
323+ DataType :: Decimal32 ( p, s) => Ok ( DataType :: Decimal32 ( * p, * s) ) ,
324+ DataType :: Decimal64 ( p, s) => Ok ( DataType :: Decimal64 ( * p, * s) ) ,
300325 DataType :: Decimal128 ( p, s) => Ok ( DataType :: Decimal128 ( * p, * s) ) ,
301326 DataType :: Decimal256 ( p, s) => Ok ( DataType :: Decimal256 ( * p, * s) ) ,
302327 d if d. is_numeric ( ) => Ok ( DataType :: Float64 ) ,
0 commit comments