Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -230,31 +230,46 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] {
* Optimize IN predicates:
* 1. Converts the predicate to false when the list is empty and
* the value is not nullable.
* 2. Removes literal repetitions.
* 3. Replaces [[In (value, seq[Literal])]] with optimized version
* 2. Extract convertible part from list.
* 3. Removes literal repetitions.
* 4. Replaces [[In (value, seq[Literal])]] with optimized version
* [[InSet (value, HashSet[Literal])]] which is much faster.
*/
object OptimizeIn extends Rule[LogicalPlan] {
def optimizeIn(expr: In, v: Expression, list: Seq[Expression]): Expression = {
val newList = ExpressionSet(list).toSeq
if (newList.length == 1
// TODO: `EqualTo` for structural types are not working. Until SPARK-24443 is addressed,
// TODO: we exclude them in this rule.
&& !v.isInstanceOf[CreateNamedStruct]
&& !newList.head.isInstanceOf[CreateNamedStruct]) {
EqualTo(v, newList.head)
} else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) {
val hSet = newList.map(e => e.eval(EmptyRow))
InSet(v, HashSet() ++ hSet)
} else if (newList.length < list.length) {
expr.copy(list = newList)
} else { // newList.length == list.length && newList.length > 1
expr
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsDown {
case In(v, list) if list.isEmpty =>
// When v is not nullable, the following expression will be optimized
// to FalseLiteral which is tested in OptimizeInSuite.scala
If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType))
case expr @ In(v, list) if expr.inSetConvertible =>
val newList = ExpressionSet(list).toSeq
if (newList.length == 1
// TODO: `EqualTo` for structural types are not working. Until SPARK-24443 is addressed,
// TODO: we exclude them in this rule.
&& !v.isInstanceOf[CreateNamedStruct]
&& !newList.head.isInstanceOf[CreateNamedStruct]) {
EqualTo(v, newList.head)
} else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) {
val hSet = newList.map(e => e.eval(EmptyRow))
InSet(v, HashSet() ++ hSet)
} else if (newList.length < list.length) {
expr.copy(list = newList)
} else { // newList.length == list.length && newList.length > 1
case expr @ In(v, list) =>
// split list to 2 parts so that we can optimize convertible part
val (convertible, nonConvertible) = list.partition(_.isInstanceOf[Literal])
if (convertible.nonEmpty && nonConvertible.isEmpty) {
optimizeIn(expr, v, list)
} else if (convertible.nonEmpty && nonConvertible.nonEmpty &&
SQLConf.get.optimizerInExtractLiteralPart) {
val optimizedIn = optimizeIn(In(v, convertible), v, convertible)
Or(optimizedIn, In(v, nonConvertible))
} else {
expr
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,14 @@ object SQLConf {
.intConf
.createWithDefault(100)

val OPTIMIZER_IN_EXTRACT_LITERAL_PART =
buildConf("spark.sql.optimizer.inExtractLiteralPart")
.internal()
.doc("When true, we will extract and optimize the literal part of in if not all are literal.")
.version("3.1.0")
.booleanConf
.createWithDefault(true)

val OPTIMIZER_INSET_CONVERSION_THRESHOLD =
buildConf("spark.sql.optimizer.inSetConversionThreshold")
.internal()
Expand Down Expand Up @@ -2761,6 +2769,8 @@ class SQLConf extends Serializable with Logging {

def optimizerMaxIterations: Int = getConf(OPTIMIZER_MAX_ITERATIONS)

def optimizerInExtractLiteralPart: Boolean = getConf(OPTIMIZER_IN_EXTRACT_LITERAL_PART)

def optimizerInSetConversionThreshold: Int = getConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD)

def optimizerInSetSwitchThreshold: Int = getConf(OPTIMIZER_INSET_SWITCH_THRESHOLD)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_IN_EXTRACT_LITERAL_PART
import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -91,21 +92,6 @@ class OptimizeInSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("OptimizedIn test: In clause not optimized in case filter has attributes") {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this test since we support convert part of list and new test include this.

val originalQuery =
testRelation
.where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b"))))
.analyze

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
.where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b"))))
.analyze

comparePlans(optimized, correctAnswer)
}

test("OptimizedIn test: NULL IN (expr1, ..., exprN) gets transformed to Filter(null)") {
val originalQuery =
testRelation
Expand Down Expand Up @@ -238,4 +224,44 @@ class OptimizeInSuite extends PlanTest {

comparePlans(optimized, correctAnswer)
}

test("SPARK-32196: Extract In convertible part if it is not convertible") {
Seq("true", "false").foreach { enable =>
withSQLConf(OPTIMIZER_IN_EXTRACT_LITERAL_PART.key -> enable) {
val originalQuery1 =
testRelation
.where(In(UnresolvedAttribute("a"), Seq(Literal(1), UnresolvedAttribute("b"))))
.analyze
val optimized1 = Optimize.execute(originalQuery1)

if (enable.toBoolean) {
val correctAnswer1 =
testRelation
.where(
Or(EqualTo(UnresolvedAttribute("a"), Literal(1)),
In(UnresolvedAttribute("a"), Seq(UnresolvedAttribute("b"))))
)
.analyze
comparePlans(optimized1, correctAnswer1)
} else {
val correctAnswer1 =
testRelation
.where(In(UnresolvedAttribute("a"), Seq(Literal(1), UnresolvedAttribute("b"))))
.analyze
comparePlans(optimized1, correctAnswer1)
}
}
}

val originalQuery2 =
testRelation
.where(In(UnresolvedAttribute("a"), Seq(UnresolvedAttribute("b"))))
.analyze
val optimized2 = Optimize.execute(originalQuery2)
val correctAnswer2 =
testRelation
.where(In(UnresolvedAttribute("a"), Seq(UnresolvedAttribute("b"))))
.analyze
comparePlans(optimized2, correctAnswer2)
}
}