Skip to content

Commit bca43cd

Browse files
clockflycloud-fan
authored andcommitted
[SPARK-16898][SQL] Adds argument type information for typed logical plan like MapElements, TypedFilter, and AppendColumn
## What changes were proposed in this pull request? This PR adds argument type information for typed logical plan like MapElements, TypedFilter, and AppendColumn, so that we can use these info in customized optimizer rule. ## How was this patch tested? Existing test. Author: Sean Zhong <[email protected]> Closes #14494 from clockfly/add_more_info_for_typed_operator.
1 parent df10658 commit bca43cd

File tree

4 files changed

+31
-11
lines changed

4 files changed

+31
-11
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ object EliminateSerialization extends Rule[LogicalPlan] {
214214
val objAttr = Alias(s.inputObjAttr, s.inputObjAttr.name)(exprId = d.outputObjAttr.exprId)
215215
Project(objAttr :: Nil, s.child)
216216

217-
case a @ AppendColumns(_, _, _, s: SerializeFromObject)
217+
case a @ AppendColumns(_, _, _, _, _, s: SerializeFromObject)
218218
if a.deserializer.dataType == s.inputObjAttr.dataType =>
219219
AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child)
220220

@@ -223,7 +223,7 @@ object EliminateSerialization extends Rule[LogicalPlan] {
223223
// deserialization in condition, and push it down through `SerializeFromObject`.
224224
// e.g. `ds.map(...).filter(...)` can be optimized by this rule to save extra deserialization,
225225
// but `ds.map(...).as[AnotherType].filter(...)` can not be optimized.
226-
case f @ TypedFilter(_, _, s: SerializeFromObject)
226+
case f @ TypedFilter(_, _, _, _, s: SerializeFromObject)
227227
if f.deserializer.dataType == s.inputObjAttr.dataType =>
228228
s.copy(child = f.withObjectProducerChild(s.child))
229229

@@ -1703,9 +1703,14 @@ case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends Rule[Logic
17031703
*/
17041704
object CombineTypedFilters extends Rule[LogicalPlan] {
17051705
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
1706-
case t1 @ TypedFilter(_, _, t2 @ TypedFilter(_, _, child))
1706+
case t1 @ TypedFilter(_, _, _, _, t2 @ TypedFilter(_, _, _, _, child))
17071707
if t1.deserializer.dataType == t2.deserializer.dataType =>
1708-
TypedFilter(combineFilterFunction(t2.func, t1.func), t1.deserializer, child)
1708+
TypedFilter(
1709+
combineFilterFunction(t2.func, t1.func),
1710+
t1.argumentClass,
1711+
t1.argumentSchema,
1712+
t1.deserializer,
1713+
child)
17091714
}
17101715

17111716
private def combineFilterFunction(func1: AnyRef, func2: AnyRef): Any => Boolean = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ object MapElements {
155155
val deserialized = CatalystSerde.deserialize[T](child)
156156
val mapped = MapElements(
157157
func,
158+
implicitly[Encoder[T]].clsTag.runtimeClass,
159+
implicitly[Encoder[T]].schema,
158160
CatalystSerde.generateObjAttr[U],
159161
deserialized)
160162
CatalystSerde.serialize[U](mapped)
@@ -166,12 +168,19 @@ object MapElements {
166168
*/
167169
case class MapElements(
168170
func: AnyRef,
171+
argumentClass: Class[_],
172+
argumentSchema: StructType,
169173
outputObjAttr: Attribute,
170174
child: LogicalPlan) extends ObjectConsumer with ObjectProducer
171175

172176
object TypedFilter {
173177
def apply[T : Encoder](func: AnyRef, child: LogicalPlan): TypedFilter = {
174-
TypedFilter(func, UnresolvedDeserializer(encoderFor[T].deserializer), child)
178+
TypedFilter(
179+
func,
180+
implicitly[Encoder[T]].clsTag.runtimeClass,
181+
implicitly[Encoder[T]].schema,
182+
UnresolvedDeserializer(encoderFor[T].deserializer),
183+
child)
175184
}
176185
}
177186

@@ -186,6 +195,8 @@ object TypedFilter {
186195
*/
187196
case class TypedFilter(
188197
func: AnyRef,
198+
argumentClass: Class[_],
199+
argumentSchema: StructType,
189200
deserializer: Expression,
190201
child: LogicalPlan) extends UnaryNode {
191202

@@ -213,6 +224,8 @@ object AppendColumns {
213224
child: LogicalPlan): AppendColumns = {
214225
new AppendColumns(
215226
func.asInstanceOf[Any => Any],
227+
implicitly[Encoder[T]].clsTag.runtimeClass,
228+
implicitly[Encoder[T]].schema,
216229
UnresolvedDeserializer(encoderFor[T].deserializer),
217230
encoderFor[U].namedExpressions,
218231
child)
@@ -228,6 +241,8 @@ object AppendColumns {
228241
*/
229242
case class AppendColumns(
230243
func: Any => Any,
244+
argumentClass: Class[_],
245+
argumentSchema: StructType,
231246
deserializer: Expression,
232247
serializer: Seq[NamedExpression],
233248
child: LogicalPlan) extends UnaryNode {

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,9 +356,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
356356
case logical.FlatMapGroupsInR(f, p, b, is, os, key, value, grouping, data, objAttr, child) =>
357357
execution.FlatMapGroupsInRExec(f, p, b, is, os, key, value, grouping,
358358
data, objAttr, planLater(child)) :: Nil
359-
case logical.MapElements(f, objAttr, child) =>
359+
case logical.MapElements(f, _, _, objAttr, child) =>
360360
execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil
361-
case logical.AppendColumns(f, in, out, child) =>
361+
case logical.AppendColumns(f, _, _, in, out, child) =>
362362
execution.AppendColumnsExec(f, in, out, planLater(child)) :: Nil
363363
case logical.AppendColumnsWithObject(f, childSer, newSer, child) =>
364364
execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.sql.expressions.Aggregator
2727
////////////////////////////////////////////////////////////////////////////////////////////////////
2828

2929

30-
class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double] {
30+
class TypedSumDouble[IN](val f: IN => Double) extends Aggregator[IN, Double, Double] {
3131
override def zero: Double = 0.0
3232
override def reduce(b: Double, a: IN): Double = b + f(a)
3333
override def merge(b1: Double, b2: Double): Double = b1 + b2
@@ -45,7 +45,7 @@ class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double]
4545
}
4646

4747

48-
class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] {
48+
class TypedSumLong[IN](val f: IN => Long) extends Aggregator[IN, Long, Long] {
4949
override def zero: Long = 0L
5050
override def reduce(b: Long, a: IN): Long = b + f(a)
5151
override def merge(b1: Long, b2: Long): Long = b1 + b2
@@ -63,7 +63,7 @@ class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] {
6363
}
6464

6565

66-
class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] {
66+
class TypedCount[IN](val f: IN => Any) extends Aggregator[IN, Long, Long] {
6767
override def zero: Long = 0
6868
override def reduce(b: Long, a: IN): Long = {
6969
if (f(a) == null) b else b + 1
@@ -82,7 +82,7 @@ class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] {
8282
}
8383

8484

85-
class TypedAverage[IN](f: IN => Double) extends Aggregator[IN, (Double, Long), Double] {
85+
class TypedAverage[IN](val f: IN => Double) extends Aggregator[IN, (Double, Long), Double] {
8686
override def zero: (Double, Long) = (0.0, 0L)
8787
override def reduce(b: (Double, Long), a: IN): (Double, Long) = (f(a) + b._1, 1 + b._2)
8888
override def finish(reduction: (Double, Long)): Double = reduction._1 / reduction._2

0 commit comments

Comments
 (0)