From 5cb1a610b21e84d6a1444fb3d98d2d071f61c4b6 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 31 Aug 2015 15:43:55 +0800 Subject: [PATCH 1/4] Add SQLConf to LocalNode --- .../spark/sql/execution/local/LocalNode.scala | 127 +++++++++++++++++- .../sql/execution/local/FilterNodeSuite.scala | 4 +- .../sql/execution/local/LimitNodeSuite.scala | 4 +- .../sql/execution/local/LocalNodeTest.scala | 21 ++- .../execution/local/ProjectNodeSuite.scala | 4 +- .../sql/execution/local/UnionNodeSuite.scala | 6 +- 6 files changed, 151 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index 1c4469acbf264..6298dceaf931f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.Row +import org.apache.spark.Logging +import org.apache.spark.sql.{SQLConf, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratePredicate, GenerateMutableProjection, GenerateProjection} import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.StructType @@ -29,10 +31,18 @@ import org.apache.spark.sql.types.StructType * Before consuming the iterator, open function must be called. * After consuming the iterator, close function must be called. */ -abstract class LocalNode extends TreeNode[LocalNode] { +abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging { + + val codegenEnabled: Boolean = conf.codegenEnabled + + val unsafeEnabled: Boolean = conf.unsafeEnabled + + private[this] def isTesting: Boolean = sys.props.contains("spark.testing") def output: Seq[Attribute] + lazy val schema: StructType = StructType.fromAttributes(output) + /** * Initializes the iterator state. Must be called before calling `next()`. * @@ -57,6 +67,18 @@ abstract class LocalNode extends TreeNode[LocalNode] { */ def close(): Unit + /** Specifies whether this operator outputs UnsafeRows */ + def outputsUnsafeRows: Boolean = false + + /** Specifies whether this operator is capable of processing UnsafeRows */ + def canProcessUnsafeRows: Boolean = false + + /** + * Specifies whether this operator is capable of processing Java-object-based Rows (i.e. rows + * that are not UnsafeRows). + */ + def canProcessSafeRows: Boolean = true + /** * Returns the content of the iterator from the beginning to the end in the form of a Scala Seq. */ @@ -73,17 +95,112 @@ abstract class LocalNode extends TreeNode[LocalNode] { } result } + + protected def newMutableProjection( + expressions: Seq[Expression], + inputSchema: Seq[Attribute]): () => MutableProjection = { + log.debug( + s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") + if(codegenEnabled) { + try { + GenerateMutableProjection.generate(expressions, inputSchema) + } catch { + case e: Exception => + if (isTesting) { + throw e + } else { + log.error("Failed to generate mutable projection, fallback to interpreted", e) + () => new InterpretedMutableProjection(expressions, inputSchema) + } + } + } else { + () => new InterpretedMutableProjection(expressions, inputSchema) + } + } + + protected def newProjection( + expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { + log.debug( + s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") + if (codegenEnabled) { + try { + GenerateProjection.generate(expressions, inputSchema) + } catch { + case e: Exception => + if (isTesting) { + throw e + } else { + log.error("Failed to generate projection, fallback to interpret", e) + new InterpretedProjection(expressions, inputSchema) + } + } + } else { + new InterpretedProjection(expressions, inputSchema) + } + } + + protected def newPredicate( + expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { + if (codegenEnabled) { + try { + GeneratePredicate.generate(expression, inputSchema) + } catch { + case e: Exception => + if (isTesting) { + throw e + } else { + log.error("Failed to generate predicate, fallback to interpreted", e) + InterpretedPredicate.create(expression, inputSchema) + } + } + } else { + InterpretedPredicate.create(expression, inputSchema) + } + } + + def toIterator: Iterator[InternalRow] = new Iterator[InternalRow] { + + private var currentRow: InternalRow = null + + override def hasNext: Boolean = { + if (currentRow == null) { + if (LocalNode.this.next()) { + currentRow = fetch() + true + } else { + false + } + } else { + true + } + } + + override def next(): InternalRow = { + val r = currentRow + currentRow = null + r + } + } } -abstract class LeafLocalNode extends LocalNode { +abstract class LeafLocalNode(conf: SQLConf) extends LocalNode(conf) { override def children: Seq[LocalNode] = Seq.empty } -abstract class UnaryLocalNode extends LocalNode { +abstract class UnaryLocalNode(conf: SQLConf) extends LocalNode(conf) { def child: LocalNode override def children: Seq[LocalNode] = Seq(child) } + +abstract class BinaryLocalNode(conf: SQLConf) extends LocalNode(conf) { + + def left: LocalNode + + def right: LocalNode + + override def children: Seq[LocalNode] = Seq(left, right) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala index 07209f3779248..a12670e347c25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala @@ -25,7 +25,7 @@ class FilterNodeSuite extends LocalNodeTest with SharedSQLContext { val condition = (testData.col("key") % 2) === 0 checkAnswer( testData, - node => FilterNode(condition.expr, node), + node => FilterNode(conf, condition.expr, node), testData.filter(condition).collect() ) } @@ -34,7 +34,7 @@ class FilterNodeSuite extends LocalNodeTest with SharedSQLContext { val condition = (emptyTestData.col("key") % 2) === 0 checkAnswer( emptyTestData, - node => FilterNode(condition.expr, node), + node => FilterNode(conf, condition.expr, node), emptyTestData.filter(condition).collect() ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala index 523c02f4a6014..3b183902007e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala @@ -24,7 +24,7 @@ class LimitNodeSuite extends LocalNodeTest with SharedSQLContext { test("basic") { checkAnswer( testData, - node => LimitNode(10, node), + node => LimitNode(conf, 10, node), testData.limit(10).collect() ) } @@ -32,7 +32,7 @@ class LimitNodeSuite extends LocalNodeTest with SharedSQLContext { test("empty") { checkAnswer( emptyTestData, - node => LimitNode(10, node), + node => LimitNode(conf, 10, node), emptyTestData.limit(10).collect() ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala index 95f06081bd0a8..d78c6ec456fde 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -17,14 +17,32 @@ package org.apache.spark.sql.execution.local +import scala.util.Try import scala.util.control.NonFatal import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Row, SQLConf} import org.apache.spark.sql.test.SQLTestUtils class LocalNodeTest extends SparkFunSuite { + protected val conf = new SQLConf + + /** + * Sets all configurations specified in `pairs`, calls `f`, and then restore all configurations. + */ + protected def withConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(conf.getConfString(key)).toOption) + (keys, values).zipped.foreach(conf.setConfString) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => conf.setConfString(key, value) + case (key, None) => conf.unsetConf(key) + } + } + } + /** * Runs the LocalNode and makes sure the answer matches the expected result. * @param input the input data to be used. @@ -92,6 +110,7 @@ class LocalNodeTest extends SparkFunSuite { protected def dataFrameToSeqScanNode(df: DataFrame): SeqScanNode = { new SeqScanNode( + conf, df.queryExecution.sparkPlan.output, df.queryExecution.toRdd.map(_.copy()).collect()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala index ffcf092e2c66a..38e0a230c46d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala @@ -26,7 +26,7 @@ class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext { val columns = Seq(output(1), output(0)) checkAnswer( testData, - node => ProjectNode(columns, node), + node => ProjectNode(conf, columns, node), testData.select("value", "key").collect() ) } @@ -36,7 +36,7 @@ class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext { val columns = Seq(output(1), output(0)) checkAnswer( emptyTestData, - node => ProjectNode(columns, node), + node => ProjectNode(conf, columns, node), emptyTestData.select("value", "key").collect() ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala index 34670287c3e1d..eedd7320900f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala @@ -25,7 +25,7 @@ class UnionNodeSuite extends LocalNodeTest with SharedSQLContext { checkAnswer2( testData, testData, - (node1, node2) => UnionNode(Seq(node1, node2)), + (node1, node2) => UnionNode(conf, Seq(node1, node2)), testData.unionAll(testData).collect() ) } @@ -34,7 +34,7 @@ class UnionNodeSuite extends LocalNodeTest with SharedSQLContext { checkAnswer2( emptyTestData, emptyTestData, - (node1, node2) => UnionNode(Seq(node1, node2)), + (node1, node2) => UnionNode(conf, Seq(node1, node2)), emptyTestData.unionAll(emptyTestData).collect() ) } @@ -44,7 +44,7 @@ class UnionNodeSuite extends LocalNodeTest with SharedSQLContext { emptyTestData, emptyTestData, testData, emptyTestData) doCheckAnswer( dfs, - nodes => UnionNode(nodes), + nodes => UnionNode(conf, nodes), dfs.reduce(_.unionAll(_)).collect() ) } From 9055d8afcc114734781006f019603e89c65661bc Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 7 Sep 2015 20:06:04 +0800 Subject: [PATCH 2/4] Add local expand and NestedLoopJoin operators --- .../execution/local/ConvertToSafeNode.scala | 46 ++++ .../execution/local/ConvertToUnsafeNode.scala | 46 ++++ .../sql/execution/local/ExpandNode.scala | 62 +++++ .../sql/execution/local/FilterNode.scala | 4 +- .../spark/sql/execution/local/LimitNode.scala | 3 +- .../execution/local/NestedLoopJoinNode.scala | 158 ++++++++++++ .../sql/execution/local/ProjectNode.scala | 4 +- .../sql/execution/local/SeqScanNode.scala | 4 +- .../spark/sql/execution/local/UnionNode.scala | 3 +- .../sql/execution/local/LocalNodeTest.scala | 27 ++- .../local/NestedLoopJoinNodeSuite.scala | 225 ++++++++++++++++++ 11 files changed, 574 insertions(+), 8 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala new file mode 100644 index 0000000000000..4bf55514cad17 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala @@ -0,0 +1,46 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, FromUnsafeProjection, Projection} + +case class ConvertToSafeNode(conf: SQLConf, child: LocalNode) extends UnaryLocalNode(conf) { + + override def output: Seq[Attribute] = child.output + + override def outputsUnsafeRows: Boolean = false + + override def canProcessUnsafeRows: Boolean = true + + override def canProcessSafeRows: Boolean = false + + private[this] var convertToSafe: Projection = _ + + override def open(): Unit = { + child.open() + convertToSafe = FromUnsafeProjection(child.output.map(_.dataType)) + } + + override def next(): Boolean = child.next() + + override def fetch(): InternalRow = convertToSafe(child.fetch()) + + override def close(): Unit = child.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala new file mode 100644 index 0000000000000..b0de56fda097a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala @@ -0,0 +1,46 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Projection, UnsafeProjection} + +case class ConvertToUnsafeNode(conf: SQLConf, child: LocalNode) extends UnaryLocalNode(conf) { + + override def output: Seq[Attribute] = child.output + + override def outputsUnsafeRows: Boolean = true + + override def canProcessUnsafeRows: Boolean = false + + override def canProcessSafeRows: Boolean = true + + private[this] var convertToUnsafe: Projection = _ + + override def open(): Unit = { + child.open() + convertToUnsafe = UnsafeProjection.create(child.schema) + } + + override def next(): Boolean = child.next() + + override def fetch(): InternalRow = convertToUnsafe(child.fetch()) + + override def close(): Unit = child.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala new file mode 100644 index 0000000000000..2345b296ab30c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala @@ -0,0 +1,62 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Projection} + +case class ExpandNode( + conf: SQLConf, + projections: Seq[Seq[Expression]], + output: Seq[Attribute], + child: LocalNode) extends UnaryLocalNode(conf) { + + assert(projections.size > 0) + + private[this] var result: InternalRow = _ + private[this] var idx: Int = _ + private[this] var input: InternalRow = _ + + private[this] var groups: Array[Projection] = _ + + override def open(): Unit = { + child.open() + idx = -1 + groups = projections.map(ee => newProjection(ee, child.output)).toArray + } + + override def next(): Boolean = { + idx += 1 + if (idx < groups.length) { + result = groups(idx)(input) + true + } else if (child.next()) { + input = child.fetch() + idx = 0 + result = groups(idx)(input) + true + } else { + false + } + } + + override def fetch(): InternalRow = result + + override def close(): Unit = child.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala index 81dd37c7da733..dd1113b6726cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.execution.local +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate -case class FilterNode(condition: Expression, child: LocalNode) extends UnaryLocalNode { +case class FilterNode(conf: SQLConf, condition: Expression, child: LocalNode) + extends UnaryLocalNode(conf) { private[this] var predicate: (InternalRow) => Boolean = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala index fffc52abf6dd5..401b10a5ed307 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.execution.local +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute -case class LimitNode(limit: Int, child: LocalNode) extends UnaryLocalNode { +case class LimitNode(conf: SQLConf, limit: Int, child: LocalNode) extends UnaryLocalNode(conf) { private[this] var count = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala new file mode 100644 index 0000000000000..cabe4fcfa9681 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{FullOuter, RightOuter, LeftOuter, JoinType} +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.util.collection.BitSet +import org.apache.spark.util.collection.CompactBuffer + +case class NestedLoopJoinNode( + conf: SQLConf, + left: LocalNode, + right: LocalNode, + buildSide: BuildSide, + joinType: JoinType, + condition: Option[Expression]) extends BinaryLocalNode(conf) { + + override def output: Seq[Attribute] = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case x => + throw new IllegalArgumentException( + s"NestedLoopJoin should not take $x as the JoinType") + } + } + + private[this] def genResultProjection: InternalRow => InternalRow = { + if (outputsUnsafeRows) { + UnsafeProjection.create(schema) + } else { + identity[InternalRow] + } + } + + private[this] var currentRow: InternalRow = _ + + private[this] var iterator: Iterator[InternalRow] = _ + + override def open(): Unit = { + val (streamed, build) = buildSide match { + case BuildRight => (left, right) + case BuildLeft => (right, left) + } + build.open() + val buildRelation = new CompactBuffer[InternalRow] + while (build.next()) { + buildRelation += build.fetch().copy() + } + build.close() + + val boundCondition = + newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) + + val leftNulls = new GenericMutableRow(left.output.size) + val rightNulls = new GenericMutableRow(right.output.size) + val joinedRow = new JoinedRow + val includedBuildTuples = new BitSet(buildRelation.size) + val resultProj = genResultProjection + streamed.open() + + val matchesOrStreamedRowsWithNulls = streamed.toIterator.flatMap { streamedRow => + val matchedRows = new CompactBuffer[InternalRow] + + var i = 0 + var streamRowMatched = false + + while (i < buildRelation.size) { + val buildRow = buildRelation(i) + buildSide match { + case BuildRight if boundCondition(joinedRow(streamedRow, buildRow)) => + matchedRows += resultProj(joinedRow(streamedRow, buildRow)).copy() + streamRowMatched = true + includedBuildTuples.set(i) + case BuildLeft if boundCondition(joinedRow(buildRow, streamedRow)) => + matchedRows += resultProj(joinedRow(buildRow, streamedRow)).copy() + streamRowMatched = true + includedBuildTuples.set(i) + case _ => + } + i += 1 + } + + (streamRowMatched, joinType, buildSide) match { + case (false, LeftOuter | FullOuter, BuildRight) => + matchedRows += resultProj(joinedRow(streamedRow, rightNulls)).copy() + case (false, RightOuter | FullOuter, BuildLeft) => + matchedRows += resultProj(joinedRow(leftNulls, streamedRow)).copy() + case _ => + } + + matchedRows.iterator + } + + iterator = (joinType, buildSide) match { + case (RightOuter | FullOuter, BuildRight) => + var i = 0 + matchesOrStreamedRowsWithNulls ++ buildRelation.filter { row => + val r = !includedBuildTuples.get(i) + i += 1 + r + }.iterator.map { buildRow => + joinedRow.withLeft(leftNulls) + resultProj(joinedRow.withRight(buildRow)) + } + case (LeftOuter | FullOuter, BuildLeft) => + var i = 0 + matchesOrStreamedRowsWithNulls ++ buildRelation.filter { row => + val r = !includedBuildTuples.get(i) + i += 1 + r + }.iterator.map { buildRow => + joinedRow.withRight(rightNulls) + resultProj(joinedRow.withLeft(buildRow)) + } + case _ => matchesOrStreamedRowsWithNulls + } + } + + override def next(): Boolean = { + if (iterator.hasNext) { + currentRow = iterator.next() + true + } else { + false + } + } + + override def fetch(): InternalRow = currentRow + + override def close(): Unit = { + left.close() + right.close() + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala index 9b8a4fe493026..11529d6dd9b83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.execution.local +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, Attribute, NamedExpression} -case class ProjectNode(projectList: Seq[NamedExpression], child: LocalNode) extends UnaryLocalNode { +case class ProjectNode(conf: SQLConf, projectList: Seq[NamedExpression], child: LocalNode) + extends UnaryLocalNode(conf) { private[this] var project: UnsafeProjection = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala index 242cb66e07b7f..b8467f6ae58e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.execution.local +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute /** * An operator that scans some local data collection in the form of Scala Seq. */ -case class SeqScanNode(output: Seq[Attribute], data: Seq[InternalRow]) extends LeafLocalNode { +case class SeqScanNode(conf: SQLConf, output: Seq[Attribute], data: Seq[InternalRow]) + extends LeafLocalNode(conf) { private[this] var iterator: Iterator[InternalRow] = _ private[this] var currentRow: InternalRow = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala index ba4aa7671aebd..0f2b8303e7372 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution.local +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute -case class UnionNode(children: Seq[LocalNode]) extends LocalNode { +case class UnionNode(conf: SQLConf, children: Seq[LocalNode]) extends LocalNode(conf) { override def output: Seq[Attribute] = children.head.output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala index d78c6ec456fde..dc9c811eb3dfe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -17,14 +17,15 @@ package org.apache.spark.sql.execution.local +import scala.reflect.runtime.universe.TypeTag import scala.util.Try import scala.util.control.NonFatal import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{DataFrame, Row, SQLConf} -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row, SQLConf} +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} -class LocalNodeTest extends SparkFunSuite { +class LocalNodeTest extends SparkFunSuite with SharedSQLContext { protected val conf = new SQLConf @@ -43,6 +44,13 @@ class LocalNodeTest extends SparkFunSuite { } } + /** + * Creates a DataFrame from a local Seq of Product. + */ + implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = { + sqlContext.implicits.localSeqToDataFrameHolder(data) + } + /** * Runs the LocalNode and makes sure the answer matches the expected result. * @param input the input data to be used. @@ -115,6 +123,19 @@ class LocalNodeTest extends SparkFunSuite { df.queryExecution.toRdd.map(_.copy()).collect()) } + protected def wrapForUnsafe( + f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = { + if (conf.unsafeEnabled) { + (left: LocalNode, right: LocalNode) => { + val _left = ConvertToUnsafeNode(conf, left) + val _right = ConvertToUnsafeNode(conf, right) + val r = f(_left, _right) + ConvertToSafeNode(conf, r) + } + } else { + f + } + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala new file mode 100644 index 0000000000000..bf26a0bf54e65 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala @@ -0,0 +1,225 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.joins.BuildRight + +class NestedLoopJoinNodeSuite extends LocalNodeTest { + + import testImplicits._ + + def joinSuite(suiteName: String, confPairs: (String, String)*): Unit = { + test(s"$suiteName: left outer join") { + withConf(confPairs: _*) { + checkAnswer2( + upperCaseData, + lowerCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + BuildRight, + LeftOuter, + Some((upperCaseData.col("N") === lowerCaseData.col("n")).expr)) + ), + upperCaseData.join(lowerCaseData, $"n" === $"N", "left").collect()) + + checkAnswer2( + upperCaseData, + lowerCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + BuildRight, + LeftOuter, + Some( + (upperCaseData.col("N") === lowerCaseData.col("n") && + lowerCaseData.col("n") > 1).expr)) + ), + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"n" > 1, "left").collect()) + + checkAnswer2( + upperCaseData, + lowerCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + BuildRight, + LeftOuter, + Some( + (upperCaseData.col("N") === lowerCaseData.col("n") && + upperCaseData.col("N") > 1).expr)) + ), + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"N" > 1, "left").collect()) + + checkAnswer2( + upperCaseData, + lowerCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + BuildRight, + LeftOuter, + Some( + (upperCaseData.col("N") === lowerCaseData.col("n") && + lowerCaseData.col("l") > upperCaseData.col("L")).expr)) + ), + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"l" > $"L", "left").collect()) + } + } + + test(s"$suiteName: right outer join") { + withConf(confPairs: _*) { + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + BuildRight, + RightOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N", "right").collect()) + + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + BuildRight, + RightOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N") && + lowerCaseData.col("n") > 1).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "right").collect()) + + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + BuildRight, + RightOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N") && + upperCaseData.col("N") > 1).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "right").collect()) + + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + BuildRight, + RightOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N") && + lowerCaseData.col("l") > upperCaseData.col("L")).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "right").collect()) + } + } + + test(s"$suiteName: full outer join") { + withConf(confPairs: _*) { + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + BuildRight, + FullOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N", "full").collect()) + + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + BuildRight, + FullOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N") && + lowerCaseData.col("n") > 1).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "full").collect()) + + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + BuildRight, + FullOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N") && + upperCaseData.col("N") > 1).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "full").collect()) + + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + BuildRight, + FullOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N") && + lowerCaseData.col("l") > upperCaseData.col("L")).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "full").collect()) + } + } + } + + joinSuite( + "general", SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") + joinSuite("tungsten", SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") +} From 4b18418a06d1a34dfb1a0193545b4fd3fb3d5e8d Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 11 Sep 2015 21:34:09 +0800 Subject: [PATCH 3/4] Fix a bug and add unit tests --- .../sql/execution/local/ExpandNode.scala | 21 ++++---- .../sql/execution/local/ExpandNodeSuite.scala | 51 +++++++++++++++++++ 2 files changed, 62 insertions(+), 10 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala index 2345b296ab30c..17016f8966838 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala @@ -32,7 +32,6 @@ case class ExpandNode( private[this] var result: InternalRow = _ private[this] var idx: Int = _ private[this] var input: InternalRow = _ - private[this] var groups: Array[Projection] = _ override def open(): Unit = { @@ -42,17 +41,19 @@ case class ExpandNode( } override def next(): Boolean = { - idx += 1 - if (idx < groups.length) { - result = groups(idx)(input) - true - } else if (child.next()) { - input = child.fetch() - idx = 0 + if (idx < 0 || idx >= groups.length) { + if (child.next()) { + input = child.fetch() + result = groups(0)(input) + idx = 1 + true + } else { + false + } + } else { result = groups(idx)(input) + idx += 1 true - } else { - false } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala new file mode 100644 index 0000000000000..cfa7f3f6dcb97 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala @@ -0,0 +1,51 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +class ExpandNodeSuite extends LocalNodeTest { + + import testImplicits._ + + test("expand") { + val input = Seq((1, 1), (2, 2), (3, 3), (4, 4), (5, 5)).toDF("key", "value") + checkAnswer( + input, + node => + ExpandNode(conf, Seq( + Seq( + input.col("key") + input.col("value"), input.col("key") - input.col("value") + ).map(_.expr), + Seq( + input.col("key") * input.col("value"), input.col("key") / input.col("value") + ).map(_.expr) + ), node.output, node), + Seq( + (2, 0), + (1, 1), + (4, 0), + (4, 1), + (6, 0), + (9, 1), + (8, 0), + (16, 1), + (10, 0), + (25, 1) + ).toDF().collect() + ) + } +} From c6e80a2f0fa71dc788a754dd9d0f7e8e89bab56f Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 14 Sep 2015 18:54:59 +0800 Subject: [PATCH 4/4] Address Andrew's comments --- .../sql/execution/local/ExpandNode.scala | 17 ++--- .../spark/sql/execution/local/LocalNode.scala | 11 +-- .../execution/local/NestedLoopJoinNode.scala | 74 +++++++++---------- .../local/NestedLoopJoinNodeSuite.scala | 46 ++++++++---- 4 files changed, 77 insertions(+), 71 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala index 17016f8966838..2aff156d18b54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala @@ -36,25 +36,22 @@ case class ExpandNode( override def open(): Unit = { child.open() - idx = -1 groups = projections.map(ee => newProjection(ee, child.output)).toArray + idx = groups.length } override def next(): Boolean = { - if (idx < 0 || idx >= groups.length) { + if (idx >= groups.length) { if (child.next()) { input = child.fetch() - result = groups(0)(input) - idx = 1 - true + idx = 0 } else { - false + return false } - } else { - result = groups(idx)(input) - idx += 1 - true } + result = groups(idx)(input) + idx += 1 + true } override def fetch(): InternalRow = result diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index a523d977f7f19..9840080e16953 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -86,11 +86,6 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging */ final def asIterator: Iterator[InternalRow] = new LocalNodeIterator(this) - /** - * Returns the content through the [[Iterator]] interface. - */ - final def asIterator: Iterator[InternalRow] = new LocalNodeIterator(this) - /** * Returns the content of the iterator from the beginning to the end in the form of a Scala Seq. */ @@ -109,7 +104,8 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging } protected def newProjection( - expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { + expressions: Seq[Expression], + inputSchema: Seq[Attribute]): Projection = { log.debug( s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") if (codegenEnabled) { @@ -152,7 +148,8 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging } protected def newPredicate( - expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { + expression: Expression, + inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { if (codegenEnabled) { try { GeneratePredicate.generate(expression, inputSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala index 62ebcdbda957f..7321fc66b4dde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala @@ -22,8 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{FullOuter, RightOuter, LeftOuter, JoinType} import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} -import org.apache.spark.util.collection.BitSet -import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.util.collection.{BitSet, CompactBuffer} case class NestedLoopJoinNode( conf: SQLConf, @@ -77,65 +76,64 @@ case class NestedLoopJoinNode( val leftNulls = new GenericMutableRow(left.output.size) val rightNulls = new GenericMutableRow(right.output.size) val joinedRow = new JoinedRow - val includedBuildTuples = new BitSet(buildRelation.size) + val matchedBuildTuples = new BitSet(buildRelation.size) val resultProj = genResultProjection streamed.open() - val matchesOrStreamedRowsWithNulls = streamed.asIterator.flatMap { streamedRow => + // streamedRowMatches also contains null rows if using outer join + val streamedRowMatches: Iterator[InternalRow] = streamed.asIterator.flatMap { streamedRow => val matchedRows = new CompactBuffer[InternalRow] var i = 0 var streamRowMatched = false + // Scan the build relation to look for matches for each streamed row while (i < buildRelation.size) { val buildRow = buildRelation(i) buildSide match { - case BuildRight if boundCondition(joinedRow(streamedRow, buildRow)) => - matchedRows += resultProj(joinedRow(streamedRow, buildRow)).copy() - streamRowMatched = true - includedBuildTuples.set(i) - case BuildLeft if boundCondition(joinedRow(buildRow, streamedRow)) => - matchedRows += resultProj(joinedRow(buildRow, streamedRow)).copy() - streamRowMatched = true - includedBuildTuples.set(i) - case _ => + case BuildRight => joinedRow(streamedRow, buildRow) + case BuildLeft => joinedRow(buildRow, streamedRow) + } + if (boundCondition(joinedRow)) { + matchedRows += resultProj(joinedRow).copy() + streamRowMatched = true + matchedBuildTuples.set(i) } i += 1 } - (streamRowMatched, joinType, buildSide) match { - case (false, LeftOuter | FullOuter, BuildRight) => - matchedRows += resultProj(joinedRow(streamedRow, rightNulls)).copy() - case (false, RightOuter | FullOuter, BuildLeft) => - matchedRows += resultProj(joinedRow(leftNulls, streamedRow)).copy() - case _ => + // If this row had no matches and we're using outer join, join it with the null rows + if (!streamRowMatched) { + (joinType, buildSide) match { + case (LeftOuter | FullOuter, BuildRight) => + matchedRows += resultProj(joinedRow(streamedRow, rightNulls)).copy() + case (RightOuter | FullOuter, BuildLeft) => + matchedRows += resultProj(joinedRow(leftNulls, streamedRow)).copy() + case _ => + } } matchedRows.iterator } + // If we're using outer join, find rows on the build side that didn't match anything + // and join them with the null row + lazy val unmatchedBuildRows: Iterator[InternalRow] = { + var i = 0 + buildRelation.filter { row => + val r = !matchedBuildTuples.get(i) + i += 1 + r + }.iterator + } iterator = (joinType, buildSide) match { case (RightOuter | FullOuter, BuildRight) => - var i = 0 - matchesOrStreamedRowsWithNulls ++ buildRelation.filter { row => - val r = !includedBuildTuples.get(i) - i += 1 - r - }.iterator.map { buildRow => - joinedRow.withLeft(leftNulls) - resultProj(joinedRow.withRight(buildRow)) - } + streamedRowMatches ++ + unmatchedBuildRows.map { buildRow => resultProj(joinedRow(leftNulls, buildRow)) } case (LeftOuter | FullOuter, BuildLeft) => - var i = 0 - matchesOrStreamedRowsWithNulls ++ buildRelation.filter { row => - val r = !includedBuildTuples.get(i) - i += 1 - r - }.iterator.map { buildRow => - joinedRow.withRight(rightNulls) - resultProj(joinedRow.withLeft(buildRow)) - } - case _ => matchesOrStreamedRowsWithNulls + streamedRowMatches ++ + unmatchedBuildRows.map { buildRow => resultProj(joinedRow(buildRow, rightNulls)) } + case _ => streamedRowMatches } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala index 1743a8a98b2e4..b1ef26ba82f16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.execution.local import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.joins.BuildRight +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} class NestedLoopJoinNodeSuite extends LocalNodeTest { import testImplicits._ - def joinSuite(suiteName: String, confPairs: (String, String)*): Unit = { + private def joinSuite( + suiteName: String, buildSide: BuildSide, confPairs: (String, String)*): Unit = { test(s"$suiteName: left outer join") { withSQLConf(confPairs: _*) { checkAnswer2( @@ -36,7 +37,7 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest { conf, node1, node2, - BuildRight, + buildSide, LeftOuter, Some((upperCaseData.col("N") === lowerCaseData.col("n")).expr)) ), @@ -50,7 +51,7 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest { conf, node1, node2, - BuildRight, + buildSide, LeftOuter, Some( (upperCaseData.col("N") === lowerCaseData.col("n") && @@ -66,7 +67,7 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest { conf, node1, node2, - BuildRight, + buildSide, LeftOuter, Some( (upperCaseData.col("N") === lowerCaseData.col("n") && @@ -82,7 +83,7 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest { conf, node1, node2, - BuildRight, + buildSide, LeftOuter, Some( (upperCaseData.col("N") === lowerCaseData.col("n") && @@ -102,7 +103,7 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest { conf, node1, node2, - BuildRight, + buildSide, RightOuter, Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr)) ), @@ -116,7 +117,7 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest { conf, node1, node2, - BuildRight, + buildSide, RightOuter, Some((lowerCaseData.col("n") === upperCaseData.col("N") && lowerCaseData.col("n") > 1).expr)) @@ -131,7 +132,7 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest { conf, node1, node2, - BuildRight, + buildSide, RightOuter, Some((lowerCaseData.col("n") === upperCaseData.col("N") && upperCaseData.col("N") > 1).expr)) @@ -146,7 +147,7 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest { conf, node1, node2, - BuildRight, + buildSide, RightOuter, Some((lowerCaseData.col("n") === upperCaseData.col("N") && lowerCaseData.col("l") > upperCaseData.col("L")).expr)) @@ -165,7 +166,7 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest { conf, node1, node2, - BuildRight, + buildSide, FullOuter, Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr)) ), @@ -179,7 +180,7 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest { conf, node1, node2, - BuildRight, + buildSide, FullOuter, Some((lowerCaseData.col("n") === upperCaseData.col("N") && lowerCaseData.col("n") > 1).expr)) @@ -194,7 +195,7 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest { conf, node1, node2, - BuildRight, + buildSide, FullOuter, Some((lowerCaseData.col("n") === upperCaseData.col("N") && upperCaseData.col("N") > 1).expr)) @@ -209,7 +210,7 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest { conf, node1, node2, - BuildRight, + buildSide, FullOuter, Some((lowerCaseData.col("n") === upperCaseData.col("N") && lowerCaseData.col("l") > upperCaseData.col("L")).expr)) @@ -220,6 +221,19 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest { } joinSuite( - "general", SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") - joinSuite("tungsten", SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") + "general-build-left", + BuildLeft, + SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") + joinSuite( + "general-build-right", + BuildRight, + SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") + joinSuite( + "tungsten-build-left", + BuildLeft, + SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") + joinSuite( + "tungsten-build-right", + BuildRight, + SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") }