From 176cd150fdcba20433cf8b1537957b4b00f2c4b1 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 28 Sep 2017 11:39:29 -0400 Subject: [PATCH 01/34] Initial commit of groupby apply --- python/pyspark/sql/functions.py | 3 +- python/pyspark/sql/group.py | 22 ++ python/pyspark/sql/tests.py | 211 +++++++++++++++++- python/pyspark/worker.py | 38 +++- .../catalyst/expressions/AttributeSet.scala | 3 + .../sql/catalyst/optimizer/Optimizer.scala | 2 + .../sql/catalyst/plans/logical/object.scala | 19 +- .../spark/sql/RelationalGroupedDataset.scala | 33 ++- .../spark/sql/execution/SparkStrategies.scala | 2 + .../python/ArrowEvalPythonExec.scala | 9 +- .../execution/python/ArrowPythonRunner.scala | 15 +- .../execution/python/ExtractPythonUDFs.scala | 5 +- .../python/FlatMapGroupsInPandasExec.scala | 94 ++++++++ 13 files changed, 420 insertions(+), 36 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b45a59db9367..4bc7b5fbce48 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2129,7 +2129,8 @@ def _create_udf(f, returnType, vectorized): def _udf(f, returnType=StringType(), vectorized=vectorized): if vectorized: import inspect - if len(inspect.getargspec(f).args) == 0: + argspec = inspect.getargspec(f) + if len(argspec.args) == 0 and argspec.varargs is None: raise NotImplementedError("0-parameter pandas_udfs are not currently supported") udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized) return udf_obj._wrapped() diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index f2092f9c6305..35893a109ded 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -194,6 +194,28 @@ def pivot(self, pivot_col, values=None): jgd = self._jgd.pivot(pivot_col, values) return GroupedData(jgd, self.sql_ctx) + def apply(self, udf_obj): + """ + Maps each group of the current [[DataFrame]] using a pandas udf and returns the result as a :class:`DataFrame`. + + """ + from pyspark.sql.functions import pandas_udf + + df = DataFrame(self._jgd.df(), self.sql_ctx) + func = udf_obj.func + returnType = udf_obj.returnType + + # The python executors expects the function to take a list of pd.Series as input + # So we to create a wrapper function that turns that to a pd.DataFrame before passing down to the user function + columns = df.columns + def wrapped(*cols): + import pandas as pd + return func(pd.concat(cols, axis=1, keys=columns)) + + wrapped_udf_obj = pandas_udf(wrapped, returnType) + udf_column = wrapped_udf_obj(*[df[col] for col in df.columns]) + jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) + return DataFrame(jdf, self.sql_ctx) def _test(): import doctest diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1b3af42c47ad..64ed0fad522a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3106,10 +3106,11 @@ def assertFramesEqual(self, df_with_arrow, df_without): self.assertTrue(df_without.equals(df_with_arrow), msg=msg) def test_unsupported_datatype(self): - schema = StructType([StructField("dt", DateType(), True)]) - df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema) - with QuietTest(self.sc): - self.assertRaises(Exception, lambda: df.toPandas()) + schema = StructType([StructField("dt", TimestampType(), True)]) + df = self.spark.createDataFrame([(datetime.datetime(1970, 1, 1),)], schema=schema) + + #with QuietTest(self.sc): + # self.assertRaises(Exception, lambda: df.toPandas()) def test_null_conversion(self): df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + @@ -3147,6 +3148,180 @@ def test_filtered_frame(self): self.assertEqual(pdf.columns[0], "i") self.assertTrue(pdf.empty) + def test_groupby_apply(self): + from pyspark.sql.functions import col, udf, pandas_udf, sum + from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType + df1 = self.spark.createDataFrame(self.data, schema=self.schema) + import pandas as pd + + expected = (df1.withColumn('4_long_t', df1['2_int_t'] * df1['3_long_t']) + .select('1_str_t', '2_int_t', '3_long_t', '4_long_t') + .toPandas()) + + result_schema = StructType([ + StructField('1_str_t', StringType()), + StructField('2_int_t', IntegerType()), + StructField('3_long_t', LongType()), + StructField('4_long_t', LongType()) + ]) + + @pandas_udf(result_schema) + def foo(pdf): + pdf['4_long_t'] = pdf['2_int_t'] * pdf['3_long_t'] * 1.0 + return pdf + + result = (df1.groupby('1_str_t') + .apply(foo(df1[['1_str_t', '2_int_t', '3_long_t']])) + .sort('1_str_t') + .toPandas()) + + def foo2(pdf): + pdf['4_long_t'] = pdf['2_int_t'] * pdf['3_long_t'] + return pdf[['1_str_t', '2_int_t', '3_long_t', '4_long_t']] + + foo2_udf = pandas_udf(foo2, result_schema) + result2 = (df1.groupby('1_str_t') + .apply(foo2_udf) + .sort('1_str_t') + .toPandas()) + + @pandas_udf(add=('4_long_t', LongType())) + def foo3(pdf): + pdf['4_long_t'] = pdf['2_int_t'] * pdf['3_long_t'] + return pdf + result3 = (df1.groupby('1_str_t') + .apply(foo3) + .sort('1_str_t') + .select('1_str_t', '2_int_t', '3_long_t', '4_long_t') + .toPandas()) + + #@pandas_udf(add=[('4_long_t', LongType())]) + #def foo4(pdf): + # pdf['4_long_t'] = pdf['2_int_t'] * pdf['3_long_t'] * 1.0 + # return pdf + #result4 = (df1.groupby('1_str_t') + # .apply(foo4(df1[['1_str_t', '2_int_t', '3_long_t']])) + # .sort('1_str_t') + # .toPandas()) + + @pandas_udf([('1_str_t', StringType()), + ('2_int_t', IntegerType()), + ('3_long_t', LongType()), + ('4_long_t', LongType())]) + def foo5(pdf): + pdf['4_long_t'] = pdf['2_int_t'] * pdf['3_long_t'] + return pdf + result5 = (df1.select('1_str_t', '2_int_t', '3_long_t') + .groupby('1_str_t') + .apply(foo5) + .sort('1_str_t') + .toPandas()) + + def foo6(pdf): + pdf['4_long_t'] = pdf['2_int_t'] * pdf['3_long_t'] + return pdf + + foo6_udf = pandas_udf( + foo6, + [('1_str_t', StringType()), + ('2_int_t', IntegerType()), + ('3_long_t', LongType()), + ('4_long_t', LongType()) + ]) + + result6 = (df1.select('1_str_t', '2_int_t', '3_long_t') + .groupby('1_str_t') + .apply(foo6_udf) + .sort('1_str_t') + .toPandas()) + + def foo7(pdf): + pdf['4_long_t'] = pdf['2_int_t'] * pdf['3_long_t'] + return pdf + + foo7_udf = pandas_udf(foo7, expected.dtypes) + + result7 = (df1.select('1_str_t', '2_int_t', '3_long_t') + .groupby('1_str_t') + .apply(foo7_udf) + .sort('1_str_t') + .toPandas()) + + self.assertFramesEqual(result, expected) + self.assertFramesEqual(result2, expected) + self.assertFramesEqual(result3, expected) + #self.assertFramesEqual(result4, expected) + self.assertFramesEqual(result5, expected) + self.assertFramesEqual(result6, expected) + self.assertFramesEqual(result7, expected) + + def test_groupby_apply_timestamp(self): + from pyspark.sql.functions import col, udf, pandas_udf + from pyspark.sql.types import TimestampType, DateType + import datetime + + df = self.spark.createDataFrame(self.data, schema=self.schema) + df1 = df.select('3_long_t').withColumn('time', df['3_long_t'].cast(TimestampType())) + pdf1 = df1.toPandas() + + def foo(pdf): + return pdf.assign(time=pdf['time'] + datetime.timedelta(days=15)) + foo_udf = pandas_udf(foo, foo(pdf1).dtypes) + + result = df1.groupby('3_long_t').apply(foo_udf).sort('3_long_t').toPandas() + expected = foo(pdf1) + + self.assertFramesEqual(result, expected) + + def test_groupby_apply_series(self): + from pyspark.sql.functions import col, udf, pandas_udf + from pyspark.sql.types import DoubleType + import pandas as pd + + df = self.spark.createDataFrame(self.data, schema=self.schema) + + expected = pd.DataFrame({ + '1_str_t': pd.Series(['a', 'b', 'c']), + 'v1': pd.Series([0.0, 0.0, 0.0]), + 'v2': pd.Series([1.0, 1.0, 1.0]), + }) + + def foo1(pdf): + return pd.Series([0.0, 1.0]) + + foo1_udf = pandas_udf(foo1, group_add=[('v1', DoubleType()), ('v2', DoubleType())]) + result1 = df.groupby('1_str_t').apply(foo1_udf).sort('1_str_t').toPandas() + self.assertFramesEqual(result1, expected) + + def test_groupby_apply_cache(self): + from pyspark.sql.functions import col, udf, pandas_udf + from pyspark.sql.types import DoubleType + import pandas as pd + + df = self.spark.createDataFrame(self.data, schema=self.schema) + + def foo1(pdf): + return pd.Series([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]) + + foo1_udf = pandas_udf(foo1, group_add=[ + ('v1', DoubleType()), + ('v2', DoubleType()), + ('v3', DoubleType()), + ('v4', DoubleType()), + ('v5', DoubleType()), + ('v6', DoubleType()), + ('v7', DoubleType()), + ('v8', DoubleType()), + ('v9', DoubleType()), + ('v10', DoubleType()), + ('v11', DoubleType()), + ('v12', DoubleType()), + ('v13', DoubleType()), + ]) + result1 = df.groupby('1_str_t').apply(foo1_udf).sort('1_str_t').cache() + result1.count() + result1.show() + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class VectorizedUDFTests(ReusedPySparkTestCase): @@ -3376,6 +3551,34 @@ def test_vectorized_udf_empty_partition(self): res = df.select(f(col('id'))) self.assertEquals(df.collect(), res.collect()) +@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +class GroupbyApplyTests(ReusedPySparkTestCase): + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + cls.spark.stop() + + def test_groupby_apply(self): + from pyspark.sql.functions import pandas_udf, array, explode, col, lit + df = self.spark.range(10).toDF('id').withColumn("vs", array([lit(i) for i in range(20, 30)])).withColumn("v", explode(col('vs'))).drop('vs') + + def foo(df): + import pandas as pd + return pd.DataFrame({'mean': [df.v.mean()], 'std': [df.v.std()]}) + + foo_udf = pandas_udf( + foo, + StructType([StructField('mean', DoubleType()), StructField('std', DoubleType())])) + + df2 = df.groupby('id').apply(foo_udf) + df2.show(1000) + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 4e24789cf010..6e89f1da5910 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -34,6 +34,7 @@ BatchedSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import toArrowType from pyspark import shuffle +from pyspark.sql.types import StructType, IntegerType, LongType, FloatType, DoubleType pickleSer = PickleSerializer() utf8_deserializer = UTF8Deserializer() @@ -74,17 +75,32 @@ def wrap_udf(f, return_type): def wrap_pandas_udf(f, return_type): - arrow_return_type = toArrowType(return_type) - - def verify_result_length(*a): - result = f(*a) - if not hasattr(result, "__len__"): - raise TypeError("Return type of pandas_udf should be a Pandas.Series") - if len(result) != len(a[0]): - raise RuntimeError("Result vector from pandas_udf was not the required length: " - "expected %d, got %d" % (len(a[0]), len(result))) - return result - return lambda *a: (verify_result_length(*a), arrow_return_type) + if isinstance(return_type, StructType): + arrow_return_types = list(toArrowType(field.dataType) for field in return_type) + + def fn(*a): + import pandas as pd + out = f(*a) + assert isinstance(out, pd.DataFrame), 'Must return a pd.DataFrame' + assert len(out.columns) == len(arrow_return_types), \ + 'Columns of pd.DataFrame don\'t match return schema' + + return list((out[out.columns[i]], arrow_return_types[i]) for i in range(len(arrow_return_types))) + return fn + + else: + arrow_return_type = toArrowType(return_type) + + def verify_result_length(*a): + result = f(*a) + if not hasattr(result, "__len__"): + raise TypeError("Return type of pandas_udf should be a Pandas.Series") + if len(result) != len(a[0]): + raise RuntimeError("Result vector from pandas_udf was not the required length: " \ + "expected %d, got %d" % (len(a[0]), len(result))) + return result + + return lambda *a: (verify_result_length(*a), arrow_return_type) def read_single_udf(pickleSer, infile, eval_type): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index 7420b6b57d8e..d37fdc0d8354 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -37,6 +37,9 @@ object AttributeSet { /** Constructs a new [[AttributeSet]] that contains a single [[Attribute]]. */ def apply(a: Attribute): AttributeSet = new AttributeSet(Set(new AttributeEquals(a))) + def apply(as: Attribute*): AttributeSet = + new AttributeSet(Set(as.map(new AttributeEquals(_)): _*)) + /** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */ def apply(baseSet: Iterable[Expression]): AttributeSet = { new AttributeSet( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a391c513ad38..ced0820e179c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -444,6 +444,8 @@ object ColumnPruning extends Rule[LogicalPlan] { // Prunes the unused columns from child of Aggregate/Expand/Generate case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = prunedChild(child, a.references)) + case f @ FlatMapGroupsInPandas(_, _, _, child) if (child.outputSet -- f.references).nonEmpty => + f.copy(child = prunedChild(child, f.references)) case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => e.copy(child = prunedChild(child, e.references)) case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index bfb70c2ef4c8..52721eb341c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -24,9 +24,9 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.{Encoder, Row} import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{AttributeSet, _} import org.apache.spark.sql.catalyst.expressions.objects.Invoke -import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode } +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -519,3 +519,18 @@ case class CoGroup( outputObjAttr: Attribute, left: LogicalPlan, right: LogicalPlan) extends BinaryNode with ObjectProducer + +case class FlatMapGroupsInPandas( + groupingExprs: Seq[Expression], + functionExpr: Expression, + override val output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + /** + * This is needed because output attributes is considered `reference` when + * passed through the constructor. + * + * Without this, catalyst will complain that output attributes are missing + * from the input. + */ + override val producedAttributes = AttributeSet(output) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 147b54996491..a563ce8f0ebb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -27,12 +27,12 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, FlatMapGroupsInR, Pivot} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.execution.python.PythonUDF import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.NumericType -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{NumericType, StructField, StructType} /** * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]], @@ -47,8 +47,8 @@ import org.apache.spark.sql.types.StructType */ @InterfaceStability.Stable class RelationalGroupedDataset protected[sql]( - df: DataFrame, - groupingExprs: Seq[Expression], + val df: DataFrame, + val groupingExprs: Seq[Expression], groupType: RelationalGroupedDataset.GroupType) { private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { @@ -435,6 +435,29 @@ class RelationalGroupedDataset protected[sql]( df.logicalPlan.output, df.logicalPlan)) } + + private[sql] def flatMapGroupsInPandas( + expr: PythonUDF + ): DataFrame = { + val output = expr.dataType match { + case s: StructType => s.map { + case StructField(name, dataType, nullable, metadata) => + AttributeReference(name, dataType, nullable, metadata)() + } + } + + val plan = FlatMapGroupsInPandas( + groupingExprs, + expr, + output, + df.logicalPlan + ) + + Dataset.ofRows( + df.sparkSession, + plan + ) + } } private[sql] object RelationalGroupedDataset { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 92eaab5cd8f8..4cdcc73faacd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -392,6 +392,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.FlatMapGroupsInR(f, p, b, is, os, key, value, grouping, data, objAttr, child) => execution.FlatMapGroupsInRExec(f, p, b, is, os, key, value, grouping, data, objAttr, planLater(child)) :: Nil + case logical.FlatMapGroupsInPandas(grouping, func, output, child) => + execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil case logical.MapElements(f, _, _, objAttr, child) => execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil case logical.AppendColumns(f, _, _, in, out, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index f7e8cbe41612..7bba0ff300a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -44,14 +44,17 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex .map { case (attr, i) => attr.withName(s"_$i") }) + val batchedIter: Iterator[Iterator[InternalRow]] = + iter.grouped(conf.arrowMaxRecordsPerBatch).map(_.iterator) + val columnarBatchIter = new ArrowPythonRunner( - funcs, conf.arrowMaxRecordsPerBatch, bufferSize, reuseWorker, + funcs, bufferSize, reuseWorker, PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema) - .compute(iter, context.partitionId(), context) + .compute(batchedIter, context.partitionId(), context) new Iterator[InternalRow] { - var currentIter = if (columnarBatchIter.hasNext) { + private var currentIter = if (columnarBatchIter.hasNext) { val batch = columnarBatchIter.next() assert(schemaOut.equals(batch.schema), s"Invalid schema from pandas_udf: expected $schemaOut, got ${batch.schema}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index bbad9d6b631f..f6c03c415dc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -39,19 +39,18 @@ import org.apache.spark.util.Utils */ class ArrowPythonRunner( funcs: Seq[ChainedPythonFunctions], - batchSize: Int, bufferSize: Int, reuseWorker: Boolean, evalType: Int, argOffsets: Array[Array[Int]], schema: StructType) - extends BasePythonRunner[InternalRow, ColumnarBatch]( + extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( funcs, bufferSize, reuseWorker, evalType, argOffsets) { protected override def newWriterThread( env: SparkEnv, worker: Socket, - inputIterator: Iterator[InternalRow], + inputIterator: Iterator[Iterator[InternalRow]], partitionIndex: Int, context: TaskContext): WriterThread = { new WriterThread(env, worker, inputIterator, partitionIndex, context) { @@ -82,12 +81,12 @@ class ArrowPythonRunner( Utils.tryWithSafeFinally { while (inputIterator.hasNext) { - var rowCount = 0 - while (inputIterator.hasNext && (batchSize <= 0 || rowCount < batchSize)) { - val row = inputIterator.next() - arrowWriter.write(row) - rowCount += 1 + val nextBatch = inputIterator.next() + + while (nextBatch.hasNext) { + arrowWriter.write(nextBatch.next()) } + arrowWriter.finish() writer.writeBatch() arrowWriter.reset() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index fec456d86dbe..b28b1efad044 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution -import org.apache.spark.sql.execution.{FilterExec, SparkPlan} +import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} /** @@ -111,7 +111,8 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { } def apply(plan: SparkPlan): SparkPlan = plan transformUp { - case plan: SparkPlan => extract(plan) + case plan: ProjectExec => extract(plan) + case plan: FilterExec => extract(plan) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala new file mode 100644 index 000000000000..ab6c966ef3f3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.JavaConverters._ + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} + +case class FlatMapGroupsInPandasExec( + grouping: Seq[Expression], + func: Expression, + override val output: Seq[Attribute], + override val child: SparkPlan +) extends UnaryExecNode { + + val groupingAttributes: Seq[Attribute] = grouping.map { + case ne: NamedExpression => ne.toAttribute + } + + private val pandasFunction = func.asInstanceOf[PythonUDF].func + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def producedAttributes: AttributeSet = AttributeSet(output) + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(groupingAttributes) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingAttributes.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute() + + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) + val argOffsets = Array((0 until child.schema.length).toArray) + + inputRDD.mapPartitionsInternal { iter => + val grouped = GroupedIterator(iter, groupingAttributes, child.output) + val context = TaskContext.get() + + val columnarBatchIter = new ArrowPythonRunner( + chainedFunc, bufferSize, reuseWorker, + PythonEvalType.SQL_PANDAS_UDF, argOffsets, child.schema) + .compute(grouped.map(_._2), context.partitionId(), context) + + + new Iterator[InternalRow] { + private var currentIter = if (columnarBatchIter.hasNext) { + val batch = columnarBatchIter.next() + // assert(schemaOut.equals(batch.schema), + // s"Invalid schema from pandas_udf: expected $schemaOut, got ${batch.schema}") + batch.rowIterator.asScala + } else { + Iterator.empty + } + + override def hasNext: Boolean = currentIter.hasNext || { + if (columnarBatchIter.hasNext) { + currentIter = columnarBatchIter.next().rowIterator.asScala + hasNext + } else { + false + } + } + + override def next(): InternalRow = currentIter.next() + } + } + } +} From f109afb2bf43ac35e2af6811114b0141c5bb6a47 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 28 Sep 2017 13:50:16 -0400 Subject: [PATCH 02/34] Clean up tests --- python/pyspark/sql/tests.py | 199 ++---------------- .../python/FlatMapGroupsInPandasExec.scala | 7 +- 2 files changed, 23 insertions(+), 183 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 64ed0fad522a..ff0045238ffe 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3148,180 +3148,6 @@ def test_filtered_frame(self): self.assertEqual(pdf.columns[0], "i") self.assertTrue(pdf.empty) - def test_groupby_apply(self): - from pyspark.sql.functions import col, udf, pandas_udf, sum - from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType - df1 = self.spark.createDataFrame(self.data, schema=self.schema) - import pandas as pd - - expected = (df1.withColumn('4_long_t', df1['2_int_t'] * df1['3_long_t']) - .select('1_str_t', '2_int_t', '3_long_t', '4_long_t') - .toPandas()) - - result_schema = StructType([ - StructField('1_str_t', StringType()), - StructField('2_int_t', IntegerType()), - StructField('3_long_t', LongType()), - StructField('4_long_t', LongType()) - ]) - - @pandas_udf(result_schema) - def foo(pdf): - pdf['4_long_t'] = pdf['2_int_t'] * pdf['3_long_t'] * 1.0 - return pdf - - result = (df1.groupby('1_str_t') - .apply(foo(df1[['1_str_t', '2_int_t', '3_long_t']])) - .sort('1_str_t') - .toPandas()) - - def foo2(pdf): - pdf['4_long_t'] = pdf['2_int_t'] * pdf['3_long_t'] - return pdf[['1_str_t', '2_int_t', '3_long_t', '4_long_t']] - - foo2_udf = pandas_udf(foo2, result_schema) - result2 = (df1.groupby('1_str_t') - .apply(foo2_udf) - .sort('1_str_t') - .toPandas()) - - @pandas_udf(add=('4_long_t', LongType())) - def foo3(pdf): - pdf['4_long_t'] = pdf['2_int_t'] * pdf['3_long_t'] - return pdf - result3 = (df1.groupby('1_str_t') - .apply(foo3) - .sort('1_str_t') - .select('1_str_t', '2_int_t', '3_long_t', '4_long_t') - .toPandas()) - - #@pandas_udf(add=[('4_long_t', LongType())]) - #def foo4(pdf): - # pdf['4_long_t'] = pdf['2_int_t'] * pdf['3_long_t'] * 1.0 - # return pdf - #result4 = (df1.groupby('1_str_t') - # .apply(foo4(df1[['1_str_t', '2_int_t', '3_long_t']])) - # .sort('1_str_t') - # .toPandas()) - - @pandas_udf([('1_str_t', StringType()), - ('2_int_t', IntegerType()), - ('3_long_t', LongType()), - ('4_long_t', LongType())]) - def foo5(pdf): - pdf['4_long_t'] = pdf['2_int_t'] * pdf['3_long_t'] - return pdf - result5 = (df1.select('1_str_t', '2_int_t', '3_long_t') - .groupby('1_str_t') - .apply(foo5) - .sort('1_str_t') - .toPandas()) - - def foo6(pdf): - pdf['4_long_t'] = pdf['2_int_t'] * pdf['3_long_t'] - return pdf - - foo6_udf = pandas_udf( - foo6, - [('1_str_t', StringType()), - ('2_int_t', IntegerType()), - ('3_long_t', LongType()), - ('4_long_t', LongType()) - ]) - - result6 = (df1.select('1_str_t', '2_int_t', '3_long_t') - .groupby('1_str_t') - .apply(foo6_udf) - .sort('1_str_t') - .toPandas()) - - def foo7(pdf): - pdf['4_long_t'] = pdf['2_int_t'] * pdf['3_long_t'] - return pdf - - foo7_udf = pandas_udf(foo7, expected.dtypes) - - result7 = (df1.select('1_str_t', '2_int_t', '3_long_t') - .groupby('1_str_t') - .apply(foo7_udf) - .sort('1_str_t') - .toPandas()) - - self.assertFramesEqual(result, expected) - self.assertFramesEqual(result2, expected) - self.assertFramesEqual(result3, expected) - #self.assertFramesEqual(result4, expected) - self.assertFramesEqual(result5, expected) - self.assertFramesEqual(result6, expected) - self.assertFramesEqual(result7, expected) - - def test_groupby_apply_timestamp(self): - from pyspark.sql.functions import col, udf, pandas_udf - from pyspark.sql.types import TimestampType, DateType - import datetime - - df = self.spark.createDataFrame(self.data, schema=self.schema) - df1 = df.select('3_long_t').withColumn('time', df['3_long_t'].cast(TimestampType())) - pdf1 = df1.toPandas() - - def foo(pdf): - return pdf.assign(time=pdf['time'] + datetime.timedelta(days=15)) - foo_udf = pandas_udf(foo, foo(pdf1).dtypes) - - result = df1.groupby('3_long_t').apply(foo_udf).sort('3_long_t').toPandas() - expected = foo(pdf1) - - self.assertFramesEqual(result, expected) - - def test_groupby_apply_series(self): - from pyspark.sql.functions import col, udf, pandas_udf - from pyspark.sql.types import DoubleType - import pandas as pd - - df = self.spark.createDataFrame(self.data, schema=self.schema) - - expected = pd.DataFrame({ - '1_str_t': pd.Series(['a', 'b', 'c']), - 'v1': pd.Series([0.0, 0.0, 0.0]), - 'v2': pd.Series([1.0, 1.0, 1.0]), - }) - - def foo1(pdf): - return pd.Series([0.0, 1.0]) - - foo1_udf = pandas_udf(foo1, group_add=[('v1', DoubleType()), ('v2', DoubleType())]) - result1 = df.groupby('1_str_t').apply(foo1_udf).sort('1_str_t').toPandas() - self.assertFramesEqual(result1, expected) - - def test_groupby_apply_cache(self): - from pyspark.sql.functions import col, udf, pandas_udf - from pyspark.sql.types import DoubleType - import pandas as pd - - df = self.spark.createDataFrame(self.data, schema=self.schema) - - def foo1(pdf): - return pd.Series([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]) - - foo1_udf = pandas_udf(foo1, group_add=[ - ('v1', DoubleType()), - ('v2', DoubleType()), - ('v3', DoubleType()), - ('v4', DoubleType()), - ('v5', DoubleType()), - ('v6', DoubleType()), - ('v7', DoubleType()), - ('v8', DoubleType()), - ('v9', DoubleType()), - ('v10', DoubleType()), - ('v11', DoubleType()), - ('v12', DoubleType()), - ('v13', DoubleType()), - ]) - result1 = df.groupby('1_str_t').apply(foo1_udf).sort('1_str_t').cache() - result1.count() - result1.show() - @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class VectorizedUDFTests(ReusedPySparkTestCase): @@ -3563,20 +3389,33 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() cls.spark.stop() + def assertFramesEqual(self, expected, result): + msg = ("DataFrames are not equal: " + + ("\n\nExpected:\n%s\n%s" % (expected, expected.dtypes)) + + ("\n\nResult:\n%s\n%s" % (result, result.dtypes))) + self.assertTrue(expected.equals(result), msg=msg) + def test_groupby_apply(self): from pyspark.sql.functions import pandas_udf, array, explode, col, lit df = self.spark.range(10).toDF('id').withColumn("vs", array([lit(i) for i in range(20, 30)])).withColumn("v", explode(col('vs'))).drop('vs') def foo(df): - import pandas as pd - return pd.DataFrame({'mean': [df.v.mean()], 'std': [df.v.std()]}) + ret = df + ret = ret.assign(v1=df.v * df.id * 1.0) + ret = ret.assign(v2=df.v + df.id) + return ret foo_udf = pandas_udf( foo, - StructType([StructField('mean', DoubleType()), StructField('std', DoubleType())])) - - df2 = df.groupby('id').apply(foo_udf) - df2.show(1000) + StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('v1', DoubleType()), + StructField('v2', LongType())])) + + result = df.groupby('id').apply(foo_udf).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(foo).reset_index(drop=True) + self.assertFramesEqual(expected, result) if __name__ == "__main__": diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index ab6c966ef3f3..64be86db97f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -23,7 +23,7 @@ import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, NamedExpression, SortOrder, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} @@ -67,8 +67,7 @@ case class FlatMapGroupsInPandasExec( PythonEvalType.SQL_PANDAS_UDF, argOffsets, child.schema) .compute(grouped.map(_._2), context.partitionId(), context) - - new Iterator[InternalRow] { + val vectorRowIter = new Iterator[InternalRow] { private var currentIter = if (columnarBatchIter.hasNext) { val batch = columnarBatchIter.next() // assert(schemaOut.equals(batch.schema), @@ -89,6 +88,8 @@ case class FlatMapGroupsInPandasExec( override def next(): InternalRow = currentIter.next() } + + vectorRowIter.map(UnsafeProjection.create(output, output)) } } } From 07bcccaa73f6cd9b63bd776475f89356d60898fa Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 28 Sep 2017 14:37:42 -0400 Subject: [PATCH 03/34] Add support for dtypes as returnType --- python/pyspark/sql/functions.py | 6 +++++- python/pyspark/sql/tests.py | 31 +++++++++++++++++++++++++++++-- python/pyspark/sql/types.py | 27 ++++++++++++++++++++++++++- python/pyspark/worker.py | 6 +++--- 4 files changed, 63 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 4bc7b5fbce48..cd4218c8921e 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -28,7 +28,7 @@ from pyspark import since, SparkContext from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer -from pyspark.sql.types import StringType, DataType, _parse_datatype_string +from pyspark.sql.types import StringType, DataType, _parse_datatype_string, from_pandas_dtypes from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.dataframe import DataFrame @@ -2207,6 +2207,10 @@ def pandas_udf(f=None, returnType=StringType()): | 8| JOHN DOE| 22| +----------+--------------+------------+ """ + import pandas as pd + if isinstance(returnType, pd.Series): + returnType = from_pandas_dtypes(returnType) + return _create_udf(f, returnType=returnType, vectorized=True) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ff0045238ffe..153abd903daa 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3395,9 +3395,16 @@ def assertFramesEqual(self, expected, result): ("\n\nResult:\n%s\n%s" % (result, result.dtypes))) self.assertTrue(expected.equals(result), msg=msg) - def test_groupby_apply(self): + @property + def data(self): from pyspark.sql.functions import pandas_udf, array, explode, col, lit - df = self.spark.range(10).toDF('id').withColumn("vs", array([lit(i) for i in range(20, 30)])).withColumn("v", explode(col('vs'))).drop('vs') + return self.spark.range(10).toDF('id') \ + .withColumn("vs", array([lit(i) for i in range(20, 30)])) \ + .withColumn("v", explode(col('vs'))).drop('vs') + + def test_groupby_apply_simple(self): + from pyspark.sql.functions import pandas_udf + df = self.data def foo(df): ret = df @@ -3417,6 +3424,26 @@ def foo(df): expected = df.toPandas().groupby('id').apply(foo).reset_index(drop=True) self.assertFramesEqual(expected, result) + def test_groupby_apply_dtypes(self): + from pyspark.sql.functions import pandas_udf + df = self.data + + def foo(df): + ret = df + ret = ret.assign(v3=df.v * 5.0 + 1) + return ret + + sample_df = df.filter(df.id == 1).toPandas() + + foo_udf = pandas_udf( + foo, + foo(sample_df).dtypes + ) + + result = df.groupby('id').apply(foo_udf).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(foo).reset_index(drop=True) + self.assertFramesEqual(expected, result) + if __name__ == "__main__": from pyspark.sql.tests import * diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index ebdc11c3b744..87b09a0d6924 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1597,7 +1597,7 @@ def convert(self, obj, gateway_client): register_input_converter(DateConverter()) -def toArrowType(dt): +def to_arrow_type(dt): """ Convert Spark data type to pyarrow type """ import pyarrow as pa @@ -1623,6 +1623,31 @@ def toArrowType(dt): raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) return arrow_type +def from_pandas_type(dt): + """ Convert pandas data type to Spark data type + """ + import pandas as pd + import numpy as np + if dt == np.int32: + return IntegerType() + elif dt == np.int64: + return LongType() + elif dt == np.float32: + return FloatType() + elif dt == np.float64: + return DoubleType() + elif dt == np.object: + return StringType() + elif dt == np.dtype('datetime64[ns]') or type(dt) == pd.api.types.DatetimeTZDtype: + return TimestampType() + else: + raise ValueError("Unsupported numpy type in conversion to Spark: {}".format(dt)) + +def from_pandas_dtypes(dtypes): + """ Convert pandas DataFrame dtypes to Spark schema + """ + return StructType([StructField(dtypes.axes[0][i], from_pandas_type(dtypes[i])) + for i in range(len(dtypes))]) def _test(): import doctest diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 6e89f1da5910..ee7d82210e41 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -32,7 +32,7 @@ from pyspark.serializers import write_with_length, write_int, read_long, \ write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer -from pyspark.sql.types import toArrowType +from pyspark.sql.types import to_arrow_type from pyspark import shuffle from pyspark.sql.types import StructType, IntegerType, LongType, FloatType, DoubleType @@ -76,7 +76,7 @@ def wrap_udf(f, return_type): def wrap_pandas_udf(f, return_type): if isinstance(return_type, StructType): - arrow_return_types = list(toArrowType(field.dataType) for field in return_type) + arrow_return_types = list(to_arrow_type(field.dataType) for field in return_type) def fn(*a): import pandas as pd @@ -89,7 +89,7 @@ def fn(*a): return fn else: - arrow_return_type = toArrowType(return_type) + arrow_return_type = to_arrow_type(return_type) def verify_result_length(*a): result = f(*a) From e7a9b27cb1e3ca3c823a092166e34a2a02d510c1 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 28 Sep 2017 17:31:54 -0400 Subject: [PATCH 04/34] Fix pep8 sytle check --- python/pyspark/sql/group.py | 8 ++++++-- python/pyspark/sql/tests.py | 5 +++-- python/pyspark/sql/types.py | 2 ++ python/pyspark/worker.py | 5 +++-- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 35893a109ded..47427ca87725 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -196,7 +196,8 @@ def pivot(self, pivot_col, values=None): def apply(self, udf_obj): """ - Maps each group of the current [[DataFrame]] using a pandas udf and returns the result as a :class:`DataFrame`. + Maps each group of the current [[DataFrame]] using a pandas udf and returns the result + as a :class:`DataFrame`. """ from pyspark.sql.functions import pandas_udf @@ -206,8 +207,10 @@ def apply(self, udf_obj): returnType = udf_obj.returnType # The python executors expects the function to take a list of pd.Series as input - # So we to create a wrapper function that turns that to a pd.DataFrame before passing down to the user function + # So we to create a wrapper function that turns that to a pd.DataFrame before passing + # down to the user function columns = df.columns + def wrapped(*cols): import pandas as pd return func(pd.concat(cols, axis=1, keys=columns)) @@ -217,6 +220,7 @@ def wrapped(*cols): jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) return DataFrame(jdf, self.sql_ctx) + def _test(): import doctest from pyspark.sql import Row, SparkSession diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 153abd903daa..db9c8b3c3697 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3109,8 +3109,8 @@ def test_unsupported_datatype(self): schema = StructType([StructField("dt", TimestampType(), True)]) df = self.spark.createDataFrame([(datetime.datetime(1970, 1, 1),)], schema=schema) - #with QuietTest(self.sc): - # self.assertRaises(Exception, lambda: df.toPandas()) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: df.toPandas()) def test_null_conversion(self): df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + @@ -3377,6 +3377,7 @@ def test_vectorized_udf_empty_partition(self): res = df.select(f(col('id'))) self.assertEquals(df.collect(), res.collect()) + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedPySparkTestCase): @classmethod diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 87b09a0d6924..200e98059a3e 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1623,6 +1623,7 @@ def to_arrow_type(dt): raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) return arrow_type + def from_pandas_type(dt): """ Convert pandas data type to Spark data type """ @@ -1643,6 +1644,7 @@ def from_pandas_type(dt): else: raise ValueError("Unsupported numpy type in conversion to Spark: {}".format(dt)) + def from_pandas_dtypes(dtypes): """ Convert pandas DataFrame dtypes to Spark schema """ diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index ee7d82210e41..3e11ef323f0f 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -85,7 +85,8 @@ def fn(*a): assert len(out.columns) == len(arrow_return_types), \ 'Columns of pd.DataFrame don\'t match return schema' - return list((out[out.columns[i]], arrow_return_types[i]) for i in range(len(arrow_return_types))) + return list((out[out.columns[i]], arrow_return_types[i]) + for i in range(len(arrow_return_types))) return fn else: @@ -96,7 +97,7 @@ def verify_result_length(*a): if not hasattr(result, "__len__"): raise TypeError("Return type of pandas_udf should be a Pandas.Series") if len(result) != len(a[0]): - raise RuntimeError("Result vector from pandas_udf was not the required length: " \ + raise RuntimeError("Result vector from pandas_udf was not the required length: " "expected %d, got %d" % (len(a[0]), len(result))) return result From 83b647e9504a921f1b4e98ec5f1f2eb2963a2411 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Fri, 29 Sep 2017 17:21:14 -0400 Subject: [PATCH 05/34] Address comments. Updated doc string for pandas_udf. --- python/pyspark/sql/functions.py | 80 ++++++++++++++----- python/pyspark/sql/group.py | 5 ++ .../catalyst/expressions/AttributeSet.scala | 3 - .../sql/catalyst/plans/logical/object.scala | 6 +- .../spark/sql/RelationalGroupedDataset.scala | 14 ++-- .../python/FlatMapGroupsInPandasExec.scala | 12 +-- 6 files changed, 80 insertions(+), 40 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index cd4218c8921e..cfb1c807f4da 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2120,6 +2120,7 @@ def wrapper(*args): else self.func.__class__.__module__) wrapper.func = self.func wrapper.returnType = self.returnType + wrapper._vectorized = self._vectorized return wrapper @@ -2131,7 +2132,10 @@ def _udf(f, returnType=StringType(), vectorized=vectorized): import inspect argspec = inspect.getargspec(f) if len(argspec.args) == 0 and argspec.varargs is None: - raise NotImplementedError("0-parameter pandas_udfs are not currently supported") + raise ValueError( + "0-arg pandas_udf are not supported. " + "Instead, create a 1-arg pandas_udf and ignore the arg in your function." + ) udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized) return udf_obj._wrapped() @@ -2182,30 +2186,64 @@ def udf(f=None, returnType=StringType()): @since(2.3) def pandas_udf(f=None, returnType=StringType()): """ - Creates a :class:`Column` expression representing a user defined function (UDF) that accepts - `Pandas.Series` as input arguments and outputs a `Pandas.Series` of the same length. + Creates a :class:`Column` expression representing a vectorized user defined function (UDF). + + The user-defined function can define one of the following transformations: + 1. One or more `pandas.Series` -> A `pandas.Series` + + This udf is used with `DataFrame.withColumn` and `DataFrame.select`. + The returnType should be a primitive data type, e.g., DoubleType() + + Example: + + >>> from pyspark.sql.types import IntegerType, StringType + >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) + >>> @pandas_udf(returnType=StringType()) + ... def to_upper(s): + ... return s.str.upper() + ... + >>> @pandas_udf(returnType="integer") + ... def add_one(x): + ... return x + 1 + ... + >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age")) + >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\ + ... .show() # doctest: +SKIP + +----------+--------------+------------+ + |slen(name)|to_upper(name)|add_one(age)| + +----------+--------------+------------+ + | 8| JOHN DOE| 22| + +----------+--------------+------------+ + + 2. A `pandas.DataFrame` -> A `pandas.DataFrame` + + This udf is used with `GroupedData.apply` + The returnType should be a StructType describing the schema of the returned + `pandas.DataFrame`. + + Example: + + >>> df = spark.createDataFrame([(1, 1.0), (1, 2.0), (2, 3.0), (2, 4.0)], ("id", "v")) + >>> @pandas_udf(returnType=df.schema) + ... def normalize(df): + ... v = df.v + ... ret = df.assign(v=(v - v.mean()) / v.std()) + >>> df.groupby('id').apply(normalize).show() # doctest: + SKIP + +---+-------------------+ + | id| v| + +---+-------------------+ + | 1|-0.7071067811865475| + | 1| 0.7071067811865475| + | 2|-0.7071067811865475| + | 2| 0.7071067811865475| + +---+-------------------+ + + + .. note:: The user-defined functions must be deterministic. :param f: python function if used as a standalone function :param returnType: a :class:`pyspark.sql.types.DataType` object - >>> from pyspark.sql.types import IntegerType, StringType - >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) - >>> @pandas_udf(returnType=StringType()) - ... def to_upper(s): - ... return s.str.upper() - ... - >>> @pandas_udf(returnType="integer") - ... def add_one(x): - ... return x + 1 - ... - >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age")) - >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\ - ... .show() # doctest: +SKIP - +----------+--------------+------------+ - |slen(name)|to_upper(name)|add_one(age)| - +----------+--------------+------------+ - | 8| JOHN DOE| 22| - +----------+--------------+------------+ """ import pandas as pd if isinstance(returnType, pd.Series): diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 47427ca87725..88c9ddacd84a 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -202,6 +202,11 @@ def apply(self, udf_obj): """ from pyspark.sql.functions import pandas_udf + if not udf_obj._vectorized: + raise ValueError("Must pass a pandas_udf") + if not isinstance(udf_obj.returnType, StructType): + raise ValueError("Must pass a StructType as return type in pandas_udf") + df = DataFrame(self._jgd.df(), self.sql_ctx) func = udf_obj.func returnType = udf_obj.returnType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index d37fdc0d8354..7420b6b57d8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -37,9 +37,6 @@ object AttributeSet { /** Constructs a new [[AttributeSet]] that contains a single [[Attribute]]. */ def apply(a: Attribute): AttributeSet = new AttributeSet(Set(new AttributeEquals(a))) - def apply(as: Attribute*): AttributeSet = - new AttributeSet(Set(as.map(new AttributeEquals(_)): _*)) - /** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */ def apply(baseSet: Iterable[Expression]): AttributeSet = { new AttributeSet( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 52721eb341c9..e0d91309d342 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -24,7 +24,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.{Encoder, Row} import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ -import org.apache.spark.sql.catalyst.expressions.{AttributeSet, _} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} import org.apache.spark.sql.types._ @@ -521,9 +521,9 @@ case class CoGroup( right: LogicalPlan) extends BinaryNode with ObjectProducer case class FlatMapGroupsInPandas( - groupingExprs: Seq[Expression], + groupingAttributes: Seq[Attribute], functionExpr: Expression, - override val output: Seq[Attribute], + output: Seq[Attribute], child: LogicalPlan) extends UnaryNode { /** * This is needed because output attributes is considered `reference` when diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index a563ce8f0ebb..caaa20415741 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.types.{NumericType, StructField, StructType} @InterfaceStability.Stable class RelationalGroupedDataset protected[sql]( val df: DataFrame, - val groupingExprs: Seq[Expression], + groupingExprs: Seq[Expression], groupType: RelationalGroupedDataset.GroupType) { private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { @@ -436,9 +436,9 @@ class RelationalGroupedDataset protected[sql]( df.logicalPlan)) } - private[sql] def flatMapGroupsInPandas( - expr: PythonUDF - ): DataFrame = { + private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = { + require(expr.vectorized, "Must pass a vectorized python udf") + val output = expr.dataType match { case s: StructType => s.map { case StructField(name, dataType, nullable, metadata) => @@ -446,8 +446,12 @@ class RelationalGroupedDataset protected[sql]( } } + val groupingAttributes: Seq[Attribute] = groupingExprs.map { + case ne: NamedExpression => ne.toAttribute + } + val plan = FlatMapGroupsInPandas( - groupingExprs, + groupingAttributes, expr, output, df.logicalPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 64be86db97f5..61ad1223b6a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -28,15 +28,11 @@ import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Dist import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} case class FlatMapGroupsInPandasExec( - grouping: Seq[Expression], + groupingAttributes: Seq[Attribute], func: Expression, - override val output: Seq[Attribute], - override val child: SparkPlan -) extends UnaryExecNode { - - val groupingAttributes: Seq[Attribute] = grouping.map { - case ne: NamedExpression => ne.toAttribute - } + output: Seq[Attribute], + child: SparkPlan) + extends UnaryExecNode { private val pandasFunction = func.asInstanceOf[PythonUDF].func From 8d98b3e2d56c1b5ae9e7924a077793c4ec51bd76 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Fri, 29 Sep 2017 17:49:01 -0400 Subject: [PATCH 06/34] Replace iter.grouped with BatchIterator --- python/pyspark/sql/functions.py | 2 +- python/pyspark/sql/tests.py | 8 ++--- .../python/ArrowEvalPythonExec.scala | 35 +++++++++++++++++-- 3 files changed, 37 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index cfb1c807f4da..5f4500eb5432 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2133,7 +2133,7 @@ def _udf(f, returnType=StringType(), vectorized=vectorized): argspec = inspect.getargspec(f) if len(argspec.args) == 0 and argspec.varargs is None: raise ValueError( - "0-arg pandas_udf are not supported. " + "0-arg pandas_udfs are not supported. " "Instead, create a 1-arg pandas_udf and ignore the arg in your function." ) udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index db9c8b3c3697..2902f76a5545 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3257,17 +3257,17 @@ def test_vectorized_udf_null_string(self): def test_vectorized_udf_zero_parameter(self): from pyspark.sql.functions import pandas_udf - error_str = '0-parameter pandas_udfs.*not.*supported' + error_str = '0-arg pandas_udfs.*not.*supported' with QuietTest(self.sc): - with self.assertRaisesRegexp(NotImplementedError, error_str): + with self.assertRaisesRegexp(ValueError, error_str): pandas_udf(lambda: 1, LongType()) - with self.assertRaisesRegexp(NotImplementedError, error_str): + with self.assertRaisesRegexp(ValueError, error_str): @pandas_udf def zero_no_type(): return 1 - with self.assertRaisesRegexp(NotImplementedError, error_str): + with self.assertRaisesRegexp(ValueError, error_str): @pandas_udf(LongType()) def zero_with_type(): return 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 7bba0ff300a8..73bbbdcea80a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -26,6 +26,28 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.types.StructType +private object BatchIterator { + class InnerIterator[T](iter: Iterator[T], batchSize: Int) extends Iterator[T] { + var count = 0 + override def hasNext: Boolean = iter.hasNext && count < batchSize + + override def next(): T = { + count += 1 + iter.next() + } + } +} + +private class BatchIterator[T](iter: Iterator[T], batchSize: Int) + extends Iterator[Iterator[T]] { + + override def hasNext: Boolean = iter.hasNext + + override def next(): Iterator[T] = { + new BatchIterator.InnerIterator[T](iter, batchSize) + } +} + /** * A physical plan that evaluates a [[PythonUDF]], */ @@ -44,13 +66,20 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex .map { case (attr, i) => attr.withName(s"_$i") }) - val batchedIter: Iterator[Iterator[InternalRow]] = - iter.grouped(conf.arrowMaxRecordsPerBatch).map(_.iterator) + val batchSize = conf.arrowMaxRecordsPerBatch + + val batchIter = if (batchSize > 0) { + new BatchIterator(iter, batchSize) + } else if (batchSize == 0) { + Iterator(iter) + } else { + throw new IllegalArgumentException(s"MaxRecordsPerBatch must be >= 0, but is $batchSize") + } val columnarBatchIter = new ArrowPythonRunner( funcs, bufferSize, reuseWorker, PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema) - .compute(batchedIter, context.partitionId(), context) + .compute(batchIter, context.partitionId(), context) new Iterator[InternalRow] { From 96ce587285dc73ab0cc134f97c06bcef8ff5fe4e Mon Sep 17 00:00:00 2001 From: Li Jin Date: Fri, 29 Sep 2017 17:58:35 -0400 Subject: [PATCH 07/34] [Minor] Fix pep8 --- python/pyspark/sql/types.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 200e98059a3e..cb80139f1fd0 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1651,6 +1651,7 @@ def from_pandas_dtypes(dtypes): return StructType([StructField(dtypes.axes[0][i], from_pandas_type(dtypes[i])) for i in range(len(dtypes))]) + def _test(): import doctest from pyspark.context import SparkContext From 213dd1a4bfa1132946471a3c1411db52aa6e44cd Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 3 Oct 2017 16:20:43 -0400 Subject: [PATCH 08/34] Clean up code. Refine doc for pandas_udf() and apply(). Address comments. Add more tests. --- python/pyspark/sql/functions.py | 67 ++++++++------- python/pyspark/sql/group.py | 46 ++++++++-- python/pyspark/sql/tests.py | 84 ++++++++++++++++++- python/pyspark/worker.py | 10 ++- .../python/ArrowEvalPythonExec.scala | 27 +++--- .../python/FlatMapGroupsInPandasExec.scala | 8 +- 6 files changed, 174 insertions(+), 68 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 5f4500eb5432..ea831da5756b 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2058,7 +2058,7 @@ def __init__(self, func, returnType, name=None, vectorized=False): self._name = name or ( func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) - self._vectorized = vectorized + self.vectorized = vectorized @property def returnType(self): @@ -2090,7 +2090,7 @@ def _create_judf(self): wrapped_func = _wrap_function(sc, self.func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - self._name, wrapped_func, jdt, self._vectorized) + self._name, wrapped_func, jdt, self.vectorized) return judf def __call__(self, *cols): @@ -2118,9 +2118,10 @@ def wrapper(*args): wrapper.__name__ = self._name wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__') else self.func.__class__.__module__) + wrapper.func = self.func wrapper.returnType = self.returnType - wrapper._vectorized = self._vectorized + wrapper.vectorized = self.vectorized return wrapper @@ -2151,7 +2152,7 @@ def _udf(f, returnType=StringType(), vectorized=vectorized): @since(1.3) def udf(f=None, returnType=StringType()): - """Creates a :class:`Column` expression representing a user defined function (UDF). + """Creates a user defined function (UDF). .. note:: The user-defined functions must be deterministic. Due to optimization, duplicate invocations may be eliminated or the function may even be invoked more times than @@ -2186,15 +2187,19 @@ def udf(f=None, returnType=StringType()): @since(2.3) def pandas_udf(f=None, returnType=StringType()): """ - Creates a :class:`Column` expression representing a vectorized user defined function (UDF). + Creates a vectorized user defined function (UDF). + + :param f: user-defined function. A python function if used as a standalone function + :param returnType: a :class:`pyspark.sql.types.DataType` object The user-defined function can define one of the following transformations: - 1. One or more `pandas.Series` -> A `pandas.Series` - This udf is used with `DataFrame.withColumn` and `DataFrame.select`. - The returnType should be a primitive data type, e.g., DoubleType() + 1. One or more `pandas.Series` -> A `pandas.Series` - Example: + This udf is used with :meth:`pyspark.sql.DataFrame.withColumn` and + :meth:`pyspark.sql.DataFrame.select`. + The returnType should be a primitive data type, e.g., `DoubleType()`. + The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. >>> from pyspark.sql.types import IntegerType, StringType >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) @@ -2217,33 +2222,31 @@ def pandas_udf(f=None, returnType=StringType()): 2. A `pandas.DataFrame` -> A `pandas.DataFrame` - This udf is used with `GroupedData.apply` - The returnType should be a StructType describing the schema of the returned + This udf is used with :meth:`pyspark.sql.GroupedData.apply`. + The returnType should be a :class:`StructType` describing the schema of the returned `pandas.DataFrame`. - Example: - - >>> df = spark.createDataFrame([(1, 1.0), (1, 2.0), (2, 3.0), (2, 4.0)], ("id", "v")) + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) >>> @pandas_udf(returnType=df.schema) - ... def normalize(df): - ... v = df.v - ... ret = df.assign(v=(v - v.mean()) / v.std()) + ... def normalize(pdf): + ... v = pdf.v + ... return pdf.assign(v=(v - v.mean()) / v.std()) >>> df.groupby('id').apply(normalize).show() # doctest: + SKIP - +---+-------------------+ - | id| v| - +---+-------------------+ - | 1|-0.7071067811865475| - | 1| 0.7071067811865475| - | 2|-0.7071067811865475| - | 2| 0.7071067811865475| - +---+-------------------+ - - - .. note:: The user-defined functions must be deterministic. - - :param f: python function if used as a standalone function - :param returnType: a :class:`pyspark.sql.types.DataType` object - + +---+-------------------+ + | id| v| + +---+-------------------+ + | 1|-0.7071067811865475| + | 1| 0.7071067811865475| + | 2|-0.8320502943378437| + | 2|-0.2773500981126146| + | 2| 1.1094003924504583| + +---+-------------------+ + + .. seealso:: :meth:`pyspark.sql.GroupedData.apply` + + .. note:: The user-defined function must be deterministic. """ import pandas as pd if isinstance(returnType, pd.Series): diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 88c9ddacd84a..1aebb50de7f9 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -170,7 +170,7 @@ def sum(self, *cols): @since(1.6) def pivot(self, pivot_col, values=None): """ - Pivots a column of the current [[DataFrame]] and perform the specified aggregation. + Pivots a column of the current :class:`DataFrame` and perform the specified aggregation. There are two versions of pivot function: one that requires the caller to specify the list of distinct values to pivot on, and one that does not. The latter is more concise but less efficient, because Spark needs to first compute the list of distinct values internally. @@ -194,22 +194,50 @@ def pivot(self, pivot_col, values=None): jgd = self._jgd.pivot(pivot_col, values) return GroupedData(jgd, self.sql_ctx) - def apply(self, udf_obj): + def apply(self, udf): """ - Maps each group of the current [[DataFrame]] using a pandas udf and returns the result + Maps each group of the current :class:`DataFrame` using a pandas udf and returns the result as a :class:`DataFrame`. + The user-function should take a `pandas.DataFrame` and return another `pandas.DataFrame`. + Each group is passed as a `pandas.DataFrame` to the user-function and the returned + `pandas.DataFrame` are combined as a :class:`DataFrame`. The returned `pandas.DataFrame` + can be arbitrary length and its schema should match the returnType of the pandas udf. + + :param udf: A wrapped function returned by `pandas_udf` + + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) + >>> @pandas_udf(returnType=df.schema) + ... def normalize(pdf): + ... v = pdf.v + ... return pdf.assign(v=(v - v.mean()) / v.std()) + >>> df.groupby('id').apply(normalize).show() # doctest: + SKIP + +---+-------------------+ + | id| v| + +---+-------------------+ + | 1|-0.7071067811865475| + | 1| 0.7071067811865475| + | 2|-0.8320502943378437| + | 2|-0.2773500981126146| + | 2| 1.1094003924504583| + +---+-------------------+ + + .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` + """ from pyspark.sql.functions import pandas_udf - if not udf_obj._vectorized: - raise ValueError("Must pass a pandas_udf") - if not isinstance(udf_obj.returnType, StructType): - raise ValueError("Must pass a StructType as return type in pandas_udf") + # Columns are special because hasattr always return True + if isinstance(udf, Column) or not hasattr(udf, 'func') or not udf.vectorized: + raise ValueError("The argument to apply must be a pandas_udf") + if not isinstance(udf.returnType, StructType): + raise ValueError("The returnType of the pandas_udf must be a StructType") df = DataFrame(self._jgd.df(), self.sql_ctx) - func = udf_obj.func - returnType = udf_obj.returnType + func = udf.func + returnType = udf.returnType # The python executors expects the function to take a list of pd.Series as input # So we to create a wrapper function that turns that to a pd.DataFrame before passing diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2902f76a5545..5144c37db12c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3349,7 +3349,7 @@ def test_vectorized_udf_wrong_return_type(self): df = self.spark.range(10) f = pandas_udf(lambda x: x * 1.0, StringType()) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Invalid.*type.*string'): + with self.assertRaisesRegexp(Exception, 'Invalid.*type'): df.select(f(col('id'))).collect() def test_vectorized_udf_return_scalar(self): @@ -3377,6 +3377,13 @@ def test_vectorized_udf_empty_partition(self): res = df.select(f(col('id'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_varargs(self): + from pyspark.sql.functions import pandas_udf, col + df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2)) + f = pandas_udf(lambda *v: v[0], LongType()) + res = df.select(f(col('id'))) + self.assertEquals(df.collect(), res.collect()) + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedPySparkTestCase): @@ -3403,7 +3410,7 @@ def data(self): .withColumn("vs", array([lit(i) for i in range(20, 30)])) \ .withColumn("v", explode(col('vs'))).drop('vs') - def test_groupby_apply_simple(self): + def test_simple(self): from pyspark.sql.functions import pandas_udf df = self.data @@ -3425,7 +3432,27 @@ def foo(df): expected = df.toPandas().groupby('id').apply(foo).reset_index(drop=True) self.assertFramesEqual(expected, result) - def test_groupby_apply_dtypes(self): + def test_decorator(self): + from pyspark.sql.functions import pandas_udf + df = self.data + + @pandas_udf(StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('v1', DoubleType()), + StructField('v2', LongType())])) + def foo(df): + ret = df + ret = ret.assign(v1=df.v * df.id * 1.0) + ret = ret.assign(v2=df.v + df.id) + return ret + + + result = df.groupby('id').apply(foo).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) + self.assertFramesEqual(expected, result) + + def test_dtypes(self): from pyspark.sql.functions import pandas_udf df = self.data @@ -3445,6 +3472,57 @@ def foo(df): expected = df.toPandas().groupby('id').apply(foo).reset_index(drop=True) self.assertFramesEqual(expected, result) + def test_coerce(self): + from pyspark.sql.functions import pandas_udf + df = self.data + + def foo(df): + ret = df + ret = ret.assign(v=df.v + 1) + return ret + + @pandas_udf(StructType([StructField('id', LongType()), StructField('v', DoubleType())])) + def foo(df): + return df + + result = df.groupby('id').apply(foo).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) + expected = expected.assign(v=expected.v.astype('float64')) + self.assertFramesEqual(expected, result) + + def test_wrong_return_type(self): + from pyspark.sql.functions import pandas_udf + df = self.data + + def foo(df): + ret = df + ret = ret.assign(v=df.v + 1) + return ret + + @pandas_udf(StructType([StructField('id', LongType()), StructField('v', StringType())])) + def foo(df): + return df + + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Invalid.*type'): + df.groupby('id').apply(foo).sort('id').toPandas() + + def test_wrong_args(self): + from pyspark.sql.functions import udf, pandas_udf, sum + df = self.data + + with QuietTest(self.sc): + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply(lambda x: x) + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply(udf(lambda x: x, DoubleType())) + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply(sum(df.v)) + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply(df.v + 1) + with self.assertRaisesRegexp(ValueError, 'returnType'): + df.groupby('id').apply(pandas_udf(lambda x: x, DoubleType())) + if __name__ == "__main__": from pyspark.sql.tests import * diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 3e11ef323f0f..7f49e5c26690 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -34,7 +34,7 @@ BatchedSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import to_arrow_type from pyspark import shuffle -from pyspark.sql.types import StructType, IntegerType, LongType, FloatType, DoubleType +from pyspark.sql.types import StructType pickleSer = PickleSerializer() utf8_deserializer = UTF8Deserializer() @@ -76,14 +76,18 @@ def wrap_udf(f, return_type): def wrap_pandas_udf(f, return_type): if isinstance(return_type, StructType): + import pyarrow as pa + arrow_return_types = list(to_arrow_type(field.dataType) for field in return_type) def fn(*a): import pandas as pd out = f(*a) - assert isinstance(out, pd.DataFrame), 'Must return a pd.DataFrame' + assert isinstance(out, pd.DataFrame), \ + 'Return value from the user function is not a pandas.DataFrame.' assert len(out.columns) == len(arrow_return_types), \ - 'Columns of pd.DataFrame don\'t match return schema' + 'Number of columns of the returned pd.DataFrame doesn\'t match output schema. ' \ + 'Expected: {} Actual: {}'.format(len(arrow_return_types), len(out.columns)) return list((out[out.columns[i]], arrow_return_types[i]) for i in range(len(arrow_return_types))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 73bbbdcea80a..3e7d49bad218 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -26,25 +26,22 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.types.StructType -private object BatchIterator { - class InnerIterator[T](iter: Iterator[T], batchSize: Int) extends Iterator[T] { - var count = 0 - override def hasNext: Boolean = iter.hasNext && count < batchSize - - override def next(): T = { - count += 1 - iter.next() - } - } -} - private class BatchIterator[T](iter: Iterator[T], batchSize: Int) extends Iterator[Iterator[T]] { override def hasNext: Boolean = iter.hasNext override def next(): Iterator[T] = { - new BatchIterator.InnerIterator[T](iter, batchSize) + new Iterator[T] { + var count = 0 + + override def hasNext: Boolean = iter.hasNext && count < batchSize + + override def next(): T = { + count += 1 + iter.next() + } + } } } @@ -70,10 +67,8 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val batchIter = if (batchSize > 0) { new BatchIterator(iter, batchSize) - } else if (batchSize == 0) { - Iterator(iter) } else { - throw new IllegalArgumentException(s"MaxRecordsPerBatch must be >= 0, but is $batchSize") + Iterator(iter) } val columnarBatchIter = new ArrowPythonRunner( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 61ad1223b6a2..397d34a5c2f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -23,7 +23,7 @@ import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, NamedExpression, SortOrder, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} @@ -63,11 +63,9 @@ case class FlatMapGroupsInPandasExec( PythonEvalType.SQL_PANDAS_UDF, argOffsets, child.schema) .compute(grouped.map(_._2), context.partitionId(), context) - val vectorRowIter = new Iterator[InternalRow] { + val rowIter = new Iterator[InternalRow] { private var currentIter = if (columnarBatchIter.hasNext) { val batch = columnarBatchIter.next() - // assert(schemaOut.equals(batch.schema), - // s"Invalid schema from pandas_udf: expected $schemaOut, got ${batch.schema}") batch.rowIterator.asScala } else { Iterator.empty @@ -85,7 +83,7 @@ case class FlatMapGroupsInPandasExec( override def next(): InternalRow = currentIter.next() } - vectorRowIter.map(UnsafeProjection.create(output, output)) + rowIter.map(UnsafeProjection.create(output, output)) } } } From d37a9e6a19e3f2b5bef796ba20cdb5bc46817f62 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 3 Oct 2017 16:25:21 -0400 Subject: [PATCH 09/34] Remove dynamic returnType support --- python/pyspark/sql/functions.py | 6 +----- python/pyspark/sql/tests.py | 20 -------------------- python/pyspark/sql/types.py | 28 ---------------------------- 3 files changed, 1 insertion(+), 53 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ea831da5756b..0f2b852b8765 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -28,7 +28,7 @@ from pyspark import since, SparkContext from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer -from pyspark.sql.types import StringType, DataType, _parse_datatype_string, from_pandas_dtypes +from pyspark.sql.types import StringType, DataType, _parse_datatype_string from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.dataframe import DataFrame @@ -2248,10 +2248,6 @@ def pandas_udf(f=None, returnType=StringType()): .. note:: The user-defined function must be deterministic. """ - import pandas as pd - if isinstance(returnType, pd.Series): - returnType = from_pandas_dtypes(returnType) - return _create_udf(f, returnType=returnType, vectorized=True) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5144c37db12c..634ddfee6cd9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3452,26 +3452,6 @@ def foo(df): expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) self.assertFramesEqual(expected, result) - def test_dtypes(self): - from pyspark.sql.functions import pandas_udf - df = self.data - - def foo(df): - ret = df - ret = ret.assign(v3=df.v * 5.0 + 1) - return ret - - sample_df = df.filter(df.id == 1).toPandas() - - foo_udf = pandas_udf( - foo, - foo(sample_df).dtypes - ) - - result = df.groupby('id').apply(foo_udf).sort('id').toPandas() - expected = df.toPandas().groupby('id').apply(foo).reset_index(drop=True) - self.assertFramesEqual(expected, result) - def test_coerce(self): from pyspark.sql.functions import pandas_udf df = self.data diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index cb80139f1fd0..f65273d5f0b6 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1624,34 +1624,6 @@ def to_arrow_type(dt): return arrow_type -def from_pandas_type(dt): - """ Convert pandas data type to Spark data type - """ - import pandas as pd - import numpy as np - if dt == np.int32: - return IntegerType() - elif dt == np.int64: - return LongType() - elif dt == np.float32: - return FloatType() - elif dt == np.float64: - return DoubleType() - elif dt == np.object: - return StringType() - elif dt == np.dtype('datetime64[ns]') or type(dt) == pd.api.types.DatetimeTZDtype: - return TimestampType() - else: - raise ValueError("Unsupported numpy type in conversion to Spark: {}".format(dt)) - - -def from_pandas_dtypes(dtypes): - """ Convert pandas DataFrame dtypes to Spark schema - """ - return StructType([StructField(dtypes.axes[0][i], from_pandas_type(dtypes[i])) - for i in range(len(dtypes))]) - - def _test(): import doctest from pyspark.context import SparkContext From 1ea2b71801784842d1797863456a07c7fcfc2531 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 3 Oct 2017 16:46:24 -0400 Subject: [PATCH 10/34] Fix pep8 style check --- python/pyspark/sql/tests.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 634ddfee6cd9..2d0ec3b68b62 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3447,7 +3447,6 @@ def foo(df): ret = ret.assign(v2=df.v + df.id) return ret - result = df.groupby('id').apply(foo).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) self.assertFramesEqual(expected, result) From 4943ceb8b57041a7bf911e8f8637380c5fc263ff Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 3 Oct 2017 22:35:35 -0400 Subject: [PATCH 11/34] Fix ExtractPythonUDFs --- .../spark/sql/execution/python/ExtractPythonUDFs.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index b28b1efad044..d13489d928b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -111,8 +111,10 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { } def apply(plan: SparkPlan): SparkPlan = plan transformUp { - case plan: ProjectExec => extract(plan) - case plan: FilterExec => extract(plan) + // FlatMapGroupsInPandas and be evaluated in python worker + // Therefore we don't need to extract the UDFs + case plan: FlatMapGroupsInPandasExec => plan + case plan: SparkPlan => extract(plan) } /** From 21fed0dfeefe5775d193d7bb4176c38f0c2b91eb Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 3 Oct 2017 23:12:02 -0400 Subject: [PATCH 12/34] Address new PR comments --- python/pyspark/sql/dataframe.py | 6 +-- python/pyspark/sql/group.py | 9 +++-- python/pyspark/sql/tests.py | 39 +++++-------------- python/pyspark/worker.py | 11 ++---- .../spark/sql/RelationalGroupedDataset.scala | 2 +- .../python/ArrowEvalPythonExec.scala | 7 +--- 6 files changed, 24 insertions(+), 50 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b7ce9a83a616..733ff84c6c7b 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1227,7 +1227,7 @@ def groupBy(self, *cols): """ jgd = self._jdf.groupBy(self._jcols(*cols)) from pyspark.sql.group import GroupedData - return GroupedData(jgd, self.sql_ctx) + return GroupedData(jgd, self) @since(1.4) def rollup(self, *cols): @@ -1248,7 +1248,7 @@ def rollup(self, *cols): """ jgd = self._jdf.rollup(self._jcols(*cols)) from pyspark.sql.group import GroupedData - return GroupedData(jgd, self.sql_ctx) + return GroupedData(jgd, self) @since(1.4) def cube(self, *cols): @@ -1271,7 +1271,7 @@ def cube(self, *cols): """ jgd = self._jdf.cube(self._jcols(*cols)) from pyspark.sql.group import GroupedData - return GroupedData(jgd, self.sql_ctx) + return GroupedData(jgd, self) @since(1.3) def agg(self, *exprs): diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 1aebb50de7f9..9262469fb0a4 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -54,9 +54,10 @@ class GroupedData(object): .. versionadded:: 1.3 """ - def __init__(self, jgd, sql_ctx): + def __init__(self, jgd, df): + self._df = df self._jgd = jgd - self.sql_ctx = sql_ctx + self.sql_ctx = df.sql_ctx @ignore_unicode_prefix @since(1.3) @@ -192,7 +193,7 @@ def pivot(self, pivot_col, values=None): jgd = self._jgd.pivot(pivot_col) else: jgd = self._jgd.pivot(pivot_col, values) - return GroupedData(jgd, self.sql_ctx) + return GroupedData(jgd, self) def apply(self, udf): """ @@ -235,7 +236,7 @@ def apply(self, udf): if not isinstance(udf.returnType, StructType): raise ValueError("The returnType of the pandas_udf must be a StructType") - df = DataFrame(self._jgd.df(), self.sql_ctx) + df = self._df func = udf.func returnType = udf.returnType diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2d0ec3b68b62..b5d4887605dd 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3405,7 +3405,7 @@ def assertFramesEqual(self, expected, result): @property def data(self): - from pyspark.sql.functions import pandas_udf, array, explode, col, lit + from pyspark.sql.functions import array, explode, col, lit return self.spark.range(10).toDF('id') \ .withColumn("vs", array([lit(i) for i in range(20, 30)])) \ .withColumn("v", explode(col('vs'))).drop('vs') @@ -3414,14 +3414,8 @@ def test_simple(self): from pyspark.sql.functions import pandas_udf df = self.data - def foo(df): - ret = df - ret = ret.assign(v1=df.v * df.id * 1.0) - ret = ret.assign(v2=df.v + df.id) - return ret - foo_udf = pandas_udf( - foo, + lambda df: df.assign(v1=df.v * df.id * 1.0, v2=df.v + df.id), StructType( [StructField('id', LongType()), StructField('v', IntegerType()), @@ -3429,7 +3423,7 @@ def foo(df): StructField('v2', LongType())])) result = df.groupby('id').apply(foo_udf).sort('id').toPandas() - expected = df.toPandas().groupby('id').apply(foo).reset_index(drop=True) + expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) self.assertFramesEqual(expected, result) def test_decorator(self): @@ -3442,10 +3436,7 @@ def test_decorator(self): StructField('v1', DoubleType()), StructField('v2', LongType())])) def foo(df): - ret = df - ret = ret.assign(v1=df.v * df.id * 1.0) - ret = ret.assign(v2=df.v + df.id) - return ret + return df.assign(v1=df.v * df.id * 1.0, v2=df.v + df.id) result = df.groupby('id').apply(foo).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) @@ -3455,14 +3446,9 @@ def test_coerce(self): from pyspark.sql.functions import pandas_udf df = self.data - def foo(df): - ret = df - ret = ret.assign(v=df.v + 1) - return ret - - @pandas_udf(StructType([StructField('id', LongType()), StructField('v', DoubleType())])) - def foo(df): - return df + foo = pandas_udf( + lambda df: df, + StructType([StructField('id', LongType()), StructField('v', DoubleType())])) result = df.groupby('id').apply(foo).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) @@ -3473,14 +3459,9 @@ def test_wrong_return_type(self): from pyspark.sql.functions import pandas_udf df = self.data - def foo(df): - ret = df - ret = ret.assign(v=df.v + 1) - return ret - - @pandas_udf(StructType([StructField('id', LongType()), StructField('v', StringType())])) - def foo(df): - return df + foo = pandas_udf( + lambda df: df, + StructType([StructField('id', LongType()), StructField('v', StringType())])) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Invalid.*type'): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 7f49e5c26690..87d2322b3666 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -32,9 +32,8 @@ from pyspark.serializers import write_with_length, write_int, read_long, \ write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer -from pyspark.sql.types import to_arrow_type +from pyspark.sql.types import to_arrow_type, StructType from pyspark import shuffle -from pyspark.sql.types import StructType pickleSer = PickleSerializer() utf8_deserializer = UTF8Deserializer() @@ -76,9 +75,7 @@ def wrap_udf(f, return_type): def wrap_pandas_udf(f, return_type): if isinstance(return_type, StructType): - import pyarrow as pa - - arrow_return_types = list(to_arrow_type(field.dataType) for field in return_type) + arrow_return_types = [to_arrow_type(field.dataType) for field in return_type] def fn(*a): import pandas as pd @@ -89,8 +86,8 @@ def fn(*a): 'Number of columns of the returned pd.DataFrame doesn\'t match output schema. ' \ 'Expected: {} Actual: {}'.format(len(arrow_return_types), len(out.columns)) - return list((out[out.columns[i]], arrow_return_types[i]) - for i in range(len(arrow_return_types))) + return [(out[out.columns[i]], arrow_return_types[i]) + for i in range(len(arrow_return_types))] return fn else: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index caaa20415741..4ecaef9e2122 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -47,7 +47,7 @@ import org.apache.spark.sql.types.{NumericType, StructField, StructType} */ @InterfaceStability.Stable class RelationalGroupedDataset protected[sql]( - val df: DataFrame, + df: DataFrame, groupingExprs: Seq[Expression], groupType: RelationalGroupedDataset.GroupType) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 3e7d49bad218..985bd94f8940 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -64,12 +64,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi .map { case (attr, i) => attr.withName(s"_$i") }) val batchSize = conf.arrowMaxRecordsPerBatch - - val batchIter = if (batchSize > 0) { - new BatchIterator(iter, batchSize) - } else { - Iterator(iter) - } + val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) val columnarBatchIter = new ArrowPythonRunner( funcs, bufferSize, reuseWorker, From 40d7e8acf4cb1fcc988eb16fddc915f78a402ef7 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 4 Oct 2017 16:43:05 +0900 Subject: [PATCH 13/34] Add a test for complex groupby. --- python/pyspark/sql/tests.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b5d4887605dd..097424e652f6 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3455,6 +3455,25 @@ def test_coerce(self): expected = expected.assign(v=expected.v.astype('float64')) self.assertFramesEqual(expected, result) + def test_complex_groupby(self): + from pyspark.sql.functions import pandas_udf, col + df = self.data + + @pandas_udf(StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('norm', DoubleType())])) + def normalize(pdf): + v = pdf.v + return pdf.assign(norm=(v - v.mean()) / v.std()) + + result = df.groupby(col('id') % 2 == 0).apply(normalize).sort('id', 'v').toPandas() + pdf = df.toPandas() + expected = pdf.groupby(pdf['id'] % 2 == 0).apply(normalize.func) + expected = expected.sort_values(['id', 'v']).reset_index(drop=True) + expected = expected.assign(norm=expected.norm.astype('float64')) + self.assertFramesEqual(expected, result) + def test_wrong_return_type(self): from pyspark.sql.functions import pandas_udf df = self.data From 427a84717233a67a2de02a6e7cb327d524f64cf9 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 4 Oct 2017 16:46:19 +0900 Subject: [PATCH 14/34] Fix complex groupby. --- .../spark/sql/RelationalGroupedDataset.scala | 20 ++++++++++--------- .../python/FlatMapGroupsInPandasExec.scala | 14 +++++++++---- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 4ecaef9e2122..1ae58f95b476 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -438,23 +438,25 @@ class RelationalGroupedDataset protected[sql]( private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = { require(expr.vectorized, "Must pass a vectorized python udf") + require(expr.dataType.isInstanceOf[StructType], + "The returnType of the vectorized python udf must be a StructType") - val output = expr.dataType match { - case s: StructType => s.map { - case StructField(name, dataType, nullable, metadata) => - AttributeReference(name, dataType, nullable, metadata)() - } + val groupingNamedExpressions = groupingExprs.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() } + val groupingAttributes = groupingNamedExpressions.map(_.toAttribute) - val groupingAttributes: Seq[Attribute] = groupingExprs.map { - case ne: NamedExpression => ne.toAttribute - } + val child = df.logicalPlan + val project = Project(groupingNamedExpressions ++ child.output, child) + + val output = expr.dataType.asInstanceOf[StructType].toAttributes val plan = FlatMapGroupsInPandas( groupingAttributes, expr, output, - df.logicalPlan + project ) Dataset.ofRows( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 397d34a5c2f5..70ebfdd67c10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.types.StructType case class FlatMapGroupsInPandasExec( groupingAttributes: Seq[Attribute], @@ -52,16 +53,21 @@ case class FlatMapGroupsInPandasExec( val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) - val argOffsets = Array((0 until child.schema.length).toArray) + val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray) + val schema = StructType(child.schema.drop(groupingAttributes.length)) inputRDD.mapPartitionsInternal { iter => - val grouped = GroupedIterator(iter, groupingAttributes, child.output) + val dropGrouping = + UnsafeProjection.create(child.output.drop(groupingAttributes.length), child.output) + val grouped = GroupedIterator(iter, groupingAttributes, child.output).map { + case (_, iter) => iter.map(dropGrouping) + } val context = TaskContext.get() val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_UDF, argOffsets, child.schema) - .compute(grouped.map(_._2), context.partitionId(), context) + PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema) + .compute(grouped, context.partitionId(), context) val rowIter = new Iterator[InternalRow] { private var currentIter = if (columnarBatchIter.hasNext) { From 0929d4d38293cd698521468544833af299b2510e Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 4 Oct 2017 18:22:21 +0900 Subject: [PATCH 15/34] Add support for empty groupby. --- python/pyspark/sql/tests.py | 19 +++++++++++++++++++ .../python/FlatMapGroupsInPandasExec.scala | 11 ++++++++--- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 097424e652f6..4a8bcbae2687 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3474,6 +3474,25 @@ def normalize(pdf): expected = expected.assign(norm=expected.norm.astype('float64')) self.assertFramesEqual(expected, result) + def test_empty_groupby(self): + from pyspark.sql.functions import pandas_udf, col + df = self.data + + @pandas_udf(StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('norm', DoubleType())])) + def normalize(pdf): + v = pdf.v + return pdf.assign(norm=(v - v.mean()) / v.std()) + + result = df.groupby().apply(normalize).sort('id', 'v').toPandas() + pdf = df.toPandas() + expected = normalize.func(pdf) + expected = expected.sort_values(['id', 'v']).reset_index(drop=True) + expected = expected.assign(norm=expected.norm.astype('float64')) + self.assertFramesEqual(expected, result) + def test_wrong_return_type(self): from pyspark.sql.functions import pandas_udf df = self.data diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 70ebfdd67c10..26053e440a11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -24,7 +24,7 @@ import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} import org.apache.spark.sql.types.StructType @@ -41,8 +41,13 @@ case class FlatMapGroupsInPandasExec( override def producedAttributes: AttributeSet = AttributeSet(output) - override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(groupingAttributes) :: Nil + override def requiredChildDistribution: Seq[Distribution] = { + if (groupingAttributes.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingAttributes) :: Nil + } + } override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq(groupingAttributes.map(SortOrder(_, Ascending))) From d9a3e8d3470a806de2db20301e49854ace0de912 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 4 Oct 2017 18:54:49 +0900 Subject: [PATCH 16/34] Skip grouping if groupingAttributes is empty. --- .../python/FlatMapGroupsInPandasExec.scala | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 26053e440a11..c4cd8d3cb5a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -62,11 +62,17 @@ case class FlatMapGroupsInPandasExec( val schema = StructType(child.schema.drop(groupingAttributes.length)) inputRDD.mapPartitionsInternal { iter => - val dropGrouping = - UnsafeProjection.create(child.output.drop(groupingAttributes.length), child.output) - val grouped = GroupedIterator(iter, groupingAttributes, child.output).map { - case (_, iter) => iter.map(dropGrouping) + val grouped = if (groupingAttributes.isEmpty) { + Iterator(iter) + } else { + val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) + val dropGrouping = + UnsafeProjection.create(child.output.drop(groupingAttributes.length), child.output) + groupedIter.map { + case (_, iter) => iter.map(dropGrouping) + } } + val context = TaskContext.get() val columnarBatchIter = new ArrowPythonRunner( From ce0d54c1b6e3c7377c26bd8fe0b175879f9948db Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 4 Oct 2017 17:48:39 -0400 Subject: [PATCH 17/34] Address some new comments --- python/pyspark/sql/functions.py | 2 +- python/pyspark/sql/group.py | 15 +++++----- python/pyspark/sql/tests.py | 5 ++-- python/pyspark/worker.py | 17 ++++++----- .../spark/sql/RelationalGroupedDataset.scala | 15 ++-------- .../python/ArrowEvalPythonExec.scala | 15 ++++++++-- .../execution/python/ExtractPythonUDFs.scala | 2 +- .../python/FlatMapGroupsInPandasExec.scala | 29 +++++-------------- 8 files changed, 43 insertions(+), 57 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 0f2b852b8765..744b7b0f503b 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2233,7 +2233,7 @@ def pandas_udf(f=None, returnType=StringType()): ... def normalize(pdf): ... v = pdf.v ... return pdf.assign(v=(v - v.mean()) / v.std()) - >>> df.groupby('id').apply(normalize).show() # doctest: + SKIP + >>> df.groupby('id').apply(normalize).show() # doctest: +SKIP +---+-------------------+ | id| v| +---+-------------------+ diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 9262469fb0a4..fddbbc6a1cc8 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -55,8 +55,8 @@ class GroupedData(object): """ def __init__(self, jgd, df): - self._df = df self._jgd = jgd + self._df = df self.sql_ctx = df.sql_ctx @ignore_unicode_prefix @@ -193,17 +193,18 @@ def pivot(self, pivot_col, values=None): jgd = self._jgd.pivot(pivot_col) else: jgd = self._jgd.pivot(pivot_col, values) - return GroupedData(jgd, self) + return GroupedData(jgd, self._df) def apply(self, udf): """ Maps each group of the current :class:`DataFrame` using a pandas udf and returns the result as a :class:`DataFrame`. - The user-function should take a `pandas.DataFrame` and return another `pandas.DataFrame`. - Each group is passed as a `pandas.DataFrame` to the user-function and the returned - `pandas.DataFrame` are combined as a :class:`DataFrame`. The returned `pandas.DataFrame` - can be arbitrary length and its schema should match the returnType of the pandas udf. + The user-defined function should take a `pandas.DataFrame` and return another + `pandas.DataFrame`. Each group is passed as a `pandas.DataFrame` to the user-function and + the returned`pandas.DataFrame` are combined as a :class:`DataFrame`. The returned + `pandas.DataFrame` can be arbitrary length and its schema should match the returnType of + the pandas udf. :param udf: A wrapped function returned by `pandas_udf` @@ -214,7 +215,7 @@ def apply(self, udf): ... def normalize(pdf): ... v = pdf.v ... return pdf.assign(v=(v - v.mean()) / v.std()) - >>> df.groupby('id').apply(normalize).show() # doctest: + SKIP + >>> df.groupby('id').apply(normalize).show() # doctest: +SKIP +---+-------------------+ | id| v| +---+-------------------+ diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4a8bcbae2687..49494e6a768c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3106,9 +3106,8 @@ def assertFramesEqual(self, df_with_arrow, df_without): self.assertTrue(df_without.equals(df_with_arrow), msg=msg) def test_unsupported_datatype(self): - schema = StructType([StructField("dt", TimestampType(), True)]) - df = self.spark.createDataFrame([(datetime.datetime(1970, 1, 1),)], schema=schema) - + schema = StructType([StructField("dt", DateType(), True)]) + df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema) with QuietTest(self.sc): self.assertRaises(Exception, lambda: df.toPandas()) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 87d2322b3666..a9e21e551f9d 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -79,14 +79,15 @@ def wrap_pandas_udf(f, return_type): def fn(*a): import pandas as pd - out = f(*a) - assert isinstance(out, pd.DataFrame), \ - 'Return value from the user function is not a pandas.DataFrame.' - assert len(out.columns) == len(arrow_return_types), \ - 'Number of columns of the returned pd.DataFrame doesn\'t match output schema. ' \ - 'Expected: {} Actual: {}'.format(len(arrow_return_types), len(out.columns)) - - return [(out[out.columns[i]], arrow_return_types[i]) + result = f(*a) + assert isinstance(result, pd.DataFrame), \ + 'Return value of the user-defined function is not a pandas.DataFrame.' + assert len(result.columns) == len(arrow_return_types), \ + 'Number of columns of the returned pandas.DataFrame doesn\'t match ' \ + 'specified schema. ' \ + 'Expected: {} Actual: {}'.format(len(arrow_return_types), len(result.columns)) + + return [(result[result.columns[i]], arrow_return_types[i]) for i in range(len(arrow_return_types))] return fn diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 1ae58f95b476..2cd49207bfe8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -446,23 +446,12 @@ class RelationalGroupedDataset protected[sql]( case other => Alias(other, other.toString)() } val groupingAttributes = groupingNamedExpressions.map(_.toAttribute) - val child = df.logicalPlan val project = Project(groupingNamedExpressions ++ child.output, child) - val output = expr.dataType.asInstanceOf[StructType].toAttributes + val plan = FlatMapGroupsInPandas(groupingAttributes, expr, output, project) - val plan = FlatMapGroupsInPandas( - groupingAttributes, - expr, - output, - project - ) - - Dataset.ofRows( - df.sparkSession, - plan - ) + Dataset.ofRows(df.sparkSession, plan) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 985bd94f8940..81896187ecc4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -26,6 +26,12 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.types.StructType +/** + * Grouped a iterator into batches. + * This is similar to iter.grouped but returns Iterator[T] instead of Seq[T]. + * This is necessary because sometimes we cannot hold reference of input rows + * because the some input rows are mutable and can be reused. + */ private class BatchIterator[T](iter: Iterator[T], batchSize: Int) extends Iterator[Iterator[T]] { @@ -38,8 +44,12 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int) override def hasNext: Boolean = iter.hasNext && count < batchSize override def next(): T = { - count += 1 - iter.next() + if (!hasNext) { + Iterator.empty.next() + } else { + count += 1 + iter.next() + } } } } @@ -64,6 +74,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi .map { case (attr, i) => attr.withName(s"_$i") }) val batchSize = conf.arrowMaxRecordsPerBatch + // DO NOT use iter.grouped(). See BatchIterator. val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) val columnarBatchIter = new ArrowPythonRunner( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index d13489d928b0..8dee59c3d61d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -111,7 +111,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { } def apply(plan: SparkPlan): SparkPlan = plan transformUp { - // FlatMapGroupsInPandas and be evaluated in python worker + // FlatMapGroupsInPandas can be evaluated directly in python worker // Therefore we don't need to extract the UDFs case plan: FlatMapGroupsInPandasExec => plan case plan: SparkPlan => extract(plan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index c4cd8d3cb5a5..c0bf95cb9dba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -28,6 +28,11 @@ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} import org.apache.spark.sql.types.StructType +/** + * FlatMap groups using a pandas udf. + * + * This is used by pyspark.sql.DataFrame.groupby().apply() + */ case class FlatMapGroupsInPandasExec( groupingAttributes: Seq[Attribute], func: Expression, @@ -69,7 +74,7 @@ case class FlatMapGroupsInPandasExec( val dropGrouping = UnsafeProjection.create(child.output.drop(groupingAttributes.length), child.output) groupedIter.map { - case (_, iter) => iter.map(dropGrouping) + case (_, groupedRowIter) => groupedRowIter.map(dropGrouping) } } @@ -80,27 +85,7 @@ case class FlatMapGroupsInPandasExec( PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema) .compute(grouped, context.partitionId(), context) - val rowIter = new Iterator[InternalRow] { - private var currentIter = if (columnarBatchIter.hasNext) { - val batch = columnarBatchIter.next() - batch.rowIterator.asScala - } else { - Iterator.empty - } - - override def hasNext: Boolean = currentIter.hasNext || { - if (columnarBatchIter.hasNext) { - currentIter = columnarBatchIter.next().rowIterator.asScala - hasNext - } else { - false - } - } - - override def next(): InternalRow = currentIter.next() - } - - rowIter.map(UnsafeProjection.create(output, output)) + columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) } } } From 657942b1b4080c30fa5c60bcd700c862fb571465 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 4 Oct 2017 22:44:30 -0400 Subject: [PATCH 18/34] Fix minor typo --- .../org/apache/spark/sql/catalyst/plans/logical/object.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index e0d91309d342..3a394d2340f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -526,7 +526,7 @@ case class FlatMapGroupsInPandas( output: Seq[Attribute], child: LogicalPlan) extends UnaryNode { /** - * This is needed because output attributes is considered `reference` when + * This is needed because output attributes are considered `references` when * passed through the constructor. * * Without this, catalyst will complain that output attributes are missing From fa88c881a2fa3a7bb49af882ee9c482314184ff1 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 4 Oct 2017 23:16:11 -0400 Subject: [PATCH 19/34] Add doc for FlatMapGroupsInPandasExec --- .../python/FlatMapGroupsInPandasExec.scala | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index c0bf95cb9dba..6319a8102da6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -29,9 +29,22 @@ import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode import org.apache.spark.sql.types.StructType /** - * FlatMap groups using a pandas udf. + * FlatMap groups using a udf: pandas.Dataframe -> pandas.DataFrame. + * This is used by pyspark.sql.DataFrame.groupby().apply(). * - * This is used by pyspark.sql.DataFrame.groupby().apply() + * Rows in each group are passed to the python worker as a Arrow record batch. + * The python worker turns the record batch to a pandas.DataFrame, invoke the + * use-defined function, and passes the resulting pandas.DataFrame + * as a Arrow record batch. Finally, each record batch is turned to + * Iterator[InternalRow] using ColumnarBatch. + * + * Note on memory usage: + * Both the python worker and the java executor need to have enough memory to + * hold the largest group. The memory on the java side is used to construct the + * record batch (off heap memory). The memory on the python side is used for + * holding the pandas.DataFrame. It's possible to further split one group into + * multiple record batches to reduce the memory footprint on the java side, this + * is left as future work. */ case class FlatMapGroupsInPandasExec( groupingAttributes: Seq[Attribute], From e4efb3281008a2b450f9013aeb8f1ac9cf4ffa9e Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 5 Oct 2017 00:23:19 -0400 Subject: [PATCH 20/34] Fix doctest in group.py --- python/pyspark/sql/group.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index fddbbc6a1cc8..673d6fc4726e 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -215,7 +215,7 @@ def apply(self, udf): ... def normalize(pdf): ... v = pdf.v ... return pdf.assign(v=(v - v.mean()) / v.std()) - >>> df.groupby('id').apply(normalize).show() # doctest: +SKIP + >>> df.groupby('id').apply(normalize).show() # doctest: +SKIP +---+-------------------+ | id| v| +---+-------------------+ From f572385e28a1ccd2f8663adf64910d5f0a0ce67c Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 5 Oct 2017 10:12:41 -0400 Subject: [PATCH 21/34] Fix doctest for group.py --- python/pyspark/sql/group.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 673d6fc4726e..693e28abba4d 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -208,6 +208,7 @@ def apply(self, udf): :param udf: A wrapped function returned by `pandas_udf` + >>> from pyspark.sql.functions import pandas_udf >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) @@ -267,6 +268,7 @@ def _test(): .getOrCreate() sc = spark.sparkContext globs['sc'] = sc + globs['spark'] = spark globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) From 5162ed1774bd60477f43bfb020047c3bebe5cc48 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 5 Oct 2017 15:32:51 -0400 Subject: [PATCH 22/34] Add comments and standardize exception handling in wrap_pandas_udf --- python/pyspark/sql/group.py | 3 ++- python/pyspark/worker.py | 29 ++++++++++++++++++++--------- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 693e28abba4d..8e4af4d9298b 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -195,6 +195,7 @@ def pivot(self, pivot_col, values=None): jgd = self._jgd.pivot(pivot_col, values) return GroupedData(jgd, self._df) + @since(2.3) def apply(self, udf): """ Maps each group of the current :class:`DataFrame` using a pandas udf and returns the result @@ -206,7 +207,7 @@ def apply(self, udf): `pandas.DataFrame` can be arbitrary length and its schema should match the returnType of the pandas udf. - :param udf: A wrapped function returned by `pandas_udf` + :param udf: A wrapped udf function returned by :meth:`pyspark.sql.functions.pandas_udf` >>> from pyspark.sql.functions import pandas_udf >>> df = spark.createDataFrame( diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index a9e21e551f9d..a09b1fae1b6f 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -74,22 +74,32 @@ def wrap_udf(f, return_type): def wrap_pandas_udf(f, return_type): + # If the return_type is a StructType, it indicates this is a groupby apply udf, + # otherwise, it's a vectorized column udf. + # We can distinguish these two by return type because in groupby apply, we always specify + # returnType as a StructType, and in vectorized column udf, StructType is not supported. + # + # TODO: This logic is a bit hacky and might not work for future pandas udfs. Need refactoring. if isinstance(return_type, StructType): arrow_return_types = [to_arrow_type(field.dataType) for field in return_type] - def fn(*a): + # Verify the return type and number of columns in result + def verify_result_type(*a): import pandas as pd result = f(*a) - assert isinstance(result, pd.DataFrame), \ - 'Return value of the user-defined function is not a pandas.DataFrame.' - assert len(result.columns) == len(arrow_return_types), \ - 'Number of columns of the returned pandas.DataFrame doesn\'t match ' \ - 'specified schema. ' \ - 'Expected: {} Actual: {}'.format(len(arrow_return_types), len(result.columns)) + if not isinstance(result, pd.DataFrame): + raise TypeError("Return type of the user-defined function should be a " + "Pandas.DataFrame") + if not len(result.columns) == len(arrow_return_types): + raise RuntimeError( + "Number of columns of the returned Pandas.DataFrame " \ + "doesn't match specified schema. " \ + "Expected: {} Actual: {}".format(len(arrow_return_types), len(result.columns))) return [(result[result.columns[i]], arrow_return_types[i]) for i in range(len(arrow_return_types))] - return fn + + return verify_result_type else: arrow_return_type = to_arrow_type(return_type) @@ -97,7 +107,8 @@ def fn(*a): def verify_result_length(*a): result = f(*a) if not hasattr(result, "__len__"): - raise TypeError("Return type of pandas_udf should be a Pandas.Series") + raise TypeError("Return type of the user-defined functon should be a " + "Pandas.Series") if len(result) != len(a[0]): raise RuntimeError("Result vector from pandas_udf was not the required length: " "expected %d, got %d" % (len(a[0]), len(result))) From d628f4ede208b2206d6469676fb4e0779dc8f320 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 5 Oct 2017 19:24:09 -0400 Subject: [PATCH 23/34] Minor: Fix pep8 --- python/pyspark/worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index a09b1fae1b6f..efa9aa1897e7 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -92,8 +92,8 @@ def verify_result_type(*a): "Pandas.DataFrame") if not len(result.columns) == len(arrow_return_types): raise RuntimeError( - "Number of columns of the returned Pandas.DataFrame " \ - "doesn't match specified schema. " \ + "Number of columns of the returned Pandas.DataFrame " + "doesn't match specified schema. " "Expected: {} Actual: {}".format(len(arrow_return_types), len(result.columns))) return [(result[result.columns[i]], arrow_return_types[i]) From 20fb1fe9cbf033d73ecf2851f9cb1dc94f41fb3e Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 5 Oct 2017 23:08:52 -0400 Subject: [PATCH 24/34] Fix test --- python/pyspark/sql/tests.py | 2 +- python/pyspark/worker.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 49494e6a768c..70a75371c0da 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3356,7 +3356,7 @@ def test_vectorized_udf_return_scalar(self): df = self.spark.range(10) f = pandas_udf(lambda x: 1.0, DoubleType()) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Return.*type.*pandas_udf.*Series'): + with self.assertRaisesRegexp(Exception, 'Return.*type.*Series'): df.select(f(col('id'))).collect() def test_vectorized_udf_decorator(self): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index efa9aa1897e7..2601e987593c 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -88,8 +88,8 @@ def verify_result_type(*a): import pandas as pd result = f(*a) if not isinstance(result, pd.DataFrame): - raise TypeError("Return type of the user-defined function should be a " - "Pandas.DataFrame") + raise TypeError("Return type of the user-defined function should be " + "Pandas.DataFrame, but is {}".format(type(result))) if not len(result.columns) == len(arrow_return_types): raise RuntimeError( "Number of columns of the returned Pandas.DataFrame " @@ -107,8 +107,8 @@ def verify_result_type(*a): def verify_result_length(*a): result = f(*a) if not hasattr(result, "__len__"): - raise TypeError("Return type of the user-defined functon should be a " - "Pandas.Series") + raise TypeError("Return type of the user-defined functon should be " + "Pandas.Series, but is {}".format(type(result))) if len(result) != len(a[0]): raise RuntimeError("Result vector from pandas_udf was not the required length: " "expected %d, got %d" % (len(a[0]), len(result))) From 284ba00be0cbc357ac42900f8c4af57901d147c5 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Fri, 6 Oct 2017 10:57:48 -0400 Subject: [PATCH 25/34] Minor edit to groupby apply doc --- python/pyspark/sql/group.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 8e4af4d9298b..ef5039830b60 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -202,10 +202,10 @@ def apply(self, udf): as a :class:`DataFrame`. The user-defined function should take a `pandas.DataFrame` and return another - `pandas.DataFrame`. Each group is passed as a `pandas.DataFrame` to the user-function and - the returned`pandas.DataFrame` are combined as a :class:`DataFrame`. The returned - `pandas.DataFrame` can be arbitrary length and its schema should match the returnType of - the pandas udf. + `pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame` + to the user-function and the returned `pandas.DataFrame` are combined as a + :class:`DataFrame`. The returned `pandas.DataFrame` can be arbitrary length and its schema + must match the returnType of the pandas udf. :param udf: A wrapped udf function returned by :meth:`pyspark.sql.functions.pandas_udf` From 876b118ebb9689d9b8895945d5b409fab1d5a8e8 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 9 Oct 2017 11:42:38 -0400 Subject: [PATCH 26/34] Improve documentation. FlatMapGroupsInPandas logical node to pythonLogicalOperators.scala. --- python/pyspark/sql/group.py | 11 +++-- .../sql/catalyst/plans/logical/object.scala | 14 ------ .../logical/pythonLogicalOperators.scala | 43 +++++++++++++++++++ .../spark/sql/RelationalGroupedDataset.scala | 11 +++++ .../python/FlatMapGroupsInPandasExec.scala | 3 +- 5 files changed, 62 insertions(+), 20 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index ef5039830b60..fc7e3937ca06 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -199,13 +199,16 @@ def pivot(self, pivot_col, values=None): def apply(self, udf): """ Maps each group of the current :class:`DataFrame` using a pandas udf and returns the result - as a :class:`DataFrame`. + as a `DataFrame`. The user-defined function should take a `pandas.DataFrame` and return another `pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame` - to the user-function and the returned `pandas.DataFrame` are combined as a - :class:`DataFrame`. The returned `pandas.DataFrame` can be arbitrary length and its schema - must match the returnType of the pandas udf. + to the user-function and the returned `pandas.DataFrame` are combined as a `DataFrame`. + The returned `pandas.DataFrame` can be arbitrary length and its schema must match the + returnType of the pandas udf. + + This function does not support partial aggregation, and requires shuffling all the data in + the `DataFrame`. :param udf: A wrapped udf function returned by :meth:`pyspark.sql.functions.pandas_udf` diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 3a394d2340f2..2535b80264eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -520,17 +520,3 @@ case class CoGroup( left: LogicalPlan, right: LogicalPlan) extends BinaryNode with ObjectProducer -case class FlatMapGroupsInPandas( - groupingAttributes: Seq[Attribute], - functionExpr: Expression, - output: Seq[Attribute], - child: LogicalPlan) extends UnaryNode { - /** - * This is needed because output attributes are considered `references` when - * passed through the constructor. - * - * Without this, catalyst will complain that output attributes are missing - * from the input. - */ - override val producedAttributes = AttributeSet(output) -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala new file mode 100644 index 000000000000..cf30a5dbacbc --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression} + +/** + * Logical nodes specific to PySpark. + */ + +/** + * FlatMap groups using a udf: pandas.Dataframe -> pandas.DataFrame. + * This is used by DataFrame.groupby().apply(). + */ +case class FlatMapGroupsInPandas( + groupingAttributes: Seq[Attribute], + functionExpr: Expression, + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + /** + * This is needed because output attributes are considered `references` when + * passed through the constructor. + * + * Without this, catalyst will complain that output attributes are missing + * from the input. + */ + override val producedAttributes = AttributeSet(output) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 2cd49207bfe8..e996f5a4a53e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -436,6 +436,17 @@ class RelationalGroupedDataset protected[sql]( df.logicalPlan)) } + /** + * Applies a vectorized python use-defined function to each group of data. + * The user-defined function defines a transformation: `Pandas.DataFrame` -> `Pandas.DataFrame`. + * For each group, all elements in the group are passed as a `Pandas.DataFrame` and the results + * for all groups are combined into a new `DataFrame`. + * + * This function does not support partial aggregation, and requires shuffling all the data in + * the `DataFrame`. + * + * This function uses `Arrow` as serialization format between JVM and python workers. + */ private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = { require(expr.vectorized, "Must pass a vectorized python udf") require(expr.dataType.isInstanceOf[StructType], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 6319a8102da6..db50f4cec3b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -29,8 +29,7 @@ import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode import org.apache.spark.sql.types.StructType /** - * FlatMap groups using a udf: pandas.Dataframe -> pandas.DataFrame. - * This is used by pyspark.sql.DataFrame.groupby().apply(). + * Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandas]] * * Rows in each group are passed to the python worker as a Arrow record batch. * The python worker turns the record batch to a pandas.DataFrame, invoke the From b0410a25f710029e93caf69d9037c843e63f0c41 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 9 Oct 2017 23:19:00 -0400 Subject: [PATCH 27/34] Fix use-defined -> user-defined --- .../scala/org/apache/spark/sql/RelationalGroupedDataset.scala | 2 +- .../spark/sql/execution/python/FlatMapGroupsInPandasExec.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index e996f5a4a53e..76e177e3d232 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -437,7 +437,7 @@ class RelationalGroupedDataset protected[sql]( } /** - * Applies a vectorized python use-defined function to each group of data. + * Applies a vectorized python user-defined function to each group of data. * The user-defined function defines a transformation: `Pandas.DataFrame` -> `Pandas.DataFrame`. * For each group, all elements in the group are passed as a `Pandas.DataFrame` and the results * for all groups are combined into a new `DataFrame`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index db50f4cec3b9..ba19ea96ef9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.StructType * * Rows in each group are passed to the python worker as a Arrow record batch. * The python worker turns the record batch to a pandas.DataFrame, invoke the - * use-defined function, and passes the resulting pandas.DataFrame + * user-defined function, and passes the resulting pandas.DataFrame * as a Arrow record batch. Finally, each record batch is turned to * Iterator[InternalRow] using ColumnarBatch. * From 87edfccda2155e61fc8621573e1d861b256c6e07 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 6 Oct 2017 16:27:28 -0700 Subject: [PATCH 28/34] changed wrapping to be in one place --- python/pyspark/sql/group.py | 18 +++++++++-- python/pyspark/worker.py | 61 +++++++++++++------------------------ 2 files changed, 36 insertions(+), 43 deletions(-) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index fc7e3937ca06..fae8e9063e06 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -235,6 +235,7 @@ def apply(self, udf): """ from pyspark.sql.functions import pandas_udf + from pyspark.sql.types import to_arrow_type # Columns are special because hasattr always return True if isinstance(udf, Column) or not hasattr(udf, 'func') or not udf.vectorized: @@ -246,14 +247,25 @@ def apply(self, udf): func = udf.func returnType = udf.returnType - # The python executors expects the function to take a list of pd.Series as input + # The python executors expects the function to use pd.Series as input and output # So we to create a wrapper function that turns that to a pd.DataFrame before passing - # down to the user function + # down to the user function, then turn the result pd.DataFrame back into pd.Series columns = df.columns + arrow_return_types = [to_arrow_type(field.dataType) for field in returnType] def wrapped(*cols): import pandas as pd - return func(pd.concat(cols, axis=1, keys=columns)) + result = func(pd.concat(cols, axis=1, keys=columns)) + if not isinstance(result, pd.DataFrame): + raise TypeError("Return type of the user-defined function should be " + "Pandas.DataFrame, but is {}".format(type(result))) + if not len(result.columns) == len(arrow_return_types): + raise RuntimeError( + "Number of columns of the returned Pandas.DataFrame " + "doesn't match specified schema. " + "Expected: {} Actual: {}".format(len(arrow_return_types), len(result.columns))) + return [(result[result.columns[i]], arrow_return_types[i]) + for i in range(len(arrow_return_types))] wrapped_udf_obj = pandas_udf(wrapped, returnType) udf_column = wrapped_udf_obj(*[df[col] for col in df.columns]) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 2601e987593c..b2c2d07f53d9 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -74,47 +74,19 @@ def wrap_udf(f, return_type): def wrap_pandas_udf(f, return_type): - # If the return_type is a StructType, it indicates this is a groupby apply udf, - # otherwise, it's a vectorized column udf. - # We can distinguish these two by return type because in groupby apply, we always specify - # returnType as a StructType, and in vectorized column udf, StructType is not supported. - # - # TODO: This logic is a bit hacky and might not work for future pandas udfs. Need refactoring. - if isinstance(return_type, StructType): - arrow_return_types = [to_arrow_type(field.dataType) for field in return_type] - - # Verify the return type and number of columns in result - def verify_result_type(*a): - import pandas as pd - result = f(*a) - if not isinstance(result, pd.DataFrame): - raise TypeError("Return type of the user-defined function should be " - "Pandas.DataFrame, but is {}".format(type(result))) - if not len(result.columns) == len(arrow_return_types): - raise RuntimeError( - "Number of columns of the returned Pandas.DataFrame " - "doesn't match specified schema. " - "Expected: {} Actual: {}".format(len(arrow_return_types), len(result.columns))) - - return [(result[result.columns[i]], arrow_return_types[i]) - for i in range(len(arrow_return_types))] - - return verify_result_type + arrow_return_type = to_arrow_type(return_type) - else: - arrow_return_type = to_arrow_type(return_type) - - def verify_result_length(*a): - result = f(*a) - if not hasattr(result, "__len__"): - raise TypeError("Return type of the user-defined functon should be " - "Pandas.Series, but is {}".format(type(result))) - if len(result) != len(a[0]): - raise RuntimeError("Result vector from pandas_udf was not the required length: " - "expected %d, got %d" % (len(a[0]), len(result))) - return result + def verify_result_length(*a): + result = f(*a) + if not hasattr(result, "__len__"): + raise TypeError("Return type of the user-defined functon should be " + "Pandas.Series, but is {}".format(type(result))) + if len(result) != len(a[0]): + raise RuntimeError("Result vector from pandas_udf was not the required length: " + "expected %d, got %d" % (len(a[0]), len(result))) + return result - return lambda *a: (verify_result_length(*a), arrow_return_type) + return lambda *a: (verify_result_length(*a), arrow_return_type) def read_single_udf(pickleSer, infile, eval_type): @@ -129,7 +101,16 @@ def read_single_udf(pickleSer, infile, eval_type): row_func = chain(row_func, f) # the last returnType will be the return type of UDF if eval_type == PythonEvalType.SQL_PANDAS_UDF: - return arg_offsets, wrap_pandas_udf(row_func, return_type) + # If the return_type is a StructType, it indicates this is a groupby apply udf, + # and has already been wrapped under apply(), otherwise, it's a vectorized column udf. + # We can distinguish these two by return type because in groupby apply, we always specify + # returnType as a StructType, and in vectorized column udf, StructType is not supported. + # + # TODO: This logic is a bit hacky and might not work for future pandas udfs. Need refactoring. + if isinstance(return_type, StructType): + return arg_offsets, row_func + else: + return arg_offsets, wrap_pandas_udf(row_func, return_type) else: return arg_offsets, wrap_udf(row_func, return_type) From 4413ed439795b1a192a81cf157b99de251c33c40 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 9 Oct 2017 16:33:03 -0700 Subject: [PATCH 29/34] changed to pickle spark type instead of arrow types for wrapped func --- python/pyspark/sql/group.py | 10 ++++----- python/pyspark/worker.py | 43 ++++++++++++++++++------------------- 2 files changed, 26 insertions(+), 27 deletions(-) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index fae8e9063e06..1d2c12889d1c 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -251,7 +251,6 @@ def apply(self, udf): # So we to create a wrapper function that turns that to a pd.DataFrame before passing # down to the user function, then turn the result pd.DataFrame back into pd.Series columns = df.columns - arrow_return_types = [to_arrow_type(field.dataType) for field in returnType] def wrapped(*cols): import pandas as pd @@ -259,13 +258,14 @@ def wrapped(*cols): if not isinstance(result, pd.DataFrame): raise TypeError("Return type of the user-defined function should be " "Pandas.DataFrame, but is {}".format(type(result))) - if not len(result.columns) == len(arrow_return_types): + if not len(result.columns) == len(returnType): raise RuntimeError( "Number of columns of the returned Pandas.DataFrame " "doesn't match specified schema. " - "Expected: {} Actual: {}".format(len(arrow_return_types), len(result.columns))) - return [(result[result.columns[i]], arrow_return_types[i]) - for i in range(len(arrow_return_types))] + "Expected: {} Actual: {}".format(len(returnType), len(result.columns))) + arrow_return_types = (to_arrow_type(field.dataType) for field in returnType) + return [(result[result.columns[i]], arrow_type) + for i, arrow_type in enumerate(arrow_return_types)] wrapped_udf_obj = pandas_udf(wrapped, returnType) udf_column = wrapped_udf_obj(*[df[col] for col in df.columns]) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index b2c2d07f53d9..ac53184e94a4 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -74,20 +74,28 @@ def wrap_udf(f, return_type): def wrap_pandas_udf(f, return_type): - arrow_return_type = to_arrow_type(return_type) - - def verify_result_length(*a): - result = f(*a) - if not hasattr(result, "__len__"): - raise TypeError("Return type of the user-defined functon should be " - "Pandas.Series, but is {}".format(type(result))) - if len(result) != len(a[0]): - raise RuntimeError("Result vector from pandas_udf was not the required length: " - "expected %d, got %d" % (len(a[0]), len(result))) - return result + # If the return_type is a StructType, it indicates this is a groupby apply udf, + # and has already been wrapped under apply(), otherwise, it's a vectorized column udf. + # We can distinguish these two by return type because in groupby apply, we always specify + # returnType as a StructType, and in vectorized column udf, StructType is not supported. + # + # TODO: Look into refactoring use of StructType to be more flexible for future pandas_udfs + if isinstance(return_type, StructType): + return lambda *a: f(*a) + else: + arrow_return_type = to_arrow_type(return_type) - return lambda *a: (verify_result_length(*a), arrow_return_type) + def verify_result_length(*a): + result = f(*a) + if not hasattr(result, "__len__"): + raise TypeError("Return type of the user-defined functon should be " + "Pandas.Series, but is {}".format(type(result))) + if len(result) != len(a[0]): + raise RuntimeError("Result vector from pandas_udf was not the required length: " + "expected %d, got %d" % (len(a[0]), len(result))) + return result + return lambda *a: (verify_result_length(*a), arrow_return_type) def read_single_udf(pickleSer, infile, eval_type): num_arg = read_int(infile) @@ -101,16 +109,7 @@ def read_single_udf(pickleSer, infile, eval_type): row_func = chain(row_func, f) # the last returnType will be the return type of UDF if eval_type == PythonEvalType.SQL_PANDAS_UDF: - # If the return_type is a StructType, it indicates this is a groupby apply udf, - # and has already been wrapped under apply(), otherwise, it's a vectorized column udf. - # We can distinguish these two by return type because in groupby apply, we always specify - # returnType as a StructType, and in vectorized column udf, StructType is not supported. - # - # TODO: This logic is a bit hacky and might not work for future pandas udfs. Need refactoring. - if isinstance(return_type, StructType): - return arg_offsets, row_func - else: - return arg_offsets, wrap_pandas_udf(row_func, return_type) + return arg_offsets, wrap_pandas_udf(row_func, return_type) else: return arg_offsets, wrap_udf(row_func, return_type) From a064b21b23d2c3dee9993c3b07d771fa8c09b8ba Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 9 Oct 2017 16:44:09 -0700 Subject: [PATCH 30/34] move import --- python/pyspark/sql/group.py | 2 +- python/pyspark/worker.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 1d2c12889d1c..ac42b752b8a2 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -235,7 +235,6 @@ def apply(self, udf): """ from pyspark.sql.functions import pandas_udf - from pyspark.sql.types import to_arrow_type # Columns are special because hasattr always return True if isinstance(udf, Column) or not hasattr(udf, 'func') or not udf.vectorized: @@ -253,6 +252,7 @@ def apply(self, udf): columns = df.columns def wrapped(*cols): + from pyspark.sql.types import to_arrow_type import pandas as pd result = func(pd.concat(cols, axis=1, keys=columns)) if not isinstance(result, pd.DataFrame): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index ac53184e94a4..eb6d48688dc0 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -97,6 +97,7 @@ def verify_result_length(*a): return lambda *a: (verify_result_length(*a), arrow_return_type) + def read_single_udf(pickleSer, infile, eval_type): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] From a036f70f89b6fadbbc3b8d80feecf2c086e32d18 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 10 Oct 2017 10:39:38 -0400 Subject: [PATCH 31/34] Address CR comments --- python/pyspark/sql/functions.py | 6 +++++- python/pyspark/sql/tests.py | 10 +++++----- .../spark/sql/catalyst/plans/logical/object.scala | 3 +-- .../plans/logical/pythonLogicalOperators.scala | 6 +----- 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 744b7b0f503b..1ffb059b6d92 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2222,7 +2222,7 @@ def pandas_udf(f=None, returnType=StringType()): 2. A `pandas.DataFrame` -> A `pandas.DataFrame` - This udf is used with :meth:`pyspark.sql.GroupedData.apply`. + This udf is only used with :meth:`pyspark.sql.GroupedData.apply`. The returnType should be a :class:`StructType` describing the schema of the returned `pandas.DataFrame`. @@ -2244,6 +2244,10 @@ def pandas_udf(f=None, returnType=StringType()): | 2| 1.1094003924504583| +---+-------------------+ + .. note:: This type of udf cannot be used with functions such as `withColumn` or `select` + because it defines a `DataFrame` transformation rather than `Column` + transformation. + .. seealso:: :meth:`pyspark.sql.GroupedData.apply` .. note:: The user-defined function must be deterministic. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 70a75371c0da..9d1b429a5b7f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3414,7 +3414,7 @@ def test_simple(self): df = self.data foo_udf = pandas_udf( - lambda df: df.assign(v1=df.v * df.id * 1.0, v2=df.v + df.id), + lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), StructType( [StructField('id', LongType()), StructField('v', IntegerType()), @@ -3434,8 +3434,8 @@ def test_decorator(self): StructField('v', IntegerType()), StructField('v1', DoubleType()), StructField('v2', LongType())])) - def foo(df): - return df.assign(v1=df.v * df.id * 1.0, v2=df.v + df.id) + def foo(pdf): + return pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id) result = df.groupby('id').apply(foo).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) @@ -3446,7 +3446,7 @@ def test_coerce(self): df = self.data foo = pandas_udf( - lambda df: df, + lambda pdf: pdf, StructType([StructField('id', LongType()), StructField('v', DoubleType())])) result = df.groupby('id').apply(foo).sort('id').toPandas() @@ -3497,7 +3497,7 @@ def test_wrong_return_type(self): df = self.data foo = pandas_udf( - lambda df: df, + lambda pdf: pdf, StructType([StructField('id', LongType()), StructField('v', StringType())])) with QuietTest(self.sc): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 2535b80264eb..bfb70c2ef4c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke -import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode } import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -519,4 +519,3 @@ case class CoGroup( outputObjAttr: Attribute, left: LogicalPlan, right: LogicalPlan) extends BinaryNode with ObjectProducer - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index cf30a5dbacbc..8abab24bc9b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -20,11 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression} /** - * Logical nodes specific to PySpark. - */ - -/** - * FlatMap groups using a udf: pandas.Dataframe -> pandas.DataFrame. + * FlatMap groups using an udf: pandas.Dataframe -> pandas.DataFrame. * This is used by DataFrame.groupby().apply(). */ case class FlatMapGroupsInPandas( From b88a4d8fd2c805b883a88eb24100e360c198726a Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 10 Oct 2017 10:43:07 -0400 Subject: [PATCH 32/34] Minor: Fix typo in doc --- python/pyspark/sql/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1ffb059b6d92..9bc12c3b7a16 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2245,7 +2245,7 @@ def pandas_udf(f=None, returnType=StringType()): +---+-------------------+ .. note:: This type of udf cannot be used with functions such as `withColumn` or `select` - because it defines a `DataFrame` transformation rather than `Column` + because it defines a `DataFrame` transformation rather than a `Column` transformation. .. seealso:: :meth:`pyspark.sql.GroupedData.apply` From 9c2b10e16da4690afaa72599299346a62a0da668 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 10 Oct 2017 11:11:25 -0400 Subject: [PATCH 33/34] Clean up imports in ExtractPythonUDFs.scala --- .../apache/spark/sql/execution/python/ExtractPythonUDFs.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 8dee59c3d61d..e3f952e221d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} @@ -172,7 +171,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { val newPlan = extract(rewritten) if (newPlan.output != plan.output) { // Trim away the new UDF value if it was only used for filtering or something. - execution.ProjectExec(plan.output, newPlan) + ProjectExec(plan.output, newPlan) } else { newPlan } From dc1d4069cad71568017b39a1a675a71b7ca3b5ae Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 10 Oct 2017 14:32:31 -0400 Subject: [PATCH 34/34] Address comments about docs --- python/pyspark/sql/group.py | 9 +++++---- .../spark/sql/RelationalGroupedDataset.scala | 11 ++++++----- .../python/FlatMapGroupsInPandasExec.scala | 18 +++++++++--------- 3 files changed, 20 insertions(+), 18 deletions(-) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index ac42b752b8a2..817d0bc83bb7 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -203,14 +203,15 @@ def apply(self, udf): The user-defined function should take a `pandas.DataFrame` and return another `pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame` - to the user-function and the returned `pandas.DataFrame` are combined as a `DataFrame`. - The returned `pandas.DataFrame` can be arbitrary length and its schema must match the + to the user-function and the returned `pandas.DataFrame`s are combined as a + :class:`DataFrame`. + The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the returnType of the pandas udf. This function does not support partial aggregation, and requires shuffling all the data in - the `DataFrame`. + the :class:`DataFrame`. - :param udf: A wrapped udf function returned by :meth:`pyspark.sql.functions.pandas_udf` + :param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf` >>> from pyspark.sql.functions import pandas_udf >>> df = spark.createDataFrame( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 76e177e3d232..cd0ac1feffa5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -438,14 +438,15 @@ class RelationalGroupedDataset protected[sql]( /** * Applies a vectorized python user-defined function to each group of data. - * The user-defined function defines a transformation: `Pandas.DataFrame` -> `Pandas.DataFrame`. - * For each group, all elements in the group are passed as a `Pandas.DataFrame` and the results - * for all groups are combined into a new `DataFrame`. + * The user-defined function defines a transformation: `pandas.DataFrame` -> `pandas.DataFrame`. + * For each group, all elements in the group are passed as a `pandas.DataFrame` and the results + * for all groups are combined into a new [[DataFrame]]. * * This function does not support partial aggregation, and requires shuffling all the data in - * the `DataFrame`. + * the [[DataFrame]]. * - * This function uses `Arrow` as serialization format between JVM and python workers. + * This function uses Apache Arrow as serialization format between Java executors and Python + * workers. */ private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = { require(expr.vectorized, "Must pass a vectorized python udf") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index ba19ea96ef9d..b996b5bb38ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -31,18 +31,18 @@ import org.apache.spark.sql.types.StructType /** * Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandas]] * - * Rows in each group are passed to the python worker as a Arrow record batch. - * The python worker turns the record batch to a pandas.DataFrame, invoke the - * user-defined function, and passes the resulting pandas.DataFrame - * as a Arrow record batch. Finally, each record batch is turned to + * Rows in each group are passed to the Python worker as an Arrow record batch. + * The Python worker turns the record batch to a `pandas.DataFrame`, invoke the + * user-defined function, and passes the resulting `pandas.DataFrame` + * as an Arrow record batch. Finally, each record batch is turned to * Iterator[InternalRow] using ColumnarBatch. * * Note on memory usage: - * Both the python worker and the java executor need to have enough memory to - * hold the largest group. The memory on the java side is used to construct the - * record batch (off heap memory). The memory on the python side is used for - * holding the pandas.DataFrame. It's possible to further split one group into - * multiple record batches to reduce the memory footprint on the java side, this + * Both the Python worker and the Java executor need to have enough memory to + * hold the largest group. The memory on the Java side is used to construct the + * record batch (off heap memory). The memory on the Python side is used for + * holding the `pandas.DataFrame`. It's possible to further split one group into + * multiple record batches to reduce the memory footprint on the Java side, this * is left as future work. */ case class FlatMapGroupsInPandasExec(