Skip to content

Commit f9c38cf

Browse files
committed
take nullable into consideration
1 parent 9c66274 commit f9c38cf

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,9 +1083,12 @@ class Analyzer(
10831083
assert(parameterTypes.length == inputs.length)
10841084

10851085
val inputsNullCheck = parameterTypes.zip(inputs)
1086-
.filter(_._1.isPrimitive)
1087-
.map(i => IsNull(i._2))
1088-
.reduceLeftOption[Expression]((i1, i2) => Or(i1, i2))
1086+
// TODO: skip null handling for not-nullable primitive inputs after we can completely
1087+
// trust the `nullable` information.
1088+
// .filter { case (cls, expr) => cls.isPrimitive && expr.nullable }
1089+
.filter { case (cls, _) => cls.isPrimitive }
1090+
.map { case (_, expr) => IsNull(expr) }
1091+
.reduceLeftOption[Expression]((e1, e2) => Or(e1, e2))
10891092
inputsNullCheck.map(If(_, Literal.create(null, udf.dataType), udf)).getOrElse(udf)
10901093
}
10911094
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,19 +188,34 @@ class AnalysisSuite extends AnalysisTest {
188188
)
189189
}
190190

191+
// non-primitive parameters do not need special null handling
191192
val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil)
192193
val expected1 = udf1
193194
checkUDF(udf1, expected1)
194195

196+
// only primitive parameter needs special null handling
195197
val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil)
196198
val expected2 = If(IsNull(double), nullResult, udf2)
197199
checkUDF(udf2, expected2)
198200

201+
// special null handling should apply to all primitive parameters
199202
val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil)
200203
val expected3 = If(
201204
IsNull(short) || IsNull(double),
202205
nullResult,
203206
udf3)
204207
checkUDF(udf3, expected3)
208+
209+
// we can skip special null handling for primitive parameters that are not nullable
210+
// TODO: this is disabled for now as we can not completely trust `nullable`.
211+
val udf4 = ScalaUDF(
212+
(s: Short, d: Double) => "x",
213+
StringType,
214+
short :: double.withNullability(false) :: Nil)
215+
val expected4 = If(
216+
IsNull(short),
217+
nullResult,
218+
udf4)
219+
// checkUDF(udf4, expected4)
205220
}
206221
}

0 commit comments

Comments
 (0)