- 
                Notifications
    You must be signed in to change notification settings 
- Fork 28.9k
[SPARK-29213][SQL] Generate extra IsNotNull predicate in FilterExec #25902
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| @viirya @cloud-fan gentle ping, would you like to review this, thank you. :) | 
045df47    to
    7e54ccb      
    Compare
  
    | ok to test | 
| override def output: Seq[Attribute] = { | ||
| child.output.map { a => | ||
| if (a.nullable && notNullAttributes.contains(a.exprId)) { | ||
| if (a.nullable && notNullPreds.exists(_.semanticEquals(a))) { | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the real difference? AFAIK Attribute#sementicEquals just compare the expr id.
| Test build #111232 has finished for PR 25902 at commit  
 | 
| override def output: Seq[Attribute] = { | ||
| child.output.map { a => | ||
| if (a.nullable && notNullAttributes.contains(a.exprId)) { | ||
| if (a.nullable && notNullPreds.exists(_.semanticEquals(a))) { | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
notNullPreds is a list of IsNotNull expressions. I think this semanticEquals comparison will fail at all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, this will fail all. This change is not good enough. I'll try another way.
| @cloud-fan @viirya The real problem is here: At the same time, filter output  attribute  To fix this, I think we can filter attributes both in   | 
| Test build #111255 has finished for PR 25902 at commit  
 | 
| Test build #111256 has finished for PR 25902 at commit  
 | 
Some slight modifications so we no longer depend on Hive.
|  | ||
| // The columns that will filtered out by `IsNotNull` could be considered as not nullable. | ||
| private val notNullAttributes = notNullPreds.flatMap(_.references).distinct.map(_.exprId) | ||
| .diff(otherPreds.flatMap(_.references).distinct.map(_.exprId)) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fix seems sub-optimal to me. If we have a IsNotNull(a), then a should be a notNullAttribute even if a appears in otherPreds like EqualTo(a, b).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the bug is about how we codegen null checks, the current logic is:
- split the predicates into notNullPreds(IsNotNull expressions) andotherPreds
- try codegen otherPredsfirst, thennotNullPreds.
- if an other-predicate can leverage a specific not-null-predicate, then codegen that not-null-predicate first, so that we can codegen the other-predicate with non-nullable attributes.
There is a problem in step 3. Image that we have a IsNotNull(SubString(a)) and a SomeFunc(a). Then a is a not-null attribute for sure, even if there is no IsNullNull(a) predicate.
When we codegen SomeFunc(a), we can't find a IsNotNull(a) expression, so we skip the null check. However, we assume a is not nullable when codegen SomeFunc(a), which is wrong as IsNotNull(SubString(a)) has not been codegened yet.
| I think the problem is because the child of IsNotNull is not a reference. So when we look at the IsNotNull that can filter null values for a non-IsNotNull predicate, we fail to look at it up: val idx = notNullPreds.indexWhere { n => n.asInstanceOf[IsNotNull].child.semanticEquals(r)}Doesn't the easiest fix look: val idx = notNullPreds.indexWhere { n =>
  n.asInstanceOf[IsNotNull].child.references.contains(r)
} | 
| @viirya your proposal works, but is a little against the original design goal. Think about  When codegen  I have a slightly different proposal: when codegen  | 
| Ah, I see. You are correct. Your proposal sounds good. | 
| 
 
 To be more comprehensive, should we check if there is IsNotNull predicate for a predicate and all its children before generate IsNotNull for its references? But this is a little complicated. | 
| @wangshuo128 you are right that we can improve the algorithm further, and we can do it once the bug is fixed. For now let's fix the bug without introducing a regression. That said, for any predicate (not  | 
| 
 Ok, I'll try to fix this bug first. | 
| Test build #111341 has finished for PR 25902 at commit  
 | 
| Sorry, @JoshRosen , do you mind I use your code to move test from sql/hive to sql/core ? | 
| 
 Yes, please feel free to use that! This bug is somewhat hard to reproduce in unit tests because filter pushdown to data source scans will prevent the  (I'm following this PR because I'm interested in other aspects of  | 
| // TODO: revisit this. We can consider reordering predicates as well. | ||
| val generatedIsNotNullChecks = new Array[Boolean](notNullPreds.length) | ||
|  | ||
| val extraIsNotNullReferences = mutable.Set[Attribute]() | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: extraIsNotNullAttrs
|  | ||
| val extraIsNotNullReferences = mutable.Set[Attribute]() | ||
|  | ||
| def outputContainsNotNull(ref: Attribute): Boolean = { | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method always return true, so we add IsNotNull for all attributes which is sub-optimal. We should check notNullAttributes.contains(ref.exprId)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, my mistake.
| } | ||
| } | ||
|  | ||
| test("SPARK-29213 Make it consistent when get notnull output and generate null " + | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can simply describe the bug in the test name: SPARK-29213: FilterExec should not throw NPE
| 
 @JoshRosen Thanks a lot! | 
| } | ||
|  | ||
| test("SPARK-29213: FilterExec should not throw NPE") { | ||
| withView("t1", "t2", "t3") { | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: withView -> withTempView
| sql("select ''").as[String].map(identity).toDF("x").createOrReplaceTempView("t3") | ||
| sql( | ||
| """ | ||
| |select t1.x | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: No strict rule though, I like capitalized words for SQL keywords: e.g., select -> SELECT, ....
| Test build #111377 has finished for PR 25902 at commit  
 | 
| Test build #111378 has finished for PR 25902 at commit  
 | 
| Test build #111393 has finished for PR 25902 at commit  
 | 
| @wangshuo128 Can you resolve the conflict by git-rebase? | 
| 
 Ok | 
… null checks in FilteExec
f33d34c    to
    124ad87      
    Compare
  
    | Test build #111460 has finished for PR 25902 at commit  
 | 
| thanks, merging to master/2.4! | 
| Thanks! | 
Currently the behavior of getting output and generating null checks in `FilterExec` is different. Thus some nullable attribute could be treated as not nullable by mistake.
In `FilterExec.ouput`, an attribute is marked as nullable or not by finding its `exprId` in notNullAttributes:
```
a.nullable && notNullAttributes.contains(a.exprId)
```
But in `FilterExec.doConsume`,  a `nullCheck` is generated or not for a predicate is decided by whether there is semantic equal not null predicate:
```
      val nullChecks = c.references.map { r =>
        val idx = notNullPreds.indexWhere { n => n.asInstanceOf[IsNotNull].child.semanticEquals(r)}
        if (idx != -1 && !generatedIsNotNullChecks(idx)) {
          generatedIsNotNullChecks(idx) = true
          // Use the child's output. The nullability is what the child produced.
          genPredicate(notNullPreds(idx), input, child.output)
        } else {
          ""
        }
      }.mkString("\n").trim
```
NPE will happen when run the SQL below:
```
sql("create table table1(x string)")
sql("create table table2(x bigint)")
sql("create table table3(x string)")
sql("insert into table2 select null as x")
sql(
  """
    |select t1.x
    |from (
    |    select x from table1) t1
    |left join (
    |    select x from (
    |        select x from table2
    |        union all
    |        select substr(x,5) x from table3
    |    ) a
    |    where length(x)>0
    |) t3
    |on t1.x=t3.x
  """.stripMargin).collect()
```
NPE Exception:
```
java.lang.NullPointerException
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(generated.java:40)
    at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
    at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:726)
    at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
    at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:135)
    at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
    at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:94)
    at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52)
    at org.apache.spark.scheduler.Task.run(Task.scala:127)
    at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:449)
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1377)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:452)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:748)
```
the generated code:
```
== Subtree 4 / 5 ==
*(2) Project [cast(x#7L as string) AS x#9]
+- *(2) Filter ((length(cast(x#7L as string)) > 0) AND isnotnull(cast(x#7L as string)))
   +- Scan hive default.table2 [x#7L], HiveTableRelation `default`.`table2`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [x#7L]
Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIteratorForCodegenStage2(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=2
/* 006 */ final class GeneratedIteratorForCodegenStage2 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private scala.collection.Iterator inputadapter_input_0;
/* 010 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] filter_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2];
/* 011 */
/* 012 */   public GeneratedIteratorForCodegenStage2(Object[] references) {
/* 013 */     this.references = references;
/* 014 */   }
/* 015 */
/* 016 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 017 */     partitionIndex = index;
/* 018 */     this.inputs = inputs;
/* 019 */     inputadapter_input_0 = inputs[0];
/* 020 */     filter_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 021 */     filter_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32);
/* 022 */
/* 023 */   }
/* 024 */
/* 025 */   protected void processNext() throws java.io.IOException {
/* 026 */     while ( inputadapter_input_0.hasNext()) {
/* 027 */       InternalRow inputadapter_row_0 = (InternalRow) inputadapter_input_0.next();
/* 028 */
/* 029 */       do {
/* 030 */         boolean inputadapter_isNull_0 = inputadapter_row_0.isNullAt(0);
/* 031 */         long inputadapter_value_0 = inputadapter_isNull_0 ?
/* 032 */         -1L : (inputadapter_row_0.getLong(0));
/* 033 */
/* 034 */         boolean filter_isNull_2 = inputadapter_isNull_0;
/* 035 */         UTF8String filter_value_2 = null;
/* 036 */         if (!inputadapter_isNull_0) {
/* 037 */           filter_value_2 = UTF8String.fromString(String.valueOf(inputadapter_value_0));
/* 038 */         }
/* 039 */         int filter_value_1 = -1;
/* 040 */         filter_value_1 = (filter_value_2).numChars();
/* 041 */
/* 042 */         boolean filter_value_0 = false;
/* 043 */         filter_value_0 = filter_value_1 > 0;
/* 044 */         if (!filter_value_0) continue;
/* 045 */
/* 046 */         boolean filter_isNull_6 = inputadapter_isNull_0;
/* 047 */         UTF8String filter_value_6 = null;
/* 048 */         if (!inputadapter_isNull_0) {
/* 049 */           filter_value_6 = UTF8String.fromString(String.valueOf(inputadapter_value_0));
/* 050 */         }
/* 051 */         if (!(!filter_isNull_6)) continue;
/* 052 */
/* 053 */         ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1);
/* 054 */
/* 055 */         boolean project_isNull_0 = false;
/* 056 */         UTF8String project_value_0 = null;
/* 057 */         if (!false) {
/* 058 */           project_value_0 = UTF8String.fromString(String.valueOf(inputadapter_value_0));
/* 059 */         }
/* 060 */         filter_mutableStateArray_0[1].reset();
/* 061 */
/* 062 */         filter_mutableStateArray_0[1].zeroOutNullBytes();
/* 063 */
/* 064 */         if (project_isNull_0) {
/* 065 */           filter_mutableStateArray_0[1].setNullAt(0);
/* 066 */         } else {
/* 067 */           filter_mutableStateArray_0[1].write(0, project_value_0);
/* 068 */         }
/* 069 */         append((filter_mutableStateArray_0[1].getRow()));
/* 070 */
/* 071 */       } while(false);
/* 072 */       if (shouldStop()) return;
/* 073 */     }
/* 074 */   }
/* 075 */
/* 076 */ }
```
This PR proposes to use semantic comparison both in `FilterExec.output` and `FilterExec.doConsume` for nullable attribute.
With this PR, the generated code snippet is below:
```
== Subtree 2 / 5 ==
*(3) Project [substring(x#8, 5, 2147483647) AS x#5]
+- *(3) Filter ((length(substring(x#8, 5, 2147483647)) > 0) AND isnotnull(substring(x#8, 5, 2147483647)))
   +- Scan hive default.table3 [x#8], HiveTableRelation `default`.`table3`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [x#8]
Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIteratorForCodegenStage3(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=3
/* 006 */ final class GeneratedIteratorForCodegenStage3 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private scala.collection.Iterator inputadapter_input_0;
/* 010 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] filter_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2];
/* 011 */
/* 012 */   public GeneratedIteratorForCodegenStage3(Object[] references) {
/* 013 */     this.references = references;
/* 014 */   }
/* 015 */
/* 016 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 017 */     partitionIndex = index;
/* 018 */     this.inputs = inputs;
/* 019 */     inputadapter_input_0 = inputs[0];
/* 020 */     filter_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32);
/* 021 */     filter_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32);
/* 022 */
/* 023 */   }
/* 024 */
/* 025 */   protected void processNext() throws java.io.IOException {
/* 026 */     while ( inputadapter_input_0.hasNext()) {
/* 027 */       InternalRow inputadapter_row_0 = (InternalRow) inputadapter_input_0.next();
/* 028 */
/* 029 */       do {
/* 030 */         boolean inputadapter_isNull_0 = inputadapter_row_0.isNullAt(0);
/* 031 */         UTF8String inputadapter_value_0 = inputadapter_isNull_0 ?
/* 032 */         null : (inputadapter_row_0.getUTF8String(0));
/* 033 */
/* 034 */         boolean filter_isNull_0 = true;
/* 035 */         boolean filter_value_0 = false;
/* 036 */         boolean filter_isNull_2 = true;
/* 037 */         UTF8String filter_value_2 = null;
/* 038 */
/* 039 */         if (!inputadapter_isNull_0) {
/* 040 */           filter_isNull_2 = false; // resultCode could change nullability.
/* 041 */           filter_value_2 = inputadapter_value_0.substringSQL(5, 2147483647);
/* 042 */
/* 043 */         }
/* 044 */         boolean filter_isNull_1 = filter_isNull_2;
/* 045 */         int filter_value_1 = -1;
/* 046 */
/* 047 */         if (!filter_isNull_2) {
/* 048 */           filter_value_1 = (filter_value_2).numChars();
/* 049 */         }
/* 050 */         if (!filter_isNull_1) {
/* 051 */           filter_isNull_0 = false; // resultCode could change nullability.
/* 052 */           filter_value_0 = filter_value_1 > 0;
/* 053 */
/* 054 */         }
/* 055 */         if (filter_isNull_0 || !filter_value_0) continue;
/* 056 */         boolean filter_isNull_8 = true;
/* 057 */         UTF8String filter_value_8 = null;
/* 058 */
/* 059 */         if (!inputadapter_isNull_0) {
/* 060 */           filter_isNull_8 = false; // resultCode could change nullability.
/* 061 */           filter_value_8 = inputadapter_value_0.substringSQL(5, 2147483647);
/* 062 */
/* 063 */         }
/* 064 */         if (!(!filter_isNull_8)) continue;
/* 065 */
/* 066 */         ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1);
/* 067 */
/* 068 */         boolean project_isNull_0 = true;
/* 069 */         UTF8String project_value_0 = null;
/* 070 */
/* 071 */         if (!inputadapter_isNull_0) {
/* 072 */           project_isNull_0 = false; // resultCode could change nullability.
/* 073 */           project_value_0 = inputadapter_value_0.substringSQL(5, 2147483647);
/* 074 */
/* 075 */         }
/* 076 */         filter_mutableStateArray_0[1].reset();
/* 077 */
/* 078 */         filter_mutableStateArray_0[1].zeroOutNullBytes();
/* 079 */
/* 080 */         if (project_isNull_0) {
/* 081 */           filter_mutableStateArray_0[1].setNullAt(0);
/* 082 */         } else {
/* 083 */           filter_mutableStateArray_0[1].write(0, project_value_0);
/* 084 */         }
/* 085 */         append((filter_mutableStateArray_0[1].getRow()));
/* 086 */
/* 087 */       } while(false);
/* 088 */       if (shouldStop()) return;
/* 089 */     }
/* 090 */   }
/* 091 */
/* 092 */ }
```
Fix NPE bug in FilterExec.
no
new UT
Closes #25902 from wangshuo128/filter-codegen-npe.
Authored-by: Wang Shuo <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit bd28e8e)
Signed-off-by: Wenchen Fan <[email protected]>
    
What changes were proposed in this pull request?
Currently the behavior of getting output and generating null checks in
FilterExecis different. Thus some nullable attribute could be treated as not nullable by mistake.In
FilterExec.ouput, an attribute is marked as nullable or not by finding itsexprIdin notNullAttributes:But in
FilterExec.doConsume, anullCheckis generated or not for a predicate is decided by whether there is semantic equal not null predicate:NPE will happen when run the SQL below:
NPE Exception:
the generated code:
This PR proposes to use semantic comparison both in
FilterExec.outputandFilterExec.doConsumefor nullable attribute.With this PR, the generated code snippet is below:
Why are the changes needed?
Fix NPE bug in FilterExec.
Does this PR introduce any user-facing change?
no
How was this patch tested?
new UT