Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1599,6 +1599,8 @@ object CleanupAliases extends Rule[LogicalPlan] {
// Operators that operate on objects should only have expressions from encoders, which should
// never have extra aliases.
case o: ObjectOperator => o
case d: DeserializeToObject => d
case s: SerializeFromObject => s

case other =>
var stop = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp}

import scala.language.implicitConversions

import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
Expand Down Expand Up @@ -166,6 +167,14 @@ package object dsl {
case target => UnresolvedStar(Option(target))
}

def callFunction[T, U](
func: T => U,
returnType: DataType,
argument: Expression): Expression = {
val function = Literal.create(func, ObjectType(classOf[T => U]))
Invoke(function, "apply", returnType, argument :: Nil)
}

implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name }
// TODO more implicit class for literal?
implicit class DslString(val s: String) extends ImplicitOperators {
Expand Down Expand Up @@ -270,6 +279,16 @@ package object dsl {

def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan)

def filter[T : Encoder](func: T => Boolean): LogicalPlan = {
val deserialized = logicalPlan.deserialize[T]
val condition = expressions.callFunction(func, BooleanType, deserialized.output.head)
Filter(condition, deserialized).serialize[T]
}

def serialize[T : Encoder]: LogicalPlan = CatalystSerde.serialize[T](logicalPlan)

def deserialize[T : Encoder]: LogicalPlan = CatalystSerde.deserialize[T](logicalPlan)

def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan)

def join(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
EliminateSerialization) ::
Batch("Decimal Optimizations", FixedPoint(100),
DecimalAggregates) ::
Batch("Typed Filter Optimization", FixedPoint(100),
EmbedSerializerInFilter) ::
Batch("LocalRelation", FixedPoint(100),
ConvertToLocalRelation) ::
Batch("Subquery", Once,
Expand Down Expand Up @@ -147,12 +149,18 @@ object EliminateSerialization extends Rule[LogicalPlan] {
child = childWithoutSerialization)

case m @ MapElements(_, deserializer, _, child: ObjectOperator)
if !deserializer.isInstanceOf[Attribute] &&
deserializer.dataType == child.outputObject.dataType =>
if !deserializer.isInstanceOf[Attribute] &&
deserializer.dataType == child.outputObject.dataType =>
val childWithoutSerialization = child.withObjectOutput
m.copy(
deserializer = childWithoutSerialization.output.head,
child = childWithoutSerialization)

case d @ DeserializeToObject(_, s: SerializeFromObject)
if d.outputObjectType == s.inputObjectType =>
// Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId)
Project(objAttr :: Nil, s.child)
}
}

Expand Down Expand Up @@ -1329,3 +1337,30 @@ object ComputeCurrentTime extends Rule[LogicalPlan] {
}
}
}

/**
* Typed [[Filter]] is by default surrounded by a [[DeserializeToObject]] beneath it and a
* [[SerializeFromObject]] above it. If these serializations can't be eliminated, we should embed
* the deserializer in filter condition to save the extra serialization at last.
*/
object EmbedSerializerInFilter extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case s @ SerializeFromObject(_, Filter(condition, d: DeserializeToObject)) =>
val numObjects = condition.collect {
case a: Attribute if a == d.output.head => a
}.length

if (numObjects > 1) {
// If the filter condition references the object more than one times, we should not embed
// deserializer in it as the deserialization will happen many times and slow down the
// execution.
// TODO: we can still embed it if we can make sure subexpression elimination works here.
s
} else {
val newCondition = condition transform {
case a: Attribute if a == d.output.head => d.deserializer.child
}
Filter(newCondition, d.child)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,42 @@ import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{ObjectType, StructType}
import org.apache.spark.sql.types.{DataType, ObjectType, StructType}

object CatalystSerde {
def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = {
val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer)
DeserializeToObject(Alias(deserializer, "obj")(), child)
}

def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = {
SerializeFromObject(encoderFor[T].namedExpressions, child)
}
}

/**
* Takes the input row from child and turns it into object using the given deserializer expression.
* The output of this operator is a single-field safe row containing the deserialized object.
*/
case class DeserializeToObject(
deserializer: Alias,
child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = deserializer.toAttribute :: Nil

def outputObjectType: DataType = deserializer.dataType
}

/**
* Takes the input object from child and turns in into unsafe row using the given serializer
* expression. The output of its child must be a single-field row containing the input object.
*/
case class SerializeFromObject(
serializer: Seq[NamedExpression],
child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = serializer.map(_.toAttribute)

def inputObjectType: DataType = child.output.head.dataType
}

/**
* A trait for logical operators that apply user defined functions to domain objects.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.BooleanType

class TypedFilterOptimizationSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("EliminateSerialization", FixedPoint(50),
EliminateSerialization) ::
Batch("EmbedSerializerInFilter", FixedPoint(50),
EmbedSerializerInFilter) :: Nil
}

implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]()

test("back to back filter") {
val input = LocalRelation('_1.int, '_2.int)
val f1 = (i: (Int, Int)) => i._1 > 0
val f2 = (i: (Int, Int)) => i._2 > 0

val query = input.filter(f1).filter(f2).analyze

val optimized = Optimize.execute(query)

val expected = input.deserialize[(Int, Int)]
.where(callFunction(f1, BooleanType, 'obj))
.select('obj.as("obj"))
.where(callFunction(f2, BooleanType, 'obj))
.serialize[(Int, Int)].analyze

comparePlans(optimized, expected)
}

test("embed deserializer in filter condition if there is only one filter") {
val input = LocalRelation('_1.int, '_2.int)
val f = (i: (Int, Int)) => i._1 > 0

val query = input.filter(f).analyze

val optimized = Optimize.execute(query)

val deserializer = UnresolvedDeserializer(encoderFor[(Int, Int)].deserializer)
val condition = callFunction(f, BooleanType, deserializer)
val expected = input.where(condition).analyze

comparePlans(optimized, expected)
}
}
16 changes: 14 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1879,7 +1879,13 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func))
def filter(func: T => Boolean): Dataset[T] = {
val deserialized = CatalystSerde.deserialize[T](logicalPlan)
val function = Literal.create(func, ObjectType(classOf[T => Boolean]))
val condition = Invoke(function, "apply", BooleanType, deserialized.output)
val filter = Filter(condition, deserialized)
withTypedPlan(CatalystSerde.serialize[T](filter))
}

/**
* :: Experimental ::
Expand All @@ -1890,7 +1896,13 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
def filter(func: FilterFunction[T]): Dataset[T] = filter(t => func.call(t))
def filter(func: FilterFunction[T]): Dataset[T] = {
val deserialized = CatalystSerde.deserialize[T](logicalPlan)
val function = Literal.create(func, ObjectType(classOf[FilterFunction[T]]))
val condition = Invoke(function, "call", BooleanType, deserialized.output)
val filter = Filter(condition, deserialized)
withTypedPlan(CatalystSerde.serialize[T](filter))
}

/**
* :: Experimental ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
throw new IllegalStateException(
"logical intersect operator should have been replaced by semi-join in the optimizer")

case logical.DeserializeToObject(deserializer, child) =>
execution.DeserializeToObject(deserializer, planLater(child)) :: Nil
case logical.SerializeFromObject(serializer, child) =>
execution.SerializeFromObject(serializer, planLater(child)) :: Nil
case logical.MapPartitions(f, in, out, child) =>
execution.MapPartitions(f, in, out, planLater(child)) :: Nil
case logical.MapElements(f, in, out, child) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,73 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.types.ObjectType

/**
* Takes the input row from child and turns it into object using the given deserializer expression.
* The output of this operator is a single-field safe row containing the deserialized object.
*/
case class DeserializeToObject(
deserializer: Alias,
child: SparkPlan) extends UnaryNode with CodegenSupport {
override def output: Seq[Attribute] = deserializer.toAttribute :: Nil

override def upstreams(): Seq[RDD[InternalRow]] = {
child.asInstanceOf[CodegenSupport].upstreams()
}

protected override def doProduce(ctx: CodegenContext): String = {
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val bound = ExpressionCanonicalizer.execute(
BindReferences.bindReference(deserializer, child.output))
ctx.currentVars = input
val resultVars = bound.gen(ctx) :: Nil
consume(ctx, resultVars)
}

override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsInternal { iter =>
val projection = GenerateSafeProjection.generate(deserializer :: Nil, child.output)
iter.map(projection)
}
}
}

/**
* Takes the input object from child and turns in into unsafe row using the given serializer
* expression. The output of its child must be a single-field row containing the input object.
*/
case class SerializeFromObject(
serializer: Seq[NamedExpression],
child: SparkPlan) extends UnaryNode with CodegenSupport {
override def output: Seq[Attribute] = serializer.map(_.toAttribute)

override def upstreams(): Seq[RDD[InternalRow]] = {
child.asInstanceOf[CodegenSupport].upstreams()
}

protected override def doProduce(ctx: CodegenContext): String = {
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val bound = serializer.map { expr =>
ExpressionCanonicalizer.execute(BindReferences.bindReference(expr, child.output))
}
ctx.currentVars = input
val resultVars = bound.map(_.gen(ctx))
consume(ctx, resultVars)
}

override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsInternal { iter =>
val projection = UnsafeProjection.create(serializer)
iter.map(projection)
}
}
}

/**
* Helper functions for physical operators that work with user defined objects.
*/
Expand Down
Loading