Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.jdbc.v2.join
import java.sql.Connection
import java.util.Locale

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.jdbc.{DockerJDBCIntegrationSuite, JdbcDialect, OracleDatabaseOnDocker, OracleDialect}
import org.apache.spark.sql.jdbc.v2.JDBCV2JoinPushdownIntegrationSuiteBase
import org.apache.spark.sql.types.DataTypes
Expand Down Expand Up @@ -56,6 +57,15 @@ import org.apache.spark.tags.DockerTest
class OracleJoinPushdownIntegrationSuite
extends DockerJDBCIntegrationSuite
with JDBCV2JoinPushdownIntegrationSuiteBase {
override def excluded: Seq[String] = Seq(
// Following tests are harder to be supported for Oracle because Oracle connector does
// casts in predicates. There is a separate test in this suite that is similar to
// "Test explain formatted" test from base suite.
"Test self join with condition",
"Test multi-way self join with conditions",
"Test explain formatted"
)

override val namespace: String = "SYSTEM"

override val db = new OracleDatabaseOnDocker
Expand All @@ -74,4 +84,40 @@ class OracleJoinPushdownIntegrationSuite
override def dataPreparation(connection: Connection): Unit = {
super.dataPreparation()
}

test("Test explain formatted - Oracle compatible") {
val sqlQuery =
s"""
|SELECT * FROM $catalogAndNamespace.$casedJoinTableName1 a
|JOIN $catalogAndNamespace.$casedJoinTableName2 b
|ON a.id = b.id + 1
|JOIN $catalogAndNamespace.$casedJoinTableName3 c
|ON b.id = c.id + 1
|JOIN $catalogAndNamespace.$casedJoinTableName4 d
|ON c.id = d.id + 1
|""".stripMargin

withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
val df = sql(sqlQuery)

// scalastyle:off line.size.limit
checkJoinPushed(
df,
s"""PushedFilters: [CAST(id_3 AS decimal(11,0)) = (id_4 + 1)], PushedJoins:\u0020
|[L]: PushedFilters: [CAST(ID_1 AS decimal(11,0)) = (id_3 + 1)]
| PushedJoins:
| [L]: PushedFilters: [CAST(ID AS decimal(11,0)) = (ID_1 + 1)]
| PushedJoins:
| [L]: Relation: $catalogAndNamespace.${caseConvert(joinTableName1)}
| PushedFilters: [${caseConvert("id")} IS NOT NULL]
| [R]: Relation: $catalogAndNamespace.${caseConvert(joinTableName2)}
| PushedFilters: [${caseConvert("id")} IS NOT NULL]
| [R]: Relation: $catalogAndNamespace.${caseConvert(joinTableName3)}
| PushedFilters: [id IS NOT NULL]
|[R]: Relation: $catalogAndNamespace.${caseConvert(joinTableName4)}
| PushedFilters: [id IS NOT NULL]""".stripMargin
)
// scalastyle:on line.size.limit
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource}
import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators
import org.apache.spark.sql.execution.datasources.v2.{PushedDownOperators, TableSampleInfo}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.vectorized.ConstantColumnVector
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -158,10 +158,12 @@ case class RowDataSourceScanExec(

override def inputRDD: RDD[InternalRow] = rdd

override val metadata: Map[String, String] = {
private def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]")

def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]")
private def pushedSampleMetadataString(s: TableSampleInfo): String =
s"SAMPLE (${(s.upperBound - s.lowerBound) * 100}) ${s.withReplacement} SEED(${s.seed})"

override val metadata: Map[String, String] = {
val markedFilters = if (filters.nonEmpty) {
for (filter <- filters) yield {
if (handledFilters.contains(filter)) s"*$filter" else s"$filter"
Expand All @@ -188,8 +190,11 @@ case class RowDataSourceScanExec(
seqToString(markedFilters.toSeq)
}

val pushedJoins = if (pushedDownOperators.joinedRelations.length > 1) {
Map("PushedJoins" -> seqToString(pushedDownOperators.joinedRelations))
val pushedJoins = if (pushedDownOperators.joinedRelationPushedDownOperators.nonEmpty) {
Map("PushedJoins" ->
s"\n${getPushedJoinString(
pushedDownOperators.joinedRelationPushedDownOperators(0),
pushedDownOperators.joinedRelationPushedDownOperators(1))}\n")
} else {
Map()
}
Expand All @@ -203,12 +208,80 @@ case class RowDataSourceScanExec(
seqToString(v.groupByExpressions.map(_.describe()).toImmutableArraySeq))} ++
topNOrLimitInfo ++
offsetInfo ++
pushedDownOperators.sample.map(v => "PushedSample" ->
s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})"
) ++
pushedDownOperators.sample.map(v => "PushedSample" -> pushedSampleMetadataString(v)) ++
pushedJoins
}

/**
* Build string for all the pushed down join operators. The method is recursive, so if there is
* join on top of 2 already joined relations, all of these will be present in string.
*
* The exmaple of resulting string is the following:
*
* PushedFilters: [id_3 = (id_4 + 1)], PushedJoins:
* [L]: PushedFilters: [ID_1 = (id_3 + 1)]
* PushedJoins:
* [L]: PushedFilters: [ID = (ID_1 + 1)]
* PushedJoins:
* [L]: Relation: join_pushdown_catalog.JOIN_SCHEMA.JOIN_TABLE_1
* PushedFilters: [ID IS NOT NULL]
* [R]: Relation: join_pushdown_catalog.JOIN_SCHEMA.JOIN_TABLE_2
* PushedFilters: [ID IS NOT NULL]
* [R]: Relation: join_pushdown_catalog.JOIN_SCHEMA.JOIN_TABLE_3
* PushedFilters: [id IS NOT NULL]
* [R]: Relation: join_pushdown_catalog.JOIN_SCHEMA.JOIN_TABLE_4
* PushedFilters: [id IS NOT NULL]
*/
private def getPushedJoinString(
leftSidePushedDownOperators: PushedDownOperators,
rightSidePushedDownOperators: PushedDownOperators,
indent: Int = 0): String = {
val indentStr = " ".repeat(indent)

val leftSideOperators = buildOperatorParts(leftSidePushedDownOperators, indent)
val leftSideMetadataStr = formatMetadata(leftSideOperators, indentStr + " ".repeat(5))

val rightSideOperators = buildOperatorParts(rightSidePushedDownOperators, indent)
val rightSideMetadataStr = formatMetadata(rightSideOperators, indentStr + " ".repeat(5))

val leftSideString = s"$indentStr[L]: $leftSideMetadataStr"
val rightSideString = s"$indentStr[R]: $rightSideMetadataStr"
Seq(leftSideString, rightSideString).mkString("\n")
}

private def buildOperatorParts(operators: PushedDownOperators, indent: Int): List[String] = {
val parts = List.newBuilder[String]

// Add relation name for leaf nodes (nodes without further joins)
if (operators.joinedRelationPushedDownOperators.isEmpty) {
operators.relationName.foreach(name => parts += s"Relation: $name")
}

if (operators.pushedPredicates.nonEmpty) {
parts += s"PushedFilters: ${seqToString(operators.pushedPredicates.map(_.describe()))}"
}

operators.sample.foreach { sample =>
parts += s"PushedSample: ${pushedSampleMetadataString(sample)}"
}

// Recursively get the pushed join string for child with correct indentation.
if (operators.joinedRelationPushedDownOperators.nonEmpty) {
val nestedJoins = getPushedJoinString(
operators.joinedRelationPushedDownOperators(0),
operators.joinedRelationPushedDownOperators(1),
indent + 5)
parts += s"PushedJoins:\n$nestedJoins"
}

parts.result()
}

private def formatMetadata(parts: List[String], indentStr: String): String = {
val (basicParts, nestedJoinsParts) = parts.partition(!_.startsWith("PushedJoins:"))
(basicParts ++ nestedJoinsParts).mkString("\n" + indentStr)
}

// Don't care about `rdd` and `tableIdentifier`, and `stream` when canonicalizing.
override def doCanonicalize(): SparkPlan =
copy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ object DataSourceStrategy
l.output.toStructType,
Set.empty,
Set.empty,
PushedDownOperators(None, None, None, None, Seq.empty, Seq.empty, Seq.empty),
PushedDownOperators(None, None, None, None, Seq.empty, Seq.empty, Seq.empty, None),
toCatalystRDD(l, baseRelation.buildScan()),
baseRelation,
l.stream,
Expand Down Expand Up @@ -476,7 +476,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
PushedDownOperators(None, None, None, None, Seq.empty, Seq.empty, Seq.empty),
PushedDownOperators(None, None, None, None, Seq.empty, Seq.empty, Seq.empty, None),
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.stream,
Expand All @@ -500,7 +500,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
PushedDownOperators(None, None, None, None, Seq.empty, Seq.empty, Seq.empty),
PushedDownOperators(None, None, None, None, Seq.empty, Seq.empty, Seq.empty, None),
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.stream,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ case class PushedDownOperators(
offset: Option[Int],
sortValues: Seq[SortOrder],
pushedPredicates: Seq[Predicate],
joinedRelations: Seq[String]) {
joinedRelationPushedDownOperators: Seq[PushedDownOperators],
// Relation name in case of leaf relation. For join nodes, this is empty.
relationName: Option[String]) {
assert((limit.isEmpty && sortValues.isEmpty) || limit.isDefined)
}
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,15 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
rightSideRequiredColumnsWithAliases,
translatedCondition.get)
) {
val leftSidePushedDownOperators = getPushedDownOperators(leftHolder)
val rightSidePushedDownOperators = getPushedDownOperators(rightHolder)

leftHolder.joinedRelations = leftHolder.joinedRelations ++ rightHolder.joinedRelations
leftHolder.pushedPredicates = leftHolder.pushedPredicates ++
rightHolder.pushedPredicates :+ translatedCondition.get
leftHolder.joinedRelationsPushedDownOperators =
Seq(leftSidePushedDownOperators, rightSidePushedDownOperators)

leftHolder.pushedPredicates = Seq(translatedCondition.get)
leftHolder.pushedSample = None

leftHolder.output = node.output.asInstanceOf[Seq[AttributeReference]]
leftHolder.pushedJoinOutputMap = pushedJoinOutputMap
Expand Down Expand Up @@ -792,13 +798,18 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
f.pushedFilters()
case _ => Array.empty[sources.Filter]
}
val pushedDownOperators = PushedDownOperators(sHolder.pushedAggregate, sHolder.pushedSample,
sHolder.pushedLimit, sHolder.pushedOffset, sHolder.sortOrders, sHolder.pushedPredicates,
sHolder.joinedRelations.map(_.name))
val pushedDownOperators = getPushedDownOperators(sHolder)
V1ScanWrapper(v1, pushedFilters.toImmutableArraySeq, pushedDownOperators)
case _ => scan
}
}

private def getPushedDownOperators(sHolder: ScanBuilderHolder): PushedDownOperators = {
val optRelationName = Option.when(sHolder.joinedRelations.length <= 1)(sHolder.relation.name)
PushedDownOperators(sHolder.pushedAggregate, sHolder.pushedSample,
sHolder.pushedLimit, sHolder.pushedOffset, sHolder.sortOrders, sHolder.pushedPredicates,
sHolder.joinedRelationsPushedDownOperators, optRelationName)
}
}

case class ScanBuilderHolder(
Expand All @@ -821,6 +832,8 @@ case class ScanBuilderHolder(

var joinedRelations: Seq[DataSourceV2RelationBase] = Seq(relation)

var joinedRelationsPushedDownOperators: Seq[PushedDownOperators] = Seq.empty[PushedDownOperators]

var pushedJoinOutputMap: AttributeMap[Expression] = AttributeMap.empty[Expression]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,17 @@ trait DataSourcePushdownTestUtils extends ExplainSuiteHelper {
assert(joinNodes.nonEmpty, "Join should not be pushed down")
}

protected def checkJoinPushed(df: DataFrame, expectedTables: String*): Unit = {
protected def checkJoinPushed(df: DataFrame): Unit = {
val joinNodes = df.queryExecution.optimizedPlan.collect {
case j: Join => j
}
assert(joinNodes.isEmpty, "Join should be pushed down")
if (expectedTables.nonEmpty) {
checkPushedInfo(df, s"PushedJoins: [${expectedTables.mkString(", ")}]")
}

protected def checkJoinPushed(df: DataFrame, expectedPushdownString: String): Unit = {
checkJoinPushed(df)
if (expectedPushdownString.nonEmpty) {
checkPushedInfo(df, expectedPushdownString)
}
}

Expand Down
Loading