Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Projection}

case class ExpandNode(
conf: SQLConf,
projections: Seq[Seq[Expression]],
output: Seq[Attribute],
child: LocalNode) extends UnaryLocalNode(conf) {

assert(projections.size > 0)

private[this] var result: InternalRow = _
private[this] var idx: Int = _
private[this] var input: InternalRow = _
private[this] var groups: Array[Projection] = _

override def open(): Unit = {
child.open()
groups = projections.map(ee => newProjection(ee, child.output)).toArray
idx = groups.length
}

override def next(): Boolean = {
if (idx >= groups.length) {
if (child.next()) {
input = child.fetch()
idx = 0
} else {
return false
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about:

if (idx < 0 || idx >= groups.length) {
  if (child.next()) {
    input = child.fetch()
    idx = 0
  } else {
    return false
  }
}
result = groups(idx)(input)
idx += 1
true

a little less duplication

result = groups(idx)(input)
idx += 1
true
}

override def fetch(): InternalRow = result

override def close(): Unit = child.close()
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.Logging
import org.apache.spark.sql.{SQLConf, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -69,6 +69,18 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging
*/
def close(): Unit

/** Specifies whether this operator outputs UnsafeRows */
def outputsUnsafeRows: Boolean = false

/** Specifies whether this operator is capable of processing UnsafeRows */
def canProcessUnsafeRows: Boolean = false

/**
* Specifies whether this operator is capable of processing Java-object-based Rows (i.e. rows
* that are not UnsafeRows).
*/
def canProcessSafeRows: Boolean = true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how many of these can be protected?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These methods will be used out of LocalNode, such as building a LocalNode tree. So I didn't use protected for them.


/**
* Returns the content through the [[Iterator]] interface.
*/
Expand All @@ -91,6 +103,28 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging
result
}

protected def newProjection(
expressions: Seq[Expression],
inputSchema: Seq[Attribute]): Projection = {
log.debug(
s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
if (codegenEnabled) {
try {
GenerateProjection.generate(expressions, inputSchema)
} catch {
case NonFatal(e) =>
if (isTesting) {
throw e
} else {
log.error("Failed to generate projection, fallback to interpret", e)
new InterpretedProjection(expressions, inputSchema)
}
}
} else {
new InterpretedProjection(expressions, inputSchema)
}
}

protected def newMutableProjection(
expressions: Seq[Expression],
inputSchema: Seq[Attribute]): () => MutableProjection = {
Expand All @@ -113,6 +147,25 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging
}
}

protected def newPredicate(
expression: Expression,
inputSchema: Seq[Attribute]): (InternalRow) => Boolean = {
if (codegenEnabled) {
try {
GeneratePredicate.generate(expression, inputSchema)
} catch {
case NonFatal(e) =>
if (isTesting) {
throw e
} else {
log.error("Failed to generate predicate, fallback to interpreted", e)
InterpretedPredicate.create(expression, inputSchema)
}
}
} else {
InterpretedPredicate.create(expression, inputSchema)
}
}
}


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{FullOuter, RightOuter, LeftOuter, JoinType}
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.util.collection.{BitSet, CompactBuffer}

case class NestedLoopJoinNode(
conf: SQLConf,
left: LocalNode,
right: LocalNode,
buildSide: BuildSide,
joinType: JoinType,
condition: Option[Expression]) extends BinaryLocalNode(conf) {

override def output: Seq[Attribute] = {
joinType match {
case LeftOuter =>
left.output ++ right.output.map(_.withNullability(true))
case RightOuter =>
left.output.map(_.withNullability(true)) ++ right.output
case FullOuter =>
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
case x =>
throw new IllegalArgumentException(
s"NestedLoopJoin should not take $x as the JoinType")
}
}

private[this] def genResultProjection: InternalRow => InternalRow = {
if (outputsUnsafeRows) {
UnsafeProjection.create(schema)
} else {
identity[InternalRow]
}
}

private[this] var currentRow: InternalRow = _

private[this] var iterator: Iterator[InternalRow] = _

override def open(): Unit = {
val (streamed, build) = buildSide match {
case BuildRight => (left, right)
case BuildLeft => (right, left)
}
build.open()
val buildRelation = new CompactBuffer[InternalRow]
while (build.next()) {
buildRelation += build.fetch().copy()
}
build.close()

val boundCondition =
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)

val leftNulls = new GenericMutableRow(left.output.size)
val rightNulls = new GenericMutableRow(right.output.size)
val joinedRow = new JoinedRow
val matchedBuildTuples = new BitSet(buildRelation.size)
val resultProj = genResultProjection
streamed.open()

// streamedRowMatches also contains null rows if using outer join
val streamedRowMatches: Iterator[InternalRow] = streamed.asIterator.flatMap { streamedRow =>
val matchedRows = new CompactBuffer[InternalRow]

var i = 0
var streamRowMatched = false

// Scan the build relation to look for matches for each streamed row
while (i < buildRelation.size) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Scan the build relation to look for matches for each streamed row

val buildRow = buildRelation(i)
buildSide match {
case BuildRight => joinedRow(streamedRow, buildRow)
case BuildLeft => joinedRow(buildRow, streamedRow)
}
if (boundCondition(joinedRow)) {
matchedRows += resultProj(joinedRow).copy()
streamRowMatched = true
matchedBuildTuples.set(i)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reduce duplicate code:

buildSide match {
  case BuildRight => joinedRow(streamedRow, buildRow)
  case BuildLeft => joinedRow(buildRow, streamedRow)
  case _ =>
}
if (boundCondition(joinedRow)) {
  matchedRows += resultProj(joinedRow).copy()
  streamedRowMatched = true
  matchedBuildTuples.set(i)
}

also joinedRow(x, y) mutates the row itself, so you can just use joinedRow directly when doing the projection

i += 1
}

// If this row had no matches and we're using outer join, join it with the null rows
if (!streamRowMatched) {
(joinType, buildSide) match {
case (LeftOuter | FullOuter, BuildRight) =>
matchedRows += resultProj(joinedRow(streamedRow, rightNulls)).copy()
case (RightOuter | FullOuter, BuildLeft) =>
matchedRows += resultProj(joinedRow(leftNulls, streamedRow)).copy()
case _ =>
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be clearer if this looks like:

// If this row had no matches and we're using outer join, join it with the null row
if (!streamRowMatched) {
  (joinType, buildSide) match {
    ...
  }
}


matchedRows.iterator
}

// If we're using outer join, find rows on the build side that didn't match anything
// and join them with the null row
lazy val unmatchedBuildRows: Iterator[InternalRow] = {
var i = 0
buildRelation.filter { row =>
val r = !matchedBuildTuples.get(i)
i += 1
r
}.iterator
}
iterator = (joinType, buildSide) match {
case (RightOuter | FullOuter, BuildRight) =>
streamedRowMatches ++
unmatchedBuildRows.map { buildRow => resultProj(joinedRow(leftNulls, buildRow)) }
case (LeftOuter | FullOuter, BuildLeft) =>
streamedRowMatches ++
unmatchedBuildRows.map { buildRow => resultProj(joinedRow(buildRow, rightNulls)) }
case _ => streamedRowMatches
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about the following to reduce duplicate code? It should be functionally the same.

// If we're using outer join, find rows on the build side that didn't match anything
// and join them with the null row
lazy val unmatchedBuildRows: Iterator[InternalRow] = {
  var i = 0
  buildRelation.filter { row =>
    val r = !includedBuildTuples.get(i)
    i += 1
    r
  }.iterator
}
val additionalRows: Iterator[InternalRow] = (joinType, buildSide) match {
  case (RightOuter | FullOuter, BuildRight) =>
    unmatchedBuildRows.map { resultProj(joinedRow(leftNulls, _)) } // copy?
  case (LeftOuter | FullOuter, BuildLeft) =>
    unmatchedBuildRows.map { resultProj(joinedRow(_, rightNulls)) }
  case _ =>
    Iterator.empty[InternalRow]
}

iterator = streamedRowMatches ++ additionalRows

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a little change to your suggest to avoid using Iterator.empty[InternalRow].

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok that's fine

}

override def next(): Boolean = {
if (iterator.hasNext) {
currentRow = iterator.next()
true
} else {
false
}
}

override def fetch(): InternalRow = currentRow

override def close(): Unit = {
left.close()
right.close()
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.local

class ExpandNodeSuite extends LocalNodeTest {

import testImplicits._

test("expand") {
val input = Seq((1, 1), (2, 2), (3, 3), (4, 4), (5, 5)).toDF("key", "value")
checkAnswer(
input,
node =>
ExpandNode(conf, Seq(
Seq(
input.col("key") + input.col("value"), input.col("key") - input.col("value")
).map(_.expr),
Seq(
input.col("key") * input.col("value"), input.col("key") / input.col("value")
).map(_.expr)
), node.output, node),
Seq(
(2, 0),
(1, 1),
(4, 0),
(4, 1),
(6, 0),
(9, 1),
(8, 0),
(16, 1),
(10, 0),
(25, 1)
).toDF().collect()
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,6 @@ class HashJoinNodeSuite extends LocalNodeTest {

import testImplicits._

private def wrapForUnsafe(
f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = {
if (conf.unsafeEnabled) {
(left: LocalNode, right: LocalNode) => {
val _left = ConvertToUnsafeNode(conf, left)
val _right = ConvertToUnsafeNode(conf, right)
val r = f(_left, _right)
ConvertToSafeNode(conf, r)
}
} else {
f
}
}

def joinSuite(suiteName: String, confPairs: (String, String)*): Unit = {
test(s"$suiteName: inner join with one match per row") {
withSQLConf(confPairs: _*) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,20 @@ class LocalNodeTest extends SparkFunSuite with SharedSQLContext {

def conf: SQLConf = sqlContext.conf

protected def wrapForUnsafe(
f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = {
if (conf.unsafeEnabled) {
(left: LocalNode, right: LocalNode) => {
val _left = ConvertToUnsafeNode(conf, left)
val _right = ConvertToUnsafeNode(conf, right)
val r = f(_left, _right)
ConvertToSafeNode(conf, r)
}
} else {
f
}
}

/**
* Runs the LocalNode and makes sure the answer matches the expected result.
* @param input the input data to be used.
Expand Down
Loading