diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c8e9d8e2f9dd..991d0270d004 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -226,7 +226,7 @@ object EliminateSerialization extends Rule[LogicalPlan] { val objAttr = Alias(s.inputObjAttr, s.inputObjAttr.name)(exprId = d.outputObjAttr.exprId) Project(objAttr :: Nil, s.child) - case a @ AppendColumns(_, _, _, s: SerializeFromObject) + case a @ AppendColumns(_, _, _, _, _, s: SerializeFromObject) if a.deserializer.dataType == s.inputObjAttr.dataType => AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child) @@ -235,7 +235,7 @@ object EliminateSerialization extends Rule[LogicalPlan] { // deserialization in condition, and push it down through `SerializeFromObject`. // e.g. `ds.map(...).filter(...)` can be optimized by this rule to save extra deserialization, // but `ds.map(...).as[AnotherType].filter(...)` can not be optimized. - case f @ TypedFilter(_, _, s: SerializeFromObject) + case f @ TypedFilter(_, _, _, _, s: SerializeFromObject) if f.deserializer.dataType == s.inputObjAttr.dataType => s.copy(child = f.withObjectProducerChild(s.child)) @@ -1719,9 +1719,14 @@ case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends Rule[Logic */ object CombineTypedFilters extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case t1 @ TypedFilter(_, _, t2 @ TypedFilter(_, _, child)) + case t1 @ TypedFilter(_, _, _, _, t2 @ TypedFilter(_, _, _, _, child)) if t1.deserializer.dataType == t2.deserializer.dataType => - TypedFilter(combineFilterFunction(t2.func, t1.func), t1.deserializer, child) + TypedFilter( + combineFilterFunction(t2.func, t1.func), + t1.argumentClass, + t1.argumentSchema, + t1.deserializer, + child) } private def combineFilterFunction(func1: AnyRef, func2: AnyRef): Any => Boolean = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index e1890edcbb11..fefe5a3953a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -155,6 +155,8 @@ object MapElements { val deserialized = CatalystSerde.deserialize[T](child) val mapped = MapElements( func, + implicitly[Encoder[T]].clsTag.runtimeClass, + implicitly[Encoder[T]].schema, CatalystSerde.generateObjAttr[U], deserialized) CatalystSerde.serialize[U](mapped) @@ -166,12 +168,19 @@ object MapElements { */ case class MapElements( func: AnyRef, + argumentClass: Class[_], + argumentSchema: StructType, outputObjAttr: Attribute, child: LogicalPlan) extends ObjectConsumer with ObjectProducer object TypedFilter { def apply[T : Encoder](func: AnyRef, child: LogicalPlan): TypedFilter = { - TypedFilter(func, UnresolvedDeserializer(encoderFor[T].deserializer), child) + TypedFilter( + func, + implicitly[Encoder[T]].clsTag.runtimeClass, + implicitly[Encoder[T]].schema, + UnresolvedDeserializer(encoderFor[T].deserializer), + child) } } @@ -186,6 +195,8 @@ object TypedFilter { */ case class TypedFilter( func: AnyRef, + argumentClass: Class[_], + argumentSchema: StructType, deserializer: Expression, child: LogicalPlan) extends UnaryNode { @@ -213,6 +224,8 @@ object AppendColumns { child: LogicalPlan): AppendColumns = { new AppendColumns( func.asInstanceOf[Any => Any], + implicitly[Encoder[T]].clsTag.runtimeClass, + implicitly[Encoder[T]].schema, UnresolvedDeserializer(encoderFor[T].deserializer), encoderFor[U].namedExpressions, child) @@ -228,6 +241,8 @@ object AppendColumns { */ case class AppendColumns( func: Any => Any, + argumentClass: Class[_], + argumentSchema: StructType, deserializer: Expression, serializer: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 52e19819f2f6..cb172de9d7ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -356,9 +356,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.FlatMapGroupsInR(f, p, b, is, os, key, value, grouping, data, objAttr, child) => execution.FlatMapGroupsInRExec(f, p, b, is, os, key, value, grouping, data, objAttr, planLater(child)) :: Nil - case logical.MapElements(f, objAttr, child) => + case logical.MapElements(f, _, _, objAttr, child) => execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil - case logical.AppendColumns(f, in, out, child) => + case logical.AppendColumns(f, _, _, in, out, child) => execution.AppendColumnsExec(f, in, out, planLater(child)) :: Nil case logical.AppendColumnsWithObject(f, childSer, newSer, child) => execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala index c39a78da6f9b..1dae5f6964e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.expressions.Aggregator //////////////////////////////////////////////////////////////////////////////////////////////////// -class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double] { +class TypedSumDouble[IN](val f: IN => Double) extends Aggregator[IN, Double, Double] { override def zero: Double = 0.0 override def reduce(b: Double, a: IN): Double = b + f(a) override def merge(b1: Double, b2: Double): Double = b1 + b2 @@ -45,7 +45,7 @@ class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double] } -class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] { +class TypedSumLong[IN](val f: IN => Long) extends Aggregator[IN, Long, Long] { override def zero: Long = 0L override def reduce(b: Long, a: IN): Long = b + f(a) override def merge(b1: Long, b2: Long): Long = b1 + b2 @@ -63,7 +63,7 @@ class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] { } -class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] { +class TypedCount[IN](val f: IN => Any) extends Aggregator[IN, Long, Long] { override def zero: Long = 0 override def reduce(b: Long, a: IN): Long = { if (f(a) == null) b else b + 1 @@ -82,7 +82,7 @@ class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] { } -class TypedAverage[IN](f: IN => Double) extends Aggregator[IN, (Double, Long), Double] { +class TypedAverage[IN](val f: IN => Double) extends Aggregator[IN, (Double, Long), Double] { override def zero: (Double, Long) = (0.0, 0L) override def reduce(b: (Double, Long), a: IN): (Double, Long) = (f(a) + b._1, 1 + b._2) override def finish(reduction: (Double, Long)): Double = reduction._1 / reduction._2