diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index d78be5a5958f9..6dd555959e2b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -132,7 +132,7 @@ class GenericMutableRow(val values: Array[Any]) extends MutableRow with ArrayBac override def copy(): InternalRow = new GenericInternalRow(values.clone()) } -class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { +class InterpretedOrdering private (ordering: Seq[SortOrder]) extends Ordering[InternalRow] { def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = this(ordering.map(BindReferences.bindReference(_, inputSchema))) @@ -164,10 +164,3 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { return 0 } } - -object RowOrdering { - def forSchema(dataTypes: Seq[DataType]): RowOrdering = - new RowOrdering(dataTypes.zipWithIndex.map { - case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) - }) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index feea4f239c04d..a5a6885fe81f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.DataType -import org.apache.spark.util.MutablePair +import org.apache.spark.util.{Utils, MutablePair} import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} /** @@ -175,8 +175,21 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una val mutablePair = new MutablePair[InternalRow, Null]() iter.map(row => mutablePair.update(row.copy(), null)) } - // TODO: RangePartitioner should take an Ordering. - implicit val ordering = new RowOrdering(sortingExpressions, child.output) + // This wrapper works around the fact that generated orderings are not Serializable. + // Normally we do not run into this problem because the code generation is performed on + // the executors, but Spark's RangePartitioner requires a Serializable Ordering to be + // created on the driver. This wrapper is an easy workaround to let us use generated + // orderings here without having to rewrite or modify RangePartitioner. + implicit val ordering = new Ordering[InternalRow] { + @transient var _ordering = buildOrdering() + override def compare(x: InternalRow, y: InternalRow): Int = _ordering.compare(x, y) + def buildOrdering(): Ordering[InternalRow] = + newOrdering(sortingExpressions, child.output) + private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException { + in.defaultReadObject() + _ordering = buildOrdering() + } + } new RangePartitioner(numPartitions, rddForSampling, ascending = true) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 632f633d82a2e..b691cad218473 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -229,11 +229,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ throw e } else { log.error("Failed to generate ordering, fallback to interpreted", e) - new RowOrdering(order, inputSchema) + new InterpretedOrdering(order, inputSchema) } } } else { - new RowOrdering(order, inputSchema) + new InterpretedOrdering(order, inputSchema) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 6e127e548a120..4d5356a48557d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -119,10 +119,8 @@ case class Window( // Although input rows are grouped based on windowSpec.partitionSpec, we need to // know when we have a new partition. // This is to manually construct an ordering that can be used to compare rows. - // TODO: We may want to have a newOrdering that takes BoundReferences. - // So, we can take advantave of code gen. private val partitionOrdering: Ordering[InternalRow] = - RowOrdering.forSchema(windowSpec.partitionSpec.map(_.dataType)) + newOrdering(windowSpec.partitionSpec.map(SortOrder(_, Ascending)), child.output) // This is used to project expressions for the partition specification. protected val partitionGenerator = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 4c063c299ba53..5e2ec140ba4f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.util.collection.ExternalSorter import org.apache.spark.util.collection.unsafe.sort.PrefixComparator -import org.apache.spark.util.{CompletionIterator, MutablePair} +import org.apache.spark.util.{CompletionIterator, MutablePair, Utils} import org.apache.spark.{HashPartitioner, SparkEnv} /** @@ -167,7 +167,24 @@ case class TakeOrderedAndProject( override def outputPartitioning: Partitioning = SinglePartition - private val ord: RowOrdering = new RowOrdering(sortOrder, child.output) + private val ord: Ordering[InternalRow] = { + // This wrapper works around the fact that generated orderings are not Serializable. + // Normally we do not run into this problem because the code generation is performed on + // the executors, but Spark's takeOrdered requires a Serializable Ordering to be + // created on the driver. This wrapper is an easy workaround to let us use generated + // orderings here without having to rewrite or modify takeOrdered. + val schema = child.output + val sortOrderCopy = sortOrder + new Ordering[InternalRow] { + @transient var _ordering = buildOrdering() + override def compare(x: InternalRow, y: InternalRow): Int = _ordering.compare(x, y) + def buildOrdering(): Ordering[InternalRow] = newOrdering(sortOrderCopy, schema) + private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException { + in.defaultReadObject() + _ordering = buildOrdering() + } + } + } // TODO: remove @transient after figure out how to clean closure at InsertIntoHiveTable. @transient private val projection = projectList.map(new InterpretedProjection(_, child.output)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 981447eacad74..a44198f7e5a27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -45,8 +45,9 @@ case class SortMergeJoin( override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - // this is to manually construct an ordering that can be used to compare keys from both sides - private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType)) + // Construct an ordering that can be used to compare keys from both sides + private val keyOrdering: Ordering[InternalRow] = + newOrdering(requiredOrders(leftKeys), left.output) override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys)