diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 5197a9e004610..e904b2948bdca 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -245,7 +245,8 @@ def _inferSchema(self, rdd, samplingRatio=None): @since(1.3) @ignore_unicode_prefix - def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True): + def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True, + numSlices=None): """ Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`. @@ -276,6 +277,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr We can also use ``int`` as a short name for :class:`pyspark.sql.types.IntegerType`. :param samplingRatio: the sample ratio of rows used for inferring :param verifySchema: verify data types of every row against schema. + :param numSlices: specify as :class:`int` the number of slices (partitions) to distribute + ``data`` across. Applies to ``data`` of :class:`list` or :class:`pandas.DataFrame`. + Defaults to `self.sparkContext.defaultParallelism`. :return: :class:`DataFrame` .. versionchanged:: 2.0 @@ -334,7 +338,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr ... Py4JJavaError: ... """ - return self.sparkSession.createDataFrame(data, schema, samplingRatio, verifySchema) + return self.sparkSession.createDataFrame(data, schema, samplingRatio, verifySchema, + numSlices) @since(1.3) def registerDataFrameAsTable(self, df, tableName): diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index c1bf2bd76fb7c..938453457727b 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -388,7 +388,7 @@ def _createFromRDD(self, rdd, schema, samplingRatio): rdd = rdd.map(schema.toInternal) return rdd, schema - def _createFromLocal(self, data, schema): + def _createFromLocal(self, data, schema, numSlices=None): """ Create an RDD for DataFrame from a list or pandas.DataFrame, returns the RDD and schema. @@ -412,11 +412,12 @@ def _createFromLocal(self, data, schema): # convert python objects to sql data data = [schema.toInternal(row) for row in data] - return self._sc.parallelize(data), schema + return self._sc.parallelize(data, numSlices=numSlices), schema @since(2.0) @ignore_unicode_prefix - def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True): + def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True, + numSlices=None): """ Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`. @@ -446,6 +447,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr ``int`` as a short name for ``IntegerType``. :param samplingRatio: the sample ratio of rows used for inferring :param verifySchema: verify data types of every row against schema. + :param numSlices: specify as :class:`int` the number of slices (partitions) to distribute + ``data`` across. Applies to ``data`` of :class:`list` or :class:`pandas.DataFrame`. + Defaults to `self.sparkContext.defaultParallelism`. :return: :class:`DataFrame` .. versionchanged:: 2.1 @@ -534,7 +538,8 @@ def prepare(obj): if isinstance(data, RDD): rdd, schema = self._createFromRDD(data.map(prepare), schema, samplingRatio) else: - rdd, schema = self._createFromLocal(map(prepare, data), schema) + rdd, schema = self._createFromLocal(map(prepare, data), schema, numSlices=numSlices) + jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) jdf = self._jsparkSession.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) df = DataFrame(jdf, self._wrapped)