Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ object EliminateSerialization extends Rule[LogicalPlan] {
val objAttr = Alias(s.inputObjAttr, s.inputObjAttr.name)(exprId = d.outputObjAttr.exprId)
Project(objAttr :: Nil, s.child)

case a @ AppendColumns(_, _, _, s: SerializeFromObject)
case a @ AppendColumns(_, _, _, _, _, s: SerializeFromObject)
if a.deserializer.dataType == s.inputObjAttr.dataType =>
AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child)

Expand All @@ -235,7 +235,7 @@ object EliminateSerialization extends Rule[LogicalPlan] {
// deserialization in condition, and push it down through `SerializeFromObject`.
// e.g. `ds.map(...).filter(...)` can be optimized by this rule to save extra deserialization,
// but `ds.map(...).as[AnotherType].filter(...)` can not be optimized.
case f @ TypedFilter(_, _, s: SerializeFromObject)
case f @ TypedFilter(_, _, _, _, s: SerializeFromObject)
if f.deserializer.dataType == s.inputObjAttr.dataType =>
s.copy(child = f.withObjectProducerChild(s.child))

Expand Down Expand Up @@ -1719,9 +1719,14 @@ case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends Rule[Logic
*/
object CombineTypedFilters extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case t1 @ TypedFilter(_, _, t2 @ TypedFilter(_, _, child))
case t1 @ TypedFilter(_, _, _, _, t2 @ TypedFilter(_, _, _, _, child))
if t1.deserializer.dataType == t2.deserializer.dataType =>
TypedFilter(combineFilterFunction(t2.func, t1.func), t1.deserializer, child)
TypedFilter(
combineFilterFunction(t2.func, t1.func),
t1.argumentClass,
t1.argumentSchema,
t1.deserializer,
child)
}

private def combineFilterFunction(func1: AnyRef, func2: AnyRef): Any => Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ object MapElements {
val deserialized = CatalystSerde.deserialize[T](child)
val mapped = MapElements(
func,
implicitly[Encoder[T]].clsTag.runtimeClass,
implicitly[Encoder[T]].schema,
CatalystSerde.generateObjAttr[U],
deserialized)
CatalystSerde.serialize[U](mapped)
Expand All @@ -166,12 +168,19 @@ object MapElements {
*/
case class MapElements(
func: AnyRef,
argumentClass: Class[_],
argumentSchema: StructType,
outputObjAttr: Attribute,
child: LogicalPlan) extends ObjectConsumer with ObjectProducer

object TypedFilter {
def apply[T : Encoder](func: AnyRef, child: LogicalPlan): TypedFilter = {
TypedFilter(func, UnresolvedDeserializer(encoderFor[T].deserializer), child)
TypedFilter(
func,
implicitly[Encoder[T]].clsTag.runtimeClass,
implicitly[Encoder[T]].schema,
UnresolvedDeserializer(encoderFor[T].deserializer),
child)
}
}

Expand All @@ -186,6 +195,8 @@ object TypedFilter {
*/
case class TypedFilter(
func: AnyRef,
argumentClass: Class[_],
argumentSchema: StructType,
deserializer: Expression,
child: LogicalPlan) extends UnaryNode {

Expand Down Expand Up @@ -213,6 +224,8 @@ object AppendColumns {
child: LogicalPlan): AppendColumns = {
new AppendColumns(
func.asInstanceOf[Any => Any],
implicitly[Encoder[T]].clsTag.runtimeClass,
implicitly[Encoder[T]].schema,
UnresolvedDeserializer(encoderFor[T].deserializer),
encoderFor[U].namedExpressions,
child)
Expand All @@ -228,6 +241,8 @@ object AppendColumns {
*/
case class AppendColumns(
func: Any => Any,
argumentClass: Class[_],
argumentSchema: StructType,
deserializer: Expression,
serializer: Seq[NamedExpression],
child: LogicalPlan) extends UnaryNode {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,9 +356,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.FlatMapGroupsInR(f, p, b, is, os, key, value, grouping, data, objAttr, child) =>
execution.FlatMapGroupsInRExec(f, p, b, is, os, key, value, grouping,
data, objAttr, planLater(child)) :: Nil
case logical.MapElements(f, objAttr, child) =>
case logical.MapElements(f, _, _, objAttr, child) =>
execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil
case logical.AppendColumns(f, in, out, child) =>
case logical.AppendColumns(f, _, _, in, out, child) =>
execution.AppendColumnsExec(f, in, out, planLater(child)) :: Nil
case logical.AppendColumnsWithObject(f, childSer, newSer, child) =>
execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.expressions.Aggregator
////////////////////////////////////////////////////////////////////////////////////////////////////


class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double] {
class TypedSumDouble[IN](val f: IN => Double) extends Aggregator[IN, Double, Double] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why expose the f?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TypedFilter and MapElements also exposes the function closure? With this, we can apply the possible optimization describe in https://issues.apache.org/jira/browse/SPARK-16898?

override def zero: Double = 0.0
override def reduce(b: Double, a: IN): Double = b + f(a)
override def merge(b1: Double, b2: Double): Double = b1 + b2
Expand All @@ -45,7 +45,7 @@ class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double]
}


class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] {
class TypedSumLong[IN](val f: IN => Long) extends Aggregator[IN, Long, Long] {
override def zero: Long = 0L
override def reduce(b: Long, a: IN): Long = b + f(a)
override def merge(b1: Long, b2: Long): Long = b1 + b2
Expand All @@ -63,7 +63,7 @@ class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] {
}


class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] {
class TypedCount[IN](val f: IN => Any) extends Aggregator[IN, Long, Long] {
override def zero: Long = 0
override def reduce(b: Long, a: IN): Long = {
if (f(a) == null) b else b + 1
Expand All @@ -82,7 +82,7 @@ class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] {
}


class TypedAverage[IN](f: IN => Double) extends Aggregator[IN, (Double, Long), Double] {
class TypedAverage[IN](val f: IN => Double) extends Aggregator[IN, (Double, Long), Double] {
override def zero: (Double, Long) = (0.0, 0L)
override def reduce(b: (Double, Long), a: IN): (Double, Long) = (f(a) + b._1, 1 + b._2)
override def finish(reduction: (Double, Long)): Double = reduction._1 / reduction._2
Expand Down