From a5594f7ffcbdc9ab2e83008a99d5878fa9fae2b8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 2 Dec 2016 14:41:17 +0000 Subject: [PATCH] Change GenerateExec's output so PySpark's UDF can work with Generator. --- python/pyspark/sql/tests.py | 20 +++++++++++++++++++ .../plans/logical/basicLogicalOperators.scala | 12 +++++------ .../spark/sql/execution/GenerateExec.scala | 15 +++++++++++--- .../spark/sql/execution/SparkStrategies.scala | 3 ++- 4 files changed, 40 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b7b2a5923c07..de5555d0061c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -384,6 +384,26 @@ def test_udf_in_generate(self): row = df.select(explode(f(*df))).groupBy().sum().first() self.assertEqual(row[0], 10) + df = self.spark.range(3) + res = df.select("id", explode(f(df.id))).collect() + self.assertEqual(res[0][0], 1) + self.assertEqual(res[0][1], 0) + self.assertEqual(res[1][0], 2) + self.assertEqual(res[1][1], 0) + self.assertEqual(res[2][0], 2) + self.assertEqual(res[2][1], 1) + + range_udf = udf(lambda value: list(range(value - 1, value + 1)), ArrayType(IntegerType())) + res = df.select("id", explode(range_udf(df.id))).collect() + self.assertEqual(res[0][0], 0) + self.assertEqual(res[0][1], -1) + self.assertEqual(res[1][0], 0) + self.assertEqual(res[1][1], 0) + self.assertEqual(res[2][0], 1) + self.assertEqual(res[2][1], 0) + self.assertEqual(res[3][0], 1) + self.assertEqual(res[3][1], 1) + def test_udf_with_order_by_and_limit(self): from pyspark.sql.functions import udf my_copy = udf(lambda x: x, IntegerType()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 7aaefc8529a5..324662e5bda8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -93,13 +93,13 @@ case class Generate( override def producedAttributes: AttributeSet = AttributeSet(generatorOutput) - def output: Seq[Attribute] = { - val qualified = qualifier.map(q => - // prepend the new qualifier to the existed one - generatorOutput.map(a => a.withQualifier(Some(q))) - ).getOrElse(generatorOutput) + val qualifiedGeneratorOutput: Seq[Attribute] = qualifier.map { q => + // prepend the new qualifier to the existed one + generatorOutput.map(a => a.withQualifier(Some(q))) + }.getOrElse(generatorOutput) - if (join) child.output ++ qualified else qualified + def output: Seq[Attribute] = { + if (join) child.output ++ qualifiedGeneratorOutput else qualifiedGeneratorOutput } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index f80214af43fc..04b16af4ea26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -51,17 +51,26 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In * it. * @param outer when true, each input row will be output at least once, even if the output of the * given `generator` is empty. `outer` has no effect when `join` is false. - * @param output the output attributes of this node, which constructed in analysis phase, - * and we can not change it, as the parent node bound with it already. + * @param generatorOutput the qualified output attributes of the generator of this node, which + * constructed in analysis phase, and we can not change it, as the + * parent node bound with it already. */ case class GenerateExec( generator: Generator, join: Boolean, outer: Boolean, - output: Seq[Attribute], + generatorOutput: Seq[Attribute], child: SparkPlan) extends UnaryExecNode with CodegenSupport { + override def output: Seq[Attribute] = { + if (join) { + child.output ++ generatorOutput + } else { + generatorOutput + } + } + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 2308ae8a6c61..d88cbdfbcfa0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -403,7 +403,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.UnionExec(unionChildren.map(planLater)) :: Nil case g @ logical.Generate(generator, join, outer, _, _, child) => execution.GenerateExec( - generator, join = join, outer = outer, g.output, planLater(child)) :: Nil + generator, join = join, outer = outer, g.qualifiedGeneratorOutput, + planLater(child)) :: Nil case logical.OneRowRelation => execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil case r: logical.Range =>