Skip to content

Commit fecd23d

Browse files
viiryahvanhovell
authored andcommitted
[SPARK-18634][PYSPARK][SQL] Corruption and Correctness issues with exploding Python UDFs
## What changes were proposed in this pull request? As reported in the Jira, there are some weird issues with exploding Python UDFs in SparkSQL. The following test code can reproduce it. Notice: the following test code is reported to return wrong results in the Jira. However, as I tested on master branch, it causes exception and so can't return any result. >>> from pyspark.sql.functions import * >>> from pyspark.sql.types import * >>> >>> df = spark.range(10) >>> >>> def return_range(value): ... return [(i, str(i)) for i in range(value - 1, value + 1)] ... >>> range_udf = udf(return_range, ArrayType(StructType([StructField("integer_val", IntegerType()), ... StructField("string_val", StringType())]))) >>> >>> df.select("id", explode(range_udf(df.id))).show() Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/spark/python/pyspark/sql/dataframe.py", line 318, in show print(self._jdf.showString(n, 20)) File "/spark/python/lib/py4j-0.10.4-src.zip/py4j/java_gateway.py", line 1133, in __call__ File "/spark/python/pyspark/sql/utils.py", line 63, in deco return f(*a, **kw) File "/spark/python/lib/py4j-0.10.4-src.zip/py4j/protocol.py", line 319, in get_return_value py4j.protocol.Py4JJavaError: An error occurred while calling o126.showString.: java.lang.AssertionError: assertion failed at scala.Predef$.assert(Predef.scala:156) at org.apache.spark.sql.execution.CodegenSupport$class.consume(WholeStageCodegenExec.scala:120) at org.apache.spark.sql.execution.GenerateExec.consume(GenerateExec.scala:57) The cause of this issue is, in `ExtractPythonUDFs` we insert `BatchEvalPythonExec` to run PythonUDFs in batch. `BatchEvalPythonExec` will add extra outputs (e.g., `pythonUDF0`) to original plan. In above case, the original `Range` only has one output `id`. After `ExtractPythonUDFs`, the added `BatchEvalPythonExec` has two outputs `id` and `pythonUDF0`. Because the output of `GenerateExec` is given after analysis phase, in above case, it is the combination of `id`, i.e., the output of `Range`, and `col`. But in planning phase, we change `GenerateExec`'s child plan to `BatchEvalPythonExec` with additional output attributes. It will cause no problem in non wholestage codegen. Because when evaluating the additional attributes are projected out the final output of `GenerateExec`. However, as `GenerateExec` now supports wholestage codegen, the framework will input all the outputs of the child plan to `GenerateExec`. Then when consuming `GenerateExec`'s output data (i.e., calling `consume`), the number of output attributes is different to the output variables in wholestage codegen. To solve this issue, this patch only gives the generator's output to `GenerateExec` after analysis phase. `GenerateExec`'s output is the combination of its child plan's output and the generator's output. So when we change `GenerateExec`'s child, its output is still correct. ## How was this patch tested? Added test cases to PySpark. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh <[email protected]> Closes #16120 from viirya/fix-py-udf-with-generator. (cherry picked from commit 3ba69b6) Signed-off-by: Herman van Hovell <[email protected]>
1 parent c6a4e3d commit fecd23d

File tree

4 files changed

+40
-10
lines changed

4 files changed

+40
-10
lines changed

python/pyspark/sql/tests.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,26 @@ def test_udf_in_generate(self):
384384
row = df.select(explode(f(*df))).groupBy().sum().first()
385385
self.assertEqual(row[0], 10)
386386

387+
df = self.spark.range(3)
388+
res = df.select("id", explode(f(df.id))).collect()
389+
self.assertEqual(res[0][0], 1)
390+
self.assertEqual(res[0][1], 0)
391+
self.assertEqual(res[1][0], 2)
392+
self.assertEqual(res[1][1], 0)
393+
self.assertEqual(res[2][0], 2)
394+
self.assertEqual(res[2][1], 1)
395+
396+
range_udf = udf(lambda value: list(range(value - 1, value + 1)), ArrayType(IntegerType()))
397+
res = df.select("id", explode(range_udf(df.id))).collect()
398+
self.assertEqual(res[0][0], 0)
399+
self.assertEqual(res[0][1], -1)
400+
self.assertEqual(res[1][0], 0)
401+
self.assertEqual(res[1][1], 0)
402+
self.assertEqual(res[2][0], 1)
403+
self.assertEqual(res[2][1], 0)
404+
self.assertEqual(res[3][0], 1)
405+
self.assertEqual(res[3][1], 1)
406+
387407
def test_udf_with_order_by_and_limit(self):
388408
from pyspark.sql.functions import udf
389409
my_copy = udf(lambda x: x, IntegerType())

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,13 @@ case class Generate(
9494

9595
override def producedAttributes: AttributeSet = AttributeSet(generatorOutput)
9696

97-
def output: Seq[Attribute] = {
98-
val qualified = qualifier.map(q =>
99-
// prepend the new qualifier to the existed one
100-
generatorOutput.map(a => a.withQualifier(Some(q)))
101-
).getOrElse(generatorOutput)
97+
val qualifiedGeneratorOutput: Seq[Attribute] = qualifier.map { q =>
98+
// prepend the new qualifier to the existed one
99+
generatorOutput.map(a => a.withQualifier(Some(q)))
100+
}.getOrElse(generatorOutput)
102101

103-
if (join) child.output ++ qualified else qualified
102+
def output: Seq[Attribute] = {
103+
if (join) child.output ++ qualifiedGeneratorOutput else qualifiedGeneratorOutput
104104
}
105105
}
106106

sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,26 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In
4545
* it.
4646
* @param outer when true, each input row will be output at least once, even if the output of the
4747
* given `generator` is empty. `outer` has no effect when `join` is false.
48-
* @param output the output attributes of this node, which constructed in analysis phase,
49-
* and we can not change it, as the parent node bound with it already.
48+
* @param generatorOutput the qualified output attributes of the generator of this node, which
49+
* constructed in analysis phase, and we can not change it, as the
50+
* parent node bound with it already.
5051
*/
5152
case class GenerateExec(
5253
generator: Generator,
5354
join: Boolean,
5455
outer: Boolean,
55-
output: Seq[Attribute],
56+
generatorOutput: Seq[Attribute],
5657
child: SparkPlan)
5758
extends UnaryExecNode {
5859

60+
override def output: Seq[Attribute] = {
61+
if (join) {
62+
child.output ++ generatorOutput
63+
} else {
64+
generatorOutput
65+
}
66+
}
67+
5968
override lazy val metrics = Map(
6069
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
6170

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
403403
execution.UnionExec(unionChildren.map(planLater)) :: Nil
404404
case g @ logical.Generate(generator, join, outer, _, _, child) =>
405405
execution.GenerateExec(
406-
generator, join = join, outer = outer, g.output, planLater(child)) :: Nil
406+
generator, join = join, outer = outer, g.qualifiedGeneratorOutput,
407+
planLater(child)) :: Nil
407408
case logical.OneRowRelation =>
408409
execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil
409410
case r: logical.Range =>

0 commit comments

Comments
 (0)