From 8b0ef1ebed2091ceacd1abb823fba97d6daa984e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 2 Dec 2016 05:50:47 +0000 Subject: [PATCH] Change the way to group row in BatchEvalPythonExec so udf works with input_file_name in pyspark. --- python/pyspark/sql/tests.py | 8 +++++ .../python/BatchEvalPythonExec.scala | 35 +++++++++---------- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b7b2a5923c07..55ba6a9cb9ec 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -392,6 +392,14 @@ def test_udf_with_order_by_and_limit(self): res.explain(True) self.assertEqual(res.collect(), [Row(id=0, copy=0)]) + def test_udf_with_input_file_name(self): + from pyspark.sql.functions import udf, input_file_name + from pyspark.sql.types import StringType + sourceFile = udf(lambda path: path, StringType()) + filePath = "python/test_support/sql/people1.json" + row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first() + self.assertTrue(row[0].find("people1.json") != -1) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.spark.read.json(rdd) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index dcaf2c76d479..7a5ac48f1b69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -119,26 +119,23 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val pickle = new Pickler(needConversion) // Input iterator to Python: input rows are grouped so we send them in batches to Python. // For each row, add it to the queue. - val inputIterator = iter.grouped(100).map { inputRows => - val toBePickled = inputRows.map { inputRow => - queue.add(inputRow.asInstanceOf[UnsafeRow]) - val row = projection(inputRow) - if (needConversion) { - EvaluatePython.toJava(row, schema) - } else { - // fast path for these types that does not need conversion in Python - val fields = new Array[Any](row.numFields) - var i = 0 - while (i < row.numFields) { - val dt = dataTypes(i) - fields(i) = EvaluatePython.toJava(row.get(i, dt), dt) - i += 1 - } - fields + val inputIterator = iter.map { inputRow => + queue.add(inputRow.asInstanceOf[UnsafeRow]) + val row = projection(inputRow) + if (needConversion) { + EvaluatePython.toJava(row, schema) + } else { + // fast path for these types that does not need conversion in Python + val fields = new Array[Any](row.numFields) + var i = 0 + while (i < row.numFields) { + val dt = dataTypes(i) + fields(i) = EvaluatePython.toJava(row.get(i, dt), dt) + i += 1 } - }.toArray - pickle.dumps(toBePickled) - } + fields + } + }.grouped(100).map(x => pickle.dumps(x.toArray)) val context = TaskContext.get()