Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
7d78be3
core
andrej-db Oct 23, 2024
7839e1d
revert MsSqlServerDialect build method
andrej-db Oct 23, 2024
0172c49
remove visitCaseWhen, more intuitive predicate wrapping
andrej-db Oct 23, 2024
7ae7397
imports
andrej-db Oct 23, 2024
49d742e
fix
andrej-db Oct 23, 2024
c2001d9
MsSqlServerDialect: comment
andrej-db Oct 24, 2024
8b1b2da
JdbcDialects: move aux here
andrej-db Oct 25, 2024
1f01b77
V2ExpressionBuilder: refactor
andrej-db Oct 25, 2024
7e36029
nit
andrej-db Oct 29, 2024
ab79afe
nit
andrej-db Nov 18, 2024
de38a8c
Update JdbcDialects.scala
cloud-fan Nov 19, 2024
56cea3c
Update MsSqlServerDialect.scala
cloud-fan Nov 19, 2024
21ec622
Update sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects…
cloud-fan Nov 19, 2024
468eb89
Update JdbcDialects.scala
cloud-fan Nov 19, 2024
1bc39ee
Update sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerD…
cloud-fan Nov 19, 2024
6a221c6
Update JdbcDialects.scala
cloud-fan Nov 19, 2024
ef1bcc8
Update MsSqlServerDialect.scala
cloud-fan Nov 19, 2024
5d94fa1
Update JdbcDialects.scala
andrej-db Nov 19, 2024
f94f8e5
Update MsSqlServerDialect.scala
andrej-db Nov 19, 2024
55084a3
Update MsSqlServerIntegrationSuite.scala
andrej-db Nov 19, 2024
7fb6b29
Update MsSqlServerDialect.scala
andrej-db Nov 19, 2024
ca0545a
Update MsSqlServerIntegrationSuite.scala
andrej-db Nov 19, 2024
df2fe41
Apply suggestions from code review
cloud-fan Nov 20, 2024
53a4220
Update connector/docker-integration-tests/src/test/scala/org/apache/s…
cloud-fan Nov 20, 2024
cbef9a5
Update MsSqlServerIntegrationSuite.scala
cloud-fan Nov 20, 2024
c990ec6
Update JdbcDialects.scala
cloud-fan Nov 20, 2024
ee4d4fb
Update sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerD…
cloud-fan Nov 20, 2024
6cff9f5
Update MsSqlServerIntegrationSuite.scala
andrej-db Nov 21, 2024
bca7ce8
Merge branch 'apache:master' into SPARK-50087-CaseWhen
andrej-db Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ package org.apache.spark.sql.jdbc.v2
import java.sql.Connection

import org.apache.spark.{SparkConf, SparkSQLFeatureNotSupportedException}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD
import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
import org.apache.spark.sql.jdbc.MsSQLServerDatabaseOnDocker
import org.apache.spark.sql.types._
Expand All @@ -37,6 +41,17 @@ import org.apache.spark.tags.DockerTest
@DockerTest
class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest {

def getExternalEngineQuery(executedPlan: SparkPlan): String = {
getExternalEngineRdd(executedPlan).asInstanceOf[JDBCRDD].getExternalEngineQuery
}

def getExternalEngineRdd(executedPlan: SparkPlan): RDD[InternalRow] = {
val queryNode = executedPlan.collect { case r: RowDataSourceScanExec =>
r
}.head
queryNode.rdd
}

override def excluded: Seq[String] = Seq(
"simple scan with OFFSET",
"simple scan with LIMIT and OFFSET",
Expand Down Expand Up @@ -146,4 +161,68 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD
|""".stripMargin)
assert(df.collect().length == 2)
}

test("SPARK-50087: SqlServer handle booleans in CASE WHEN test") {
val df = sql(
s"""|SELECT * FROM $catalogName.employee
|WHERE CASE WHEN name = 'Legolas' THEN name = 'Elf' ELSE NOT (name = 'Wizard') END
|""".stripMargin
)

// scalastyle:off
assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
"""SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE IIF(("name" <> 'Wizard'), 1, 0) END = 1) """
)
// scalastyle:on
df.collect()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we have some pushdown check ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add an assert with External query check

}

test("SPARK-50087: SqlServer handle booleans in CASE WHEN with always true test") {
val df = sql(
s"""|SELECT * FROM $catalogName.employee
|WHERE CASE WHEN (name = 'Legolas') THEN (name = 'Elf') ELSE (1=1) END
|""".stripMargin
)

// scalastyle:off
assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
"""SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE 1 END = 1) """
)
// scalastyle:on
df.collect()
}

test("SPARK-50087: SqlServer handle booleans in nested CASE WHEN test") {
val df = sql(
s"""|SELECT * FROM $catalogName.employee
|WHERE CASE WHEN (name = 'Legolas') THEN
| CASE WHEN (name = 'Elf') THEN (name = 'Elrond') ELSE (name = 'Gandalf') END
| ELSE (name = 'Sauron') END
|""".stripMargin
)

// scalastyle:off
assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
"""SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF((CASE WHEN ("name" = 'Elf') THEN IIF(("name" = 'Elrond'), 1, 0) ELSE IIF(("name" = 'Gandalf'), 1, 0) END = 1), 1, 0) ELSE IIF(("name" = 'Sauron'), 1, 0) END = 1) """
)
// scalastyle:on
df.collect()
}

test("SPARK-50087: SqlServer handle non-booleans in nested CASE WHEN test") {
val df = sql(
s"""|SELECT * FROM $catalogName.employee
|WHERE CASE WHEN (name = 'Legolas') THEN
| CASE WHEN (name = 'Elf') THEN 'Elf' ELSE 'Wizard' END
| ELSE 'Sauron' END = name
|""".stripMargin
)

// scalastyle:off
assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
"""SELECT "dept","name","salary","bonus" FROM "employee" WHERE ("name" IS NOT NULL) AND ((CASE WHEN "name" = 'Legolas' THEN CASE WHEN "name" = 'Elf' THEN 'Elf' ELSE 'Wizard' END ELSE 'Sauron' END) = "name") """
)
// scalastyle:on
df.collect()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im not sure but do we have test case for when type is not boolean ?

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) extends L
case _: BitwiseNot => generateExpressionWithName("~", expr, isPredicate)
case caseWhen @ CaseWhen(branches, elseValue) =>
val conditions = branches.map(_._1).flatMap(generateExpression(_, true))
val values = branches.map(_._2).flatMap(generateExpression(_))
val elseExprOpt = elseValue.flatMap(generateExpression(_))
val values = branches.map(_._2).flatMap(generateExpression(_, isPredicate))
val elseExprOpt = elseValue.flatMap(generateExpression(_, isPredicate))
if (conditions.length == branches.length && values.length == branches.length &&
elseExprOpt.size == elseValue.size) {
val branchExpressions = conditions.zip(values).flatMap { case (c, v) =>
Expand Down Expand Up @@ -421,7 +421,7 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) extends L
children: Seq[Expression],
dataType: DataType,
isPredicate: Boolean): Option[V2Expression] = {
val childrenExpressions = children.flatMap(generateExpression(_))
val childrenExpressions = children.flatMap(generateExpression(_, isPredicate))
if (childrenExpressions.length == children.length) {
if (isPredicate && dataType.isInstanceOf[BooleanType]) {
Some(new V2Predicate(v2ExpressionName, childrenExpressions.toArray[V2Expression]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
import org.apache.spark.sql.connector.catalog.index.TableIndex
import org.apache.spark.sql.connector.expressions.{Expression, Literal, NamedReference}
import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, JDBCOptions, JdbcOptionsInWrite, JdbcUtils}
Expand Down Expand Up @@ -377,6 +378,18 @@ abstract class JdbcDialect extends Serializable with Logging {
}

private[jdbc] class JDBCSQLBuilder extends V2ExpressionSQLBuilder {
// Some dialects do not support boolean type and this convenient util function is
// provided to generate SQL string without boolean values.
protected def inputToSQLNoBool(input: Expression): String = input match {
case p: Predicate if p.name() == "ALWAYS_TRUE" => "1"
case p: Predicate if p.name() == "ALWAYS_FALSE" => "0"
case p: Predicate => predicateToIntSQL(inputToSQL(p))
case _ => inputToSQL(input)
}

protected def predicateToIntSQL(input: String): String =
"CASE WHEN " + input + " THEN 1 ELSE 0 END"

override def visitLiteral(literal: Literal[_]): String = {
Option(literal.value()).map(v =>
compileValue(CatalystTypeConverters.convertToScala(v, literal.dataType())).toString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr
supportedFunctions.contains(funcName)

class MsSqlServerSQLBuilder extends JDBCSQLBuilder {
override protected def predicateToIntSQL(input: String): String =
"IIF(" + input + ", 1, 0)"
override def visitSortOrder(
sortKey: String, sortDirection: SortDirection, nullOrdering: NullOrdering): String = {
(sortDirection, nullOrdering) match {
Expand Down Expand Up @@ -87,12 +89,24 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr
expr match {
case e: Predicate => e.name() match {
case "=" | "<>" | "<=>" | "<" | "<=" | ">" | ">=" =>
val Array(l, r) = e.children().map {
case p: Predicate => s"CASE WHEN ${inputToSQL(p)} THEN 1 ELSE 0 END"
case o => inputToSQL(o)
}
val Array(l, r) = e.children().map(inputToSQLNoBool)
visitBinaryComparison(e.name(), l, r)
case "CASE_WHEN" => visitCaseWhen(expressionsToStringArray(e.children())) + " = 1"
case "CASE_WHEN" =>
// Since MsSqlServer cannot handle boolean expressions inside
// a CASE WHEN, it is necessary to convert those to another
// CASE WHEN expression that will return 1 or 0 depending on
// the result.
// Example:
// In: ... CASE WHEN a = b THEN c = d ... END
// Out: ... CASE WHEN a = b THEN CASE WHEN c = d THEN 1 ELSE 0 END ... END = 1
val stringArray = e.children().grouped(2).flatMap {
case Array(whenExpression, thenExpression) =>
Array(inputToSQL(whenExpression), inputToSQLNoBool(thenExpression))
case Array(elseExpression) =>
Array(inputToSQLNoBool(elseExpression))
}.toArray

visitCaseWhen(stringArray) + " = 1"
case _ => super.build(expr)
}
case _ => super.build(expr)
Expand Down