diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala index 1dae5f6964e56..307321cb5b9df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala @@ -17,72 +17,70 @@ package org.apache.spark.sql.execution.aggregate +import scala.reflect.runtime.universe.TypeTag + import org.apache.spark.api.java.function.MapFunction import org.apache.spark.sql.{Encoder, TypedColumn} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.expressions.Aggregator + //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines internal implementations for aggregators. //////////////////////////////////////////////////////////////////////////////////////////////////// +abstract class TypedAggregator[IN, BUF: TypeTag, OUT: TypeTag, JAVA] + extends Aggregator[IN, BUF, OUT] { + + def bufferEncoder: Encoder[BUF] = ExpressionEncoder[BUF]() + def outputEncoder: Encoder[OUT] = ExpressionEncoder[OUT]() + + def toColumnJava: TypedColumn[IN, JAVA] = { + toColumn.asInstanceOf[TypedColumn[IN, JAVA]] + } +} + +class TypedSumDouble[IN](val f: IN => Double) + extends TypedAggregator[IN, Double, Double, java.lang.Double] { -class TypedSumDouble[IN](val f: IN => Double) extends Aggregator[IN, Double, Double] { override def zero: Double = 0.0 override def reduce(b: Double, a: IN): Double = b + f(a) override def merge(b1: Double, b2: Double): Double = b1 + b2 override def finish(reduction: Double): Double = reduction - override def bufferEncoder: Encoder[Double] = ExpressionEncoder[Double]() - override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]() - - // Java api support + // Java constructor def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double]) - - def toColumnJava: TypedColumn[IN, java.lang.Double] = { - toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]] - } } -class TypedSumLong[IN](val f: IN => Long) extends Aggregator[IN, Long, Long] { +class TypedSumLong[IN](val f: IN => Long) + extends TypedAggregator[IN, Long, Long, java.lang.Long] { + override def zero: Long = 0L override def reduce(b: Long, a: IN): Long = b + f(a) override def merge(b1: Long, b2: Long): Long = b1 + b2 override def finish(reduction: Long): Long = reduction - override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]() - override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]() - - // Java api support + // Java constructor def this(f: MapFunction[IN, java.lang.Long]) = this(x => f.call(x).asInstanceOf[Long]) - - def toColumnJava: TypedColumn[IN, java.lang.Long] = { - toColumn.asInstanceOf[TypedColumn[IN, java.lang.Long]] - } } +class TypedCount[IN](val f: IN => Any) + extends TypedAggregator[IN, Long, Long, java.lang.Long] { -class TypedCount[IN](val f: IN => Any) extends Aggregator[IN, Long, Long] { override def zero: Long = 0 - override def reduce(b: Long, a: IN): Long = { - if (f(a) == null) b else b + 1 - } + override def reduce(b: Long, a: IN): Long = if (f(a) == null) b else b + 1 override def merge(b1: Long, b2: Long): Long = b1 + b2 override def finish(reduction: Long): Long = reduction - override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]() - override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]() - - // Java api support + // Java constructor def this(f: MapFunction[IN, Object]) = this(x => f.call(x)) - def toColumnJava: TypedColumn[IN, java.lang.Long] = { - toColumn.asInstanceOf[TypedColumn[IN, java.lang.Long]] - } + } +class TypedAverage[IN](val f: IN => Double) + extends TypedAggregator[IN, (Double, Long), Double, java.lang.Double] { -class TypedAverage[IN](val f: IN => Double) extends Aggregator[IN, (Double, Long), Double] { override def zero: (Double, Long) = (0.0, 0L) override def reduce(b: (Double, Long), a: IN): (Double, Long) = (f(a) + b._1, 1 + b._2) override def finish(reduction: (Double, Long)): Double = reduction._1 / reduction._2 @@ -90,12 +88,6 @@ class TypedAverage[IN](val f: IN => Double) extends Aggregator[IN, (Double, Long (b1._1 + b2._1, b1._2 + b2._2) } - override def bufferEncoder: Encoder[(Double, Long)] = ExpressionEncoder[(Double, Long)]() - override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]() - // Java api support def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double]) - def toColumnJava: TypedColumn[IN, java.lang.Double] = { - toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]] - } }