Skip to content
Closed
9 changes: 6 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -921,12 +921,15 @@ class SQLContext(@transient val sparkContext: SparkContext)
protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[InternalRow], 1)

/**
* Prepares a planned SparkPlan for execution by inserting shuffle operations as needed.
* Prepares a planned SparkPlan for execution by inserting shuffle operations and internal
* row format conversions as needed.
*/
@transient
protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] {
val batches =
Batch("Add exchange", Once, EnsureRequirements(self)) :: Nil
val batches = Seq(
Batch("Add exchange", Once, EnsureRequirements(self)),
Batch("Add row converters", Once, EnsureRowFormats)
)
}

protected[sql] def openSession(): SQLSession = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,36 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/** Specifies sort order for each partition requirements on the input data for this operator. */
def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)

/** Specifies whether this operator outputs UnsafeRows */
def outputsUnsafeRows: Boolean = false

/** Specifies whether this operator is capable of processing UnsafeRows */
def canProcessUnsafeRows: Boolean = false

/**
* Specifies whether this operator is capable of processing Java-object-based Rows (i.e. rows
* that are not UnsafeRows).
*/
def canProcessSafeRows: Boolean = true

/**
* Returns the result of this query as an RDD[InternalRow] by delegating to doExecute
* after adding query plan information to created RDDs for visualization.
* Concrete implementations of SparkPlan should override doExecute instead.
*/
final def execute(): RDD[InternalRow] = {
if (children.nonEmpty) {
val hasUnsafeInputs = children.exists(_.outputsUnsafeRows)
val hasSafeInputs = children.exists(!_.outputsUnsafeRows)
assert(!(hasSafeInputs && hasUnsafeInputs),
"Child operators should output rows in the same format")
assert(canProcessSafeRows || canProcessUnsafeRows,
"Operator must be able to process at least one row format")
assert(!hasSafeInputs || canProcessSafeRows,
"Operator will receive safe rows as input but cannot process safe rows")
assert(!hasUnsafeInputs || canProcessUnsafeRows,
"Operator will receive unsafe rows as input but cannot process unsafe rows")
}
RDDOperationScope.withScope(sparkContext, nodeName, false, true) {
doExecute()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
}

override def outputOrdering: Seq[SortOrder] = child.outputOrdering

override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows

override def canProcessUnsafeRows: Boolean = true

override def canProcessSafeRows: Boolean = true
}

/**
Expand Down Expand Up @@ -104,6 +110,9 @@ case class Sample(
case class Union(children: Seq[SparkPlan]) extends SparkPlan {
// TODO: attributes output by union should be distinct for nullability purposes
override def output: Seq[Attribute] = children.head.output
override def outputsUnsafeRows: Boolean = children.forall(_.outputsUnsafeRows)
override def canProcessUnsafeRows: Boolean = true
override def canProcessSafeRows: Boolean = true
protected override def doExecute(): RDD[InternalRow] =
sparkContext.union(children.map(_.execute()))
}
Expand Down Expand Up @@ -306,6 +315,8 @@ case class UnsafeExternalSort(
override def output: Seq[Attribute] = child.output

override def outputOrdering: Seq[SortOrder] = sortOrder

override def outputsUnsafeRows: Boolean = true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about filter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, meant to change this. Thanks for reminding me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also add an assertion to our set operations that the inputs are the same type of row?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or actually, maybe we can add a general assertion to the execute method of SparkPlan?

assert(children.map(_.outputsUnsafeRows).distinct <= 1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's add it to execute I think. I'll do this shortly.

}

@DeveloperApi
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.rules.Rule

/**
* :: DeveloperApi ::
* Converts Java-object-based rows into [[UnsafeRow]]s.
*/
@DeveloperApi
case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
override def outputsUnsafeRows: Boolean = true
override def canProcessUnsafeRows: Boolean = false
override def canProcessSafeRows: Boolean = true
override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitions { iter =>
val convertToUnsafe = UnsafeProjection.create(child.schema)
iter.map(convertToUnsafe)
}
}
}

/**
* :: DeveloperApi ::
* Converts [[UnsafeRow]]s back into Java-object-based rows.
*/
@DeveloperApi
case class ConvertToSafe(child: SparkPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
override def outputsUnsafeRows: Boolean = false
override def canProcessUnsafeRows: Boolean = true
override def canProcessSafeRows: Boolean = false
override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitions { iter =>
val convertToSafe = FromUnsafeProjection(child.output.map(_.dataType))
iter.map(convertToSafe)
}
}
}

private[sql] object EnsureRowFormats extends Rule[SparkPlan] {

private def onlyHandlesSafeRows(operator: SparkPlan): Boolean =
operator.canProcessSafeRows && !operator.canProcessUnsafeRows

private def onlyHandlesUnsafeRows(operator: SparkPlan): Boolean =
operator.canProcessUnsafeRows && !operator.canProcessSafeRows

private def handlesBothSafeAndUnsafeRows(operator: SparkPlan): Boolean =
operator.canProcessSafeRows && operator.canProcessUnsafeRows

override def apply(operator: SparkPlan): SparkPlan = operator.transformUp {
case operator: SparkPlan if onlyHandlesSafeRows(operator) =>
if (operator.children.exists(_.outputsUnsafeRows)) {
operator.withNewChildren {
operator.children.map {
c => if (c.outputsUnsafeRows) ConvertToSafe(c) else c
}
}
} else {
operator
}
case operator: SparkPlan if onlyHandlesUnsafeRows(operator) =>
if (operator.children.exists(!_.outputsUnsafeRows)) {
operator.withNewChildren {
operator.children.map {
c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c
}
}
} else {
operator
}
case operator: SparkPlan if handlesBothSafeAndUnsafeRows(operator) =>
if (operator.children.map(_.outputsUnsafeRows).toSet.size != 1) {
// If this operator's children produce both unsafe and safe rows, then convert everything
// to unsafe rows
operator.withNewChildren {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't it make more sense to convert to unsafe instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think so. I think that choosing to resolve this type of conflict in favor of UnsafeRow should be fine: if unsafe operators are disabled via a feature flag, then the plan shouldn't contain any operators which claim to output unsafe rows so this branch will never be triggered.

I'll update this patch to change this logic.

operator.children.map {
c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c
}
}
} else {
operator
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.IsNull
import org.apache.spark.sql.test.TestSQLContext

class RowFormatConvertersSuite extends SparkPlanTest {

private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect {
case c: ConvertToUnsafe => c
case c: ConvertToSafe => c
}

private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null))
assert(!outputsSafe.outputsUnsafeRows)
private val outputsUnsafe = UnsafeExternalSort(Nil, false, PhysicalRDD(Seq.empty, null))
assert(outputsUnsafe.outputsUnsafeRows)

test("planner should insert unsafe->safe conversions when required") {
val plan = Limit(10, outputsUnsafe)
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe])
}

test("filter can process unsafe rows") {
val plan = Filter(IsNull(null), outputsUnsafe)
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
assert(getConverters(preparedPlan).isEmpty)
assert(preparedPlan.outputsUnsafeRows)
}

test("filter can process safe rows") {
val plan = Filter(IsNull(null), outputsSafe)
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
assert(getConverters(preparedPlan).isEmpty)
assert(!preparedPlan.outputsUnsafeRows)
}

test("execute() fails an assertion if inputs rows are of different formats") {
val e = intercept[AssertionError] {
Union(Seq(outputsSafe, outputsUnsafe)).execute()
}
assert(e.getMessage.contains("format"))
}

test("union requires all of its input rows' formats to agree") {
val plan = Union(Seq(outputsSafe, outputsUnsafe))
assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows)
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
assert(preparedPlan.outputsUnsafeRows)
}

test("union can process safe rows") {
val plan = Union(Seq(outputsSafe, outputsSafe))
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
assert(!preparedPlan.outputsUnsafeRows)
}

test("union can process unsafe rows") {
val plan = Union(Seq(outputsUnsafe, outputsUnsafe))
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
assert(preparedPlan.outputsUnsafeRows)
}

test("round trip with ConvertToUnsafe and ConvertToSafe") {
val input = Seq(("hello", 1), ("world", 2))
checkAnswer(
TestSQLContext.createDataFrame(input),
plan => ConvertToSafe(ConvertToUnsafe(plan)),
input.map(Row.fromTuple)
)
}
}