Skip to content

Commit 0e554cd

Browse files
committed
add optimize rules
1 parent 98744f0 commit 0e554cd

File tree

9 files changed

+245
-11
lines changed

9 files changed

+245
-11
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1599,6 +1599,8 @@ object CleanupAliases extends Rule[LogicalPlan] {
15991599
// Operators that operate on objects should only have expressions from encoders, which should
16001600
// never have extra aliases.
16011601
case o: ObjectOperator => o
1602+
case d: DeserializeToObject => d
1603+
case s: SerializeFromObject => s
16021604

16031605
case other =>
16041606
var stop = false

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp}
2121

2222
import scala.language.implicitConversions
2323

24+
import org.apache.spark.sql.Encoder
2425
import org.apache.spark.sql.catalyst.analysis._
2526
import org.apache.spark.sql.catalyst.expressions._
2627
import org.apache.spark.sql.catalyst.expressions.aggregate._
@@ -166,6 +167,14 @@ package object dsl {
166167
case target => UnresolvedStar(Option(target))
167168
}
168169

170+
def callFunction[T, U](
171+
func: T => U,
172+
returnType: DataType,
173+
argument: Expression): Expression = {
174+
val function = Literal.create(func, ObjectType(classOf[T => U]))
175+
Invoke(function, "apply", returnType, argument :: Nil)
176+
}
177+
169178
implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name }
170179
// TODO more implicit class for literal?
171180
implicit class DslString(val s: String) extends ImplicitOperators {
@@ -270,6 +279,16 @@ package object dsl {
270279

271280
def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan)
272281

282+
def filter[T : Encoder](func: T => Boolean): LogicalPlan = {
283+
val deserialized = logicalPlan.deserialize[T]
284+
val condition = expressions.callFunction(func, BooleanType, deserialized.output.head)
285+
Filter(condition, deserialized).serialize[T]
286+
}
287+
288+
def serialize[T : Encoder]: LogicalPlan = CatalystSerde.serialize[T](logicalPlan)
289+
290+
def deserialize[T : Encoder]: LogicalPlan = CatalystSerde.deserialize[T](logicalPlan)
291+
273292
def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan)
274293

275294
def join(

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
9393
EliminateSerialization) ::
9494
Batch("Decimal Optimizations", FixedPoint(100),
9595
DecimalAggregates) ::
96+
Batch("Typed Filter Optimization", FixedPoint(100),
97+
EmbedSerializerInFilter) ::
9698
Batch("LocalRelation", FixedPoint(100),
9799
ConvertToLocalRelation) ::
98100
Batch("Subquery", Once,
@@ -147,12 +149,18 @@ object EliminateSerialization extends Rule[LogicalPlan] {
147149
child = childWithoutSerialization)
148150

149151
case m @ MapElements(_, deserializer, _, child: ObjectOperator)
150-
if !deserializer.isInstanceOf[Attribute] &&
151-
deserializer.dataType == child.outputObject.dataType =>
152+
if !deserializer.isInstanceOf[Attribute] &&
153+
deserializer.dataType == child.outputObject.dataType =>
152154
val childWithoutSerialization = child.withObjectOutput
153155
m.copy(
154156
deserializer = childWithoutSerialization.output.head,
155157
child = childWithoutSerialization)
158+
159+
case d @ DeserializeToObject(_, s: SerializeFromObject)
160+
if d.outputObjectType == s.inputObjectType =>
161+
// Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
162+
val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId)
163+
Project(objAttr :: Nil, s.child)
156164
}
157165
}
158166

@@ -1329,3 +1337,30 @@ object ComputeCurrentTime extends Rule[LogicalPlan] {
13291337
}
13301338
}
13311339
}
1340+
1341+
/**
1342+
* Typed [[Filter]] is by default surrounded by a [[DeserializeToObject]] beneath it and a
1343+
* [[SerializeFromObject]] above it. If these serializations can't be eliminated, we should embed
1344+
* the deserializer in filter condition to save the extra serialization at last.
1345+
*/
1346+
object EmbedSerializerInFilter extends Rule[LogicalPlan] {
1347+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
1348+
case s @ SerializeFromObject(_, Filter(condition, d: DeserializeToObject)) =>
1349+
val numObjects = condition.collect {
1350+
case a: Attribute if a == d.output.head => a
1351+
}.length
1352+
1353+
if (numObjects > 1) {
1354+
// If the filter condition references the object more than one times, we should not embed
1355+
// deserializer in it as the deserialization will happen many times and slow down the
1356+
// execution.
1357+
// TODO: we can still embed it if we can make sure subexpression elimination works here.
1358+
s
1359+
} else {
1360+
val newCondition = condition transform {
1361+
case a: Attribute if a == d.output.head => d.deserializer.child
1362+
}
1363+
Filter(newCondition, d.child)
1364+
}
1365+
}
1366+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,34 @@ import org.apache.spark.sql.Encoder
2121
import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
2222
import org.apache.spark.sql.catalyst.encoders._
2323
import org.apache.spark.sql.catalyst.expressions._
24-
import org.apache.spark.sql.types.{ObjectType, StructType}
24+
import org.apache.spark.sql.types.{DataType, ObjectType, StructType}
25+
26+
object CatalystSerde {
27+
def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = {
28+
val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer)
29+
DeserializeToObject(Alias(deserializer, "obj")(), child)
30+
}
31+
32+
def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = {
33+
SerializeFromObject(encoderFor[T].namedExpressions, child)
34+
}
35+
}
36+
37+
case class DeserializeToObject(
38+
deserializer: Alias,
39+
child: LogicalPlan) extends UnaryNode {
40+
override def output: Seq[Attribute] = deserializer.toAttribute :: Nil
41+
42+
def outputObjectType: DataType = deserializer.dataType
43+
}
44+
45+
case class SerializeFromObject(
46+
serializer: Seq[NamedExpression],
47+
child: LogicalPlan) extends UnaryNode {
48+
override def output: Seq[Attribute] = serializer.map(_.toAttribute)
49+
50+
def inputObjectType: DataType = child.output.head.dataType
51+
}
2552

2653
/**
2754
* A trait for logical operators that apply user defined functions to domain objects.
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import scala.reflect.runtime.universe.TypeTag
21+
22+
import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
23+
import org.apache.spark.sql.catalyst.dsl.expressions._
24+
import org.apache.spark.sql.catalyst.dsl.plans._
25+
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
26+
import org.apache.spark.sql.catalyst.plans.PlanTest
27+
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
28+
import org.apache.spark.sql.catalyst.rules.RuleExecutor
29+
import org.apache.spark.sql.types.BooleanType
30+
31+
class TypedFilterOptimizationSuite extends PlanTest {
32+
object Optimize extends RuleExecutor[LogicalPlan] {
33+
val batches =
34+
Batch("EliminateSerialization", FixedPoint(50),
35+
EliminateSerialization) ::
36+
Batch("EmbedSerializerInFilter", FixedPoint(50),
37+
EmbedSerializerInFilter) :: Nil
38+
}
39+
40+
implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]()
41+
42+
test("back to back filter") {
43+
val input = LocalRelation('_1.int, '_2.int)
44+
val f1 = (i: (Int, Int)) => i._1 > 0
45+
val f2 = (i: (Int, Int)) => i._2 > 0
46+
47+
val query = input.filter(f1).filter(f2).analyze
48+
49+
val optimized = Optimize.execute(query)
50+
51+
val expected = input.deserialize[(Int, Int)]
52+
.where(callFunction(f1, BooleanType, 'obj))
53+
.select('obj.as("obj"))
54+
.where(callFunction(f2, BooleanType, 'obj))
55+
.serialize[(Int, Int)].analyze
56+
57+
comparePlans(optimized, expected)
58+
}
59+
60+
test("embed deserializer in filter condition if there is only one filter") {
61+
val input = LocalRelation('_1.int, '_2.int)
62+
val f = (i: (Int, Int)) => i._1 > 0
63+
64+
val query = input.filter(f).analyze
65+
66+
val optimized = Optimize.execute(query)
67+
68+
val deserializer = UnresolvedDeserializer(encoderFor[(Int, Int)].deserializer)
69+
val condition = callFunction(f, BooleanType, deserializer)
70+
val expected = input.where(condition).analyze
71+
72+
comparePlans(optimized, expected)
73+
}
74+
}

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1880,10 +1880,11 @@ class Dataset[T] private[sql](
18801880
*/
18811881
@Experimental
18821882
def filter(func: T => Boolean): Dataset[T] = {
1883-
val deserializer = unresolvedTEncoder.deserializer
1883+
val deserialized = CatalystSerde.deserialize[T](logicalPlan)
18841884
val function = Literal.create(func, ObjectType(classOf[T => Boolean]))
1885-
val condition = Invoke(function, "apply", BooleanType, deserializer :: Nil)
1886-
withTypedPlan(Filter(condition, logicalPlan))
1885+
val condition = Invoke(function, "apply", BooleanType, deserialized.output)
1886+
val filter = Filter(condition, deserialized)
1887+
withTypedPlan(CatalystSerde.serialize[T](filter))
18871888
}
18881889

18891890
/**
@@ -1896,10 +1897,11 @@ class Dataset[T] private[sql](
18961897
*/
18971898
@Experimental
18981899
def filter(func: FilterFunction[T]): Dataset[T] = {
1899-
val deserializer = unresolvedTEncoder.deserializer
1900+
val deserialized = CatalystSerde.deserialize[T](logicalPlan)
19001901
val function = Literal.create(func, ObjectType(classOf[FilterFunction[T]]))
1901-
val condition = Invoke(function, "call", BooleanType, deserializer :: Nil)
1902-
withTypedPlan(Filter(condition, logicalPlan))
1902+
val condition = Invoke(function, "call", BooleanType, deserialized.output)
1903+
val filter = Filter(condition, deserialized)
1904+
withTypedPlan(CatalystSerde.serialize[T](filter))
19031905
}
19041906

19051907
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
339339
throw new IllegalStateException(
340340
"logical intersect operator should have been replaced by semi-join in the optimizer")
341341

342+
case logical.DeserializeToObject(deserializer, child) =>
343+
execution.DeserializeToObject(deserializer, planLater(child)) :: Nil
344+
case logical.SerializeFromObject(serializer, child) =>
345+
execution.SerializeFromObject(serializer, planLater(child)) :: Nil
342346
case logical.MapPartitions(f, in, out, child) =>
343347
execution.MapPartitions(f, in, out, planLater(child)) :: Nil
344348
case logical.MapElements(f, in, out, child) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,69 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
2727
import org.apache.spark.sql.catalyst.plans.physical._
2828
import org.apache.spark.sql.types.ObjectType
2929

30+
case class DeserializeToObject(
31+
deserializer: Alias,
32+
child: SparkPlan) extends UnaryNode with CodegenSupport {
33+
override def output: Seq[Attribute] = deserializer.toAttribute :: Nil
34+
35+
override def upstreams(): Seq[RDD[InternalRow]] = {
36+
child.asInstanceOf[CodegenSupport].upstreams()
37+
}
38+
39+
protected override def doProduce(ctx: CodegenContext): String = {
40+
child.asInstanceOf[CodegenSupport].produce(ctx, this)
41+
}
42+
43+
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
44+
val bound = ExpressionCanonicalizer.execute(
45+
BindReferences.bindReference(deserializer, child.output))
46+
ctx.currentVars = input
47+
val resultVars = bound.gen(ctx) :: Nil
48+
s"""
49+
${consume(ctx, resultVars)}
50+
"""
51+
}
52+
53+
override protected def doExecute(): RDD[InternalRow] = {
54+
child.execute().mapPartitionsInternal { iter =>
55+
val projection = GenerateSafeProjection.generate(deserializer :: Nil, child.output)
56+
iter.map(projection)
57+
}
58+
}
59+
}
60+
61+
case class SerializeFromObject(
62+
serializer: Seq[NamedExpression],
63+
child: SparkPlan) extends UnaryNode with CodegenSupport {
64+
override def output: Seq[Attribute] = serializer.map(_.toAttribute)
65+
66+
override def upstreams(): Seq[RDD[InternalRow]] = {
67+
child.asInstanceOf[CodegenSupport].upstreams()
68+
}
69+
70+
protected override def doProduce(ctx: CodegenContext): String = {
71+
child.asInstanceOf[CodegenSupport].produce(ctx, this)
72+
}
73+
74+
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
75+
val bound = serializer.map { expr =>
76+
ExpressionCanonicalizer.execute(BindReferences.bindReference(expr, child.output))
77+
}
78+
ctx.currentVars = input
79+
val resultVars = bound.map(_.gen(ctx))
80+
s"""
81+
${consume(ctx, resultVars)}
82+
"""
83+
}
84+
85+
override protected def doExecute(): RDD[InternalRow] = {
86+
child.execute().mapPartitionsInternal { iter =>
87+
val projection = UnsafeProjection.create(serializer)
88+
iter.map(projection)
89+
}
90+
}
91+
}
92+
3093
/**
3194
* Helper functions for physical operators that work with user defined objects.
3295
*/

sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import org.apache.spark.api.java.function.MapFunction
21-
import org.apache.spark.sql.{Encoders, Row}
20+
import org.apache.spark.sql.Row
2221
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
2322
import org.apache.spark.sql.execution.joins.BroadcastHashJoin
2423
import org.apache.spark.sql.functions.{avg, broadcast, col, max}
@@ -91,4 +90,13 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
9190
p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[Filter]).isDefined)
9291
assert(ds.collect() === Array(0, 2, 4, 6, 8))
9392
}
93+
94+
test("back-to-back typed filter should be included in WholeStageCodegen") {
95+
val ds = sqlContext.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0)
96+
val plan = ds.queryExecution.executedPlan
97+
assert(plan.find(p =>
98+
p.isInstanceOf[WholeStageCodegen] &&
99+
p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[SerializeFromObject]).isDefined)
100+
assert(ds.collect() === Array(0, 6))
101+
}
94102
}

0 commit comments

Comments
 (0)