Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ object TypeCoercion {
i
}

case i @ In(a, b) if b.exists(_.dataType != a.dataType) =>
case i @ In(a, b) if b.exists(_.dataType != i.value.dataType) =>
findWiderCommonType(i.children.map(_.dataType)) match {
case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType)))
case None => i
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ package object dsl {
case c: CreateNamedStruct => InSubquery(c.valExprs, l)
case other => InSubquery(Seq(other), l)
}
case _ => In(expr, list)
case _ => expr match {
case c: CreateNamedStruct => In(c.valExprs, list)
case other => In(Seq(other), list)
}
}

def like(other: Expression): Expression = Like(expr, other)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ object Canonicalize {
case Not(LessThanOrEqual(l, r)) => GreaterThan(l, r)

// order the list in the In operator
case In(value, list) if list.length > 1 => In(value, list.sortBy(_.hashCode()))
case In(values, list) if list.length > 1 => In(values, list.sortBy(_.hashCode()))

case _ => e
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ import scala.collection.immutable.TreeSet
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate}
import org.apache.spark.sql.catalyst.expressions.codegen.Block
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._


Expand Down Expand Up @@ -138,13 +140,12 @@ case class Not(child: Expression)
override def sql: String = s"(NOT ${child.sql})"
}

/**
* Evaluates to `true` if `values` are returned in `query`'s result set.
*/
case class InSubquery(values: Seq[Expression], query: ListQuery)
extends Predicate with Unevaluable {
trait InBase extends Predicate {
def values: Seq[Expression]

@transient protected lazy val isMultiValued = values.length > 1

@transient private lazy val value: Expression = if (values.length > 1) {
@transient lazy val value: Expression = if (isMultiValued) {
CreateNamedStruct(values.zipWithIndex.flatMap {
case (v: NamedExpression, _) => Seq(Literal(v.name), v)
case (v, idx) => Seq(Literal(s"_$idx"), v)
Expand All @@ -153,6 +154,28 @@ case class InSubquery(values: Seq[Expression], query: ListQuery)
values.head
}

@transient lazy val checkNullGenCode: (ExprCode) => Block = {
if (isMultiValued && !SQLConf.get.inFalseForNullField) {
e => code"${e.isNull} || ${e.value}.anyNull()"
} else {
e => code"${e.isNull}"
}
}

@transient lazy val checkNullEval: (Any) => Boolean = {
if (isMultiValued && !SQLConf.get.inFalseForNullField) {
input => input == null || input.asInstanceOf[InternalRow].anyNull
} else {
input => input == null
}
}
}

/**
* Evaluates to `true` if `values` are returned in `query`'s result set.
*/
case class InSubquery(values: Seq[Expression], query: ListQuery)
extends InBase with Unevaluable {

override def checkInputDataTypes(): TypeCheckResult = {
if (values.length != query.childOutputs.length) {
Expand Down Expand Up @@ -202,7 +225,12 @@ case class InSubquery(values: Seq[Expression], query: ListQuery)
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "expr1 _FUNC_(expr2, expr3, ...) - Returns true if `expr` equals to any valN.",
usage = """
expr1 _FUNC_(expr2, expr3, ...) - Returns true if `expr1` equals to any exprN. Otherwise, if
`expr` is a single value and it is null or any exprN is null or `expr` contains multiple
values and spark.sql.legacy.inOperator.falseForNullField is false and any of the exprN or
fields of the exprN is null it returns null, else it returns false.
""",
arguments = """
Arguments:
* expr1, expr2, expr3, ... - the arguments must be same type.
Expand All @@ -219,7 +247,7 @@ case class InSubquery(values: Seq[Expression], query: ListQuery)
true
""")
// scalastyle:on line.size.limit
case class In(value: Expression, list: Seq[Expression]) extends Predicate {
case class In(values: Seq[Expression], list: Seq[Expression]) extends InBase {

require(list != null, "list should not be null")

Expand All @@ -234,24 +262,29 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
}
}

override def children: Seq[Expression] = value +: list
override def children: Seq[Expression] = values ++ list
lazy val inSetConvertible = list.forall(_.isInstanceOf[Literal])
private lazy val ordering = TypeUtils.getInterpretedOrdering(value.dataType)

override def nullable: Boolean = children.exists(_.nullable)
override def nullable: Boolean = if (isMultiValued && !SQLConf.get.inFalseForNullField) {
children.exists(_.nullable) ||
list.exists(_.dataType.asInstanceOf[StructType].exists(_.nullable))
} else {
value.nullable || list.exists(_.nullable)
}
override def foldable: Boolean = children.forall(_.foldable)

override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}"

override def eval(input: InternalRow): Any = {
val evaluatedValue = value.eval(input)
if (evaluatedValue == null) {
if (checkNullEval(evaluatedValue)) {
null
} else {
var hasNull = false
list.foreach { e =>
val v = e.eval(input)
if (v == null) {
if (checkNullEval(v)) {
hasNull = true
} else if (ordering.equiv(v, evaluatedValue)) {
return true
Expand Down Expand Up @@ -283,7 +316,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
val listCode = listGen.map(x =>
s"""
|${x.code}
|if (${x.isNull}) {
|if (${checkNullGenCode(x)}) {
| $tmpResult = $HAS_NULL; // ${ev.isNull} = true;
|} else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) {
| $tmpResult = $MATCHED; // ${ev.isNull} = false; ${ev.value} = true;
Expand Down Expand Up @@ -316,7 +349,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
code"""
|${valueGen.code}
|byte $tmpResult = $HAS_NULL;
|if (!${valueGen.isNull}) {
|if (!(${checkNullGenCode(valueGen)})) {
| $tmpResult = $NOT_MATCHED;
| $javaDataType $valueArg = ${valueGen.value};
| do {
Expand All @@ -339,37 +372,57 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
* Optimized version of In clause, when all filter values of In clause are
* static.
*/
case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with Predicate {
case class InSet(values: Seq[Expression], hset: Set[Any]) extends InBase {

require(hset != null, "hset could not be null")

override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}"
override def toString: String = s"$value INSET ${hset.mkString("(", ",", ")")}"

@transient private[this] lazy val hasNull: Boolean = hset.contains(null)
override def children: Seq[Expression] = values

override def nullable: Boolean = child.nullable || hasNull
@transient private[this] lazy val hasNull: Boolean = {
if (isMultiValued && !SQLConf.get.inFalseForNullField) {
hset.exists(checkNullEval)
} else {
hset.contains(null)
}
}

protected override def nullSafeEval(value: Any): Any = {
if (set.contains(value)) {
true
} else if (hasNull) {
override def nullable: Boolean = {
val isValueNullable = if (isMultiValued && !SQLConf.get.inFalseForNullField) {
values.exists(_.nullable)
} else {
value.nullable
}
isValueNullable || hasNull
}

override def eval(input: InternalRow): Any = {
val inputValue = value.eval(input)
if (checkNullEval(inputValue)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we change behavior here? seems null inset (null, xxx) returns true previously.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, because previously we were overriding nullSafeEval and the child of InSet, so null inset (whatever) was returning null

null
} else {
false
if (set.contains(inputValue)) {
true
} else if (hasNull) {
null
} else {
false
}
}
}

@transient lazy val set: Set[Any] = child.dataType match {
@transient lazy val set: Set[Any] = value.dataType match {
case _: AtomicType => hset
case _: NullType => hset
case _ =>
// for structs use interpreted ordering to be able to compare UnsafeRows with non-UnsafeRows
TreeSet.empty(TypeUtils.getInterpretedOrdering(child.dataType)) ++ hset
TreeSet.empty(TypeUtils.getInterpretedOrdering(value.dataType)) ++ hset
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val setTerm = ctx.addReferenceObj("set", set)
val childGen = child.genCode(ctx)
val childGen = value.genCode(ctx)
val setIsNull = if (hasNull) {
s"${ev.isNull} = !${ev.value};"
} else {
Expand All @@ -378,7 +431,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
ev.copy(code =
code"""
|${childGen.code}
|${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull};
|${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${checkNullGenCode(childGen)};
|${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false;
|if (!${ev.isNull}) {
| ${ev.value} = $setTerm.contains(${childGen.value});
Expand All @@ -388,7 +441,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
}

override def sql: String = {
val valueSQL = child.sql
val valueSQL = value.sql
val listSQL = hset.toSeq.map(Literal(_).sql).mkString(", ")
s"($valueSQL IN ($listSQL))"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,27 +212,33 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] {
* 1. Converts the predicate to false when the list is empty and
* the value is not nullable.
* 2. Removes literal repetitions.
* 3. Replaces [[In (value, seq[Literal])]] with optimized version
* [[InSet (value, HashSet[Literal])]] which is much faster.
* 3. Replaces [[In (values, seq[Literal])]] with optimized version
* [[InSet (values, HashSet[Literal])]] which is much faster.
*/
object OptimizeIn extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsDown {
case In(v, list) if list.isEmpty =>
// When v is not nullable, the following expression will be optimized
case i @ In(values, list) if list.isEmpty =>
// When values are not nullable, the following expression will be optimized
// to FalseLiteral which is tested in OptimizeInSuite.scala
If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType))
case expr @ In(v, list) if expr.inSetConvertible =>
val isNotNull = if (SQLConf.get.inFalseForNullField) {
IsNotNull(i.value)
} else {
values.map(IsNotNull).reduce(And)
}
If(isNotNull, FalseLiteral, Literal(null, BooleanType))
case expr @ In(values, list) if expr.inSetConvertible =>
// if we have more than one element in the values, we have to skip this optimization
val newList = ExpressionSet(list).toSeq
if (newList.length == 1
// TODO: `EqualTo` for structural types are not working. Until SPARK-24443 is addressed,
// TODO: we exclude them in this rule.
&& !v.isInstanceOf[CreateNamedStructLike]
&& !expr.value.isInstanceOf[CreateNamedStructLike]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  @transient protected lazy val isMultiValued = values.length > 1
  @transient lazy val value: Expression = if (isMultiValued) {
    CreateNamedStruct(values.zipWithIndex.flatMap {
      case (v: NamedExpression, _) => Seq(Literal(v.name), v)
      case (v, idx) => Seq(Literal(s"_$idx"), v)
    })
  } else {
    values.head
  }
}

According to the implementation, expr.value.isInstanceOf[CreateNamedStructLike] means expr.values.length > 1, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, rigth

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we use expr.values.length == 1 here to make it more clear?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really, because expr.value.isInstanceOf[CreateNamedStructLike] means:

  • either expr.values.length == 1;
  • or expr.values.head.isInstanceOf[CreateNamedStructLike];

Basically there are 2 cases: one where we have several attributes in the value before IN; the other when there is a single value before IN but the value is a struct. expr.value.isInstanceOf[CreateNamedStructLike] catches both. I can add a comment explaining these 2 cases if you think is needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, you mean
expr.values.length > 1 => expr.value.isInstanceOf[CreateNamedStructLike]
but expr.value.isInstanceOf[CreateNamedStructLike] can't => expr.values.length > 1

Can you give an example?

Based on my understanding, the code here is trying to optimize a case when it's not a multi-value in and the list has only one element.

Copy link
Contributor Author

@mgaido91 mgaido91 Oct 31, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I mean that. An example is:

select 1 from (select struct('a', 1, 'b', '2') as a1) t1 where a1 in ((...), ...);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for your case, it's not CreateNamedStructLike, but just a struct type column?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, because of optimizations, it is a CreateNamedStructLike

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, I think for this case we should optimize it.

Anyway it follows the previous behavior, we can change it later.

&& !newList.head.isInstanceOf[CreateNamedStructLike]) {
EqualTo(v, newList.head)
EqualTo(expr.value, newList.head)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, we do this only when value is not a CreateNamedStructLike, so we don't go here if there are multi-values

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we update the match here? I think it should be In(Seq(vaue) ...) now

Copy link
Contributor Author

@mgaido91 mgaido91 Oct 26, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, sorry, we can't do that, otherwise we would skip the other possible optimizations here, eg. converting to InSet, reducing the list of values, etc.etc.

What should be done, instead, is doing the same change to InSet, so that the way nulls are handled is coherent.

} else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) {
val hSet = newList.map(e => e.eval(EmptyRow))
InSet(v, HashSet() ++ hSet)
InSet(values, HashSet() ++ hSet)
} else if (newList.length < list.length) {
expr.copy(list = newList)
} else { // newList.length == list.length && newList.length > 1
Expand Down Expand Up @@ -527,7 +533,7 @@ object NullPropagation extends Rule[LogicalPlan] {
}

// If the value expression is NULL then transform the In expression to null literal.
case In(Literal(null, _), _) => Literal.create(null, BooleanType)
case In(Seq(Literal(null, _)), _) => Literal.create(null, BooleanType)
case InSubquery(Seq(Literal(null, _)), _) => Literal.create(null, BooleanType)

// Non-leaf NullIntolerant expressions will return null, if at least one of its children is
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1120,7 +1120,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
case SqlBaseParser.IN if ctx.query != null =>
invertIfNotDefined(InSubquery(getValueExpressions(e), ListQuery(plan(ctx.query))))
case SqlBaseParser.IN =>
invertIfNotDefined(In(e, ctx.expression.asScala.map(expression)))
invertIfNotDefined(In(getValueExpressions(e), ctx.expression.asScala.map(expression)))
case SqlBaseParser.LIKE =>
invertIfNotDefined(Like(e, expression(ctx.pattern)))
case SqlBaseParser.RLIKE =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,14 @@ case class FilterEstimation(plan: Filter) extends Logging {
case op @ GreaterThanOrEqual(l: Literal, ar: Attribute) =>
evaluateBinary(LessThanOrEqual(ar, l), ar, l, update)

case In(ar: Attribute, expList)
if expList.forall(e => e.isInstanceOf[Literal]) =>
case In(Seq(ar: Attribute), expList) if expList.forall(e => e.isInstanceOf[Literal]) =>
// Expression [In (value, seq[Literal])] will be replaced with optimized version
// [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10.
// Here we convert In into InSet anyway, because they share the same processing logic.
val hSet = expList.map(e => e.eval())
evaluateInSet(ar, HashSet() ++ hSet, update)

case InSet(ar: Attribute, set) =>
case InSet(Seq(ar: Attribute), set) =>
evaluateInSet(ar, set, update)

// In current stage, we don't have advanced statistics such as sketches or histograms.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1561,6 +1561,16 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val LEGACY_IN_FALSE_FOR_NULL_FIELD =
buildConf("spark.sql.legacy.inOperator.falseForNullField")
.internal()
.doc("When set to true, the IN operator returns false when comparing multiple values " +
"containing a null. When set to false (default), it returns null, instead. This is " +
"important especially when using NOT IN as in the second case, it filters out the rows " +
"when a null is present in a field; while in the first one, those rows are returned.")
.booleanConf
.createWithDefault(false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we set this true by default in Spark 2.4 at least?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it can be done, WDYT @cloud-fan @gatorsmile ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kindly ping @cloud-fan @gatorsmile


val LEGACY_INTEGRALDIVIDE_RETURN_LONG = buildConf("spark.sql.legacy.integralDivide.returnBigint")
.doc("If it is set to true, the div operator returns always a bigint. This behavior was " +
"inherited from Hive. Otherwise, the return type is the data type of the operands.")
Expand Down Expand Up @@ -1978,6 +1988,8 @@ class SQLConf extends Serializable with Logging {

def setOpsPrecedenceEnforced: Boolean = getConf(SQLConf.LEGACY_SETOPS_PRECEDENCE_ENABLED)

def inFalseForNullField: Boolean = getConf(SQLConf.LEGACY_IN_FALSE_FOR_NULL_FIELD)

def integralDivideReturnLong: Boolean = getConf(SQLConf.LEGACY_INTEGRALDIVIDE_RETURN_LONG)

/** ********************** SQLConf functionality methods ************ */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,21 +280,22 @@ class AnalysisSuite extends AnalysisTest with Matchers {
}

test("SPARK-8654: invalid CAST in NULL IN(...) expression") {
val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(2))), "a")() :: Nil,
val plan = Project(Alias(In(Seq(Literal(null)), Seq(Literal(1), Literal(2))), "a")() :: Nil,
LocalRelation()
)
assertAnalysisSuccess(plan)
}

test("SPARK-8654: different types in inlist but can be converted to a common type") {
val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(1.2345))), "a")() :: Nil,
val plan = Project(
Alias(In(Seq(Literal(null)), Seq(Literal(1), Literal(1.2345))), "a")() :: Nil,
LocalRelation()
)
assertAnalysisSuccess(plan)
}

test("SPARK-8654: check type compatibility error") {
val plan = Project(Alias(In(Literal(null), Seq(Literal(true), Literal(1))), "a")() :: Nil,
val plan = Project(Alias(In(Seq(Literal(null)), Seq(Literal(true), Literal(1))), "a")() :: Nil,
LocalRelation()
)
assertAnalysisError(plan, Seq("data type mismatch: Arguments must be same type"))
Expand Down
Loading