Skip to content

Commit df68beb

Browse files
viiryamarmbrus
authored andcommitted
[SPARK-13995][SQL] Extract correct IsNotNull constraints for Expression
## What changes were proposed in this pull request? JIRA: https://issues.apache.org/jira/browse/SPARK-13995 We infer relative `IsNotNull` constraints from logical plan's expressions in `constructIsNotNullConstraints` now. However, we don't consider the case of (nested) `Cast`. For example: val tr = LocalRelation('a.int, 'b.long) val plan = tr.where('a.attr === 'b.attr).analyze Then, the plan's constraints will have `IsNotNull(Cast(resolveColumn(tr, "a"), LongType))`, instead of `IsNotNull(resolveColumn(tr, "a"))`. This PR fixes it. Besides, as `IsNotNull` constraints are most useful for `Attribute`, we should do recursing through any `Expression` that is null intolerant and construct `IsNotNull` constraints for all `Attribute`s under these Expressions. For example, consider the following constraints: val df = Seq((1,2,3)).toDF("a", "b", "c") df.where("a + b = c").queryExecution.analyzed.constraints The inferred isnotnull constraints should be isnotnull(a), isnotnull(b), isnotnull(c), instead of isnotnull(a + c) and isnotnull(c). ## How was this patch tested? Test is added into `ConstraintPropagationSuite`. Author: Liang-Chi Hsieh <[email protected]> Closes #11809 from viirya/constraint-cast.
1 parent 381358f commit df68beb

File tree

7 files changed

+134
-37
lines changed

7 files changed

+134
-37
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ object Cast {
112112
}
113113

114114
/** Cast the child expression to the target data type. */
115-
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
115+
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with NullIntolerant {
116116

117117
override def toString: String = s"cast($child as ${dataType.simpleString})"
118118

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ import org.apache.spark.sql.types._
2424
import org.apache.spark.unsafe.types.CalendarInterval
2525

2626

27-
case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes {
27+
case class UnaryMinus(child: Expression) extends UnaryExpression
28+
with ExpectsInputTypes with NullIntolerant {
2829

2930
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
3031

@@ -58,7 +59,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp
5859
override def sql: String = s"(-${child.sql})"
5960
}
6061

61-
case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes {
62+
case class UnaryPositive(child: Expression)
63+
extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
6264
override def prettyName: String = "positive"
6365

6466
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
@@ -79,7 +81,8 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects
7981
@ExpressionDescription(
8082
usage = "_FUNC_(expr) - Returns the absolute value of the numeric value",
8183
extended = "> SELECT _FUNC_('-1');\n1")
82-
case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes {
84+
case class Abs(child: Expression)
85+
extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
8386

8487
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
8588

@@ -123,7 +126,7 @@ private[sql] object BinaryArithmetic {
123126
def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right))
124127
}
125128

126-
case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
129+
case class Add(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant {
127130

128131
override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
129132

@@ -152,7 +155,8 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
152155
}
153156
}
154157

155-
case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
158+
case class Subtract(left: Expression, right: Expression)
159+
extends BinaryArithmetic with NullIntolerant {
156160

157161
override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
158162

@@ -181,7 +185,8 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
181185
}
182186
}
183187

184-
case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
188+
case class Multiply(left: Expression, right: Expression)
189+
extends BinaryArithmetic with NullIntolerant {
185190

186191
override def inputType: AbstractDataType = NumericType
187192

@@ -193,7 +198,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
193198
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
194199
}
195200

196-
case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
201+
case class Divide(left: Expression, right: Expression)
202+
extends BinaryArithmetic with NullIntolerant {
197203

198204
override def inputType: AbstractDataType = NumericType
199205

@@ -269,7 +275,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
269275
}
270276
}
271277

272-
case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
278+
case class Remainder(left: Expression, right: Expression)
279+
extends BinaryArithmetic with NullIntolerant {
273280

274281
override def inputType: AbstractDataType = NumericType
275282

@@ -457,7 +464,7 @@ case class MinOf(left: Expression, right: Expression)
457464
override def symbol: String = "min"
458465
}
459466

460-
case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
467+
case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant {
461468

462469
override def toString: String = s"pmod($left, $right)"
463470

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ trait NamedExpression extends Expression {
9797
}
9898
}
9999

100-
abstract class Attribute extends LeafExpression with NamedExpression {
100+
abstract class Attribute extends LeafExpression with NamedExpression with NullIntolerant {
101101

102102
override def references: AttributeSet = AttributeSet(this)
103103

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,11 @@ package object expressions {
9292
StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable)))
9393
}
9494
}
95+
96+
/**
97+
* When an expression inherits this, meaning the expression is null intolerant (i.e. any null
98+
* input will result in null output). We will use this information during constructing IsNotNull
99+
* constraints.
100+
*/
101+
trait NullIntolerant
95102
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ trait PredicateHelper {
9090

9191

9292
case class Not(child: Expression)
93-
extends UnaryExpression with Predicate with ImplicitCastInputTypes {
93+
extends UnaryExpression with Predicate with ImplicitCastInputTypes with NullIntolerant {
9494

9595
override def toString: String = s"NOT $child"
9696

@@ -402,7 +402,8 @@ private[sql] object Equality {
402402
}
403403

404404

405-
case class EqualTo(left: Expression, right: Expression) extends BinaryComparison {
405+
case class EqualTo(left: Expression, right: Expression)
406+
extends BinaryComparison with NullIntolerant {
406407

407408
override def inputType: AbstractDataType = AnyDataType
408409

@@ -467,7 +468,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
467468
}
468469

469470

470-
case class LessThan(left: Expression, right: Expression) extends BinaryComparison {
471+
case class LessThan(left: Expression, right: Expression)
472+
extends BinaryComparison with NullIntolerant {
471473

472474
override def inputType: AbstractDataType = TypeCollection.Ordered
473475

@@ -479,7 +481,8 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso
479481
}
480482

481483

482-
case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
484+
case class LessThanOrEqual(left: Expression, right: Expression)
485+
extends BinaryComparison with NullIntolerant {
483486

484487
override def inputType: AbstractDataType = TypeCollection.Ordered
485488

@@ -491,7 +494,8 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo
491494
}
492495

493496

494-
case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison {
497+
case class GreaterThan(left: Expression, right: Expression)
498+
extends BinaryComparison with NullIntolerant {
495499

496500
override def inputType: AbstractDataType = TypeCollection.Ordered
497501

@@ -503,7 +507,8 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar
503507
}
504508

505509

506-
case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
510+
case class GreaterThanOrEqual(left: Expression, right: Expression)
511+
extends BinaryComparison with NullIntolerant {
507512

508513
override def inputType: AbstractDataType = TypeCollection.Ordered
509514

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -44,25 +44,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
4444
* returns a constraint of the form `isNotNull(a)`
4545
*/
4646
private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = {
47-
var isNotNullConstraints = Set.empty[Expression]
48-
49-
// First, we propagate constraints if the condition consists of equality and ranges. For all
50-
// other cases, we return an empty set of constraints
51-
constraints.foreach {
52-
case EqualTo(l, r) =>
53-
isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
54-
case GreaterThan(l, r) =>
55-
isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
56-
case GreaterThanOrEqual(l, r) =>
57-
isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
58-
case LessThan(l, r) =>
59-
isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
60-
case LessThanOrEqual(l, r) =>
61-
isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
62-
case Not(EqualTo(l, r)) =>
63-
isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
64-
case _ => // No inference
65-
}
47+
// First, we propagate constraints from the null intolerant expressions.
48+
var isNotNullConstraints: Set[Expression] =
49+
constraints.flatMap(scanNullIntolerantExpr).map(IsNotNull(_))
6650

6751
// Second, we infer additional constraints from non-nullable attributes that are part of the
6852
// operator's output
@@ -72,6 +56,17 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
7256
isNotNullConstraints -- constraints
7357
}
7458

59+
/**
60+
* Recursively explores the expressions which are null intolerant and returns all attributes
61+
* in these expressions.
62+
*/
63+
private def scanNullIntolerantExpr(expr: Expression): Seq[Attribute] = expr match {
64+
case a: Attribute => Seq(a)
65+
case _: NullIntolerant | IsNotNull(_: NullIntolerant) =>
66+
expr.children.flatMap(scanNullIntolerantExpr)
67+
case _ => Seq.empty[Attribute]
68+
}
69+
7570
/**
7671
* Infers an additional set of constraints from a given set of equality constraints.
7772
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
2323
import org.apache.spark.sql.catalyst.dsl.plans._
2424
import org.apache.spark.sql.catalyst.expressions._
2525
import org.apache.spark.sql.catalyst.plans.logical._
26-
import org.apache.spark.sql.types.{IntegerType, StringType}
26+
import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType}
2727

2828
class ConstraintPropagationSuite extends SparkFunSuite {
2929

@@ -219,6 +219,89 @@ class ConstraintPropagationSuite extends SparkFunSuite {
219219
IsNotNull(resolveColumn(tr, "b")))))
220220
}
221221

222+
test("infer constraints on cast") {
223+
val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int)
224+
verifyConstraints(
225+
tr.where('a.attr === 'b.attr &&
226+
'c.attr + 100 > 'd.attr &&
227+
IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType))).analyze.constraints,
228+
ExpressionSet(Seq(Cast(resolveColumn(tr, "a"), LongType) === resolveColumn(tr, "b"),
229+
Cast(resolveColumn(tr, "c") + 100, LongType) > resolveColumn(tr, "d"),
230+
IsNotNull(resolveColumn(tr, "a")),
231+
IsNotNull(resolveColumn(tr, "b")),
232+
IsNotNull(resolveColumn(tr, "c")),
233+
IsNotNull(resolveColumn(tr, "d")),
234+
IsNotNull(resolveColumn(tr, "e")),
235+
IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType)))))
236+
}
237+
238+
test("infer isnotnull constraints from compound expressions") {
239+
val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int)
240+
verifyConstraints(
241+
tr.where('a.attr + 'b.attr === 'c.attr &&
242+
IsNotNull(
243+
Cast(
244+
Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType))).analyze.constraints,
245+
ExpressionSet(Seq(
246+
Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b") ===
247+
Cast(resolveColumn(tr, "c"), LongType),
248+
IsNotNull(resolveColumn(tr, "a")),
249+
IsNotNull(resolveColumn(tr, "b")),
250+
IsNotNull(resolveColumn(tr, "c")),
251+
IsNotNull(resolveColumn(tr, "e")),
252+
IsNotNull(Cast(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType)))))
253+
254+
verifyConstraints(
255+
tr.where(('a.attr * 'b.attr + 100) === 'c.attr && 'd / 10 === 'e).analyze.constraints,
256+
ExpressionSet(Seq(
257+
Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") + Cast(100, LongType) ===
258+
Cast(resolveColumn(tr, "c"), LongType),
259+
Cast(resolveColumn(tr, "d"), DoubleType) /
260+
Cast(Cast(10, LongType), DoubleType) ===
261+
Cast(resolveColumn(tr, "e"), DoubleType),
262+
IsNotNull(resolveColumn(tr, "a")),
263+
IsNotNull(resolveColumn(tr, "b")),
264+
IsNotNull(resolveColumn(tr, "c")),
265+
IsNotNull(resolveColumn(tr, "d")),
266+
IsNotNull(resolveColumn(tr, "e")))))
267+
268+
verifyConstraints(
269+
tr.where(('a.attr * 'b.attr - 10) >= 'c.attr && 'd / 10 < 'e).analyze.constraints,
270+
ExpressionSet(Seq(
271+
Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") - Cast(10, LongType) >=
272+
Cast(resolveColumn(tr, "c"), LongType),
273+
Cast(resolveColumn(tr, "d"), DoubleType) /
274+
Cast(Cast(10, LongType), DoubleType) <
275+
Cast(resolveColumn(tr, "e"), DoubleType),
276+
IsNotNull(resolveColumn(tr, "a")),
277+
IsNotNull(resolveColumn(tr, "b")),
278+
IsNotNull(resolveColumn(tr, "c")),
279+
IsNotNull(resolveColumn(tr, "d")),
280+
IsNotNull(resolveColumn(tr, "e")))))
281+
282+
verifyConstraints(
283+
tr.where('a.attr + 'b.attr - 'c.attr * 'd.attr > 'e.attr * 1000).analyze.constraints,
284+
ExpressionSet(Seq(
285+
(Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b")) -
286+
(Cast(resolveColumn(tr, "c"), LongType) * resolveColumn(tr, "d")) >
287+
Cast(resolveColumn(tr, "e") * 1000, LongType),
288+
IsNotNull(resolveColumn(tr, "a")),
289+
IsNotNull(resolveColumn(tr, "b")),
290+
IsNotNull(resolveColumn(tr, "c")),
291+
IsNotNull(resolveColumn(tr, "d")),
292+
IsNotNull(resolveColumn(tr, "e")))))
293+
294+
// The constraint IsNotNull(IsNotNull(expr)) doesn't guarantee expr is not null.
295+
verifyConstraints(
296+
tr.where('a.attr === 'c.attr &&
297+
IsNotNull(IsNotNull(resolveColumn(tr, "b")))).analyze.constraints,
298+
ExpressionSet(Seq(
299+
resolveColumn(tr, "a") === resolveColumn(tr, "c"),
300+
IsNotNull(IsNotNull(resolveColumn(tr, "b"))),
301+
IsNotNull(resolveColumn(tr, "a")),
302+
IsNotNull(resolveColumn(tr, "c")))))
303+
}
304+
222305
test("infer IsNotNull constraints from non-nullable attributes") {
223306
val tr = LocalRelation('a.int, AttributeReference("b", IntegerType, nullable = false)(),
224307
AttributeReference("c", StringType, nullable = false)())

0 commit comments

Comments
 (0)