Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ def _inferSchema(self, rdd, samplingRatio=None):

@since(1.3)
@ignore_unicode_prefix
def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True):
def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True,
numSlices=None):
"""
Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`.

Expand Down Expand Up @@ -276,6 +277,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
We can also use ``int`` as a short name for :class:`pyspark.sql.types.IntegerType`.
:param samplingRatio: the sample ratio of rows used for inferring
:param verifySchema: verify data types of every row against schema.
:param numSlices: specify as :class:`int` the number of slices (partitions) to distribute
``data`` across. Applies to ``data`` of :class:`list` or :class:`pandas.DataFrame`.
Defaults to `self.sparkContext.defaultParallelism`.
:return: :class:`DataFrame`

.. versionchanged:: 2.0
Expand Down Expand Up @@ -334,7 +338,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
...
Py4JJavaError: ...
"""
return self.sparkSession.createDataFrame(data, schema, samplingRatio, verifySchema)
return self.sparkSession.createDataFrame(data, schema, samplingRatio, verifySchema,
numSlices)

@since(1.3)
def registerDataFrameAsTable(self, df, tableName):
Expand Down
13 changes: 9 additions & 4 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def _createFromRDD(self, rdd, schema, samplingRatio):
rdd = rdd.map(schema.toInternal)
return rdd, schema

def _createFromLocal(self, data, schema):
def _createFromLocal(self, data, schema, numSlices=None):
"""
Create an RDD for DataFrame from a list or pandas.DataFrame, returns
the RDD and schema.
Expand All @@ -412,11 +412,12 @@ def _createFromLocal(self, data, schema):

# convert python objects to sql data
data = [schema.toInternal(row) for row in data]
return self._sc.parallelize(data), schema
return self._sc.parallelize(data, numSlices=numSlices), schema

@since(2.0)
@ignore_unicode_prefix
def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True):
def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True,
numSlices=None):
"""
Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`.

Expand Down Expand Up @@ -446,6 +447,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
``int`` as a short name for ``IntegerType``.
:param samplingRatio: the sample ratio of rows used for inferring
:param verifySchema: verify data types of every row against schema.
:param numSlices: specify as :class:`int` the number of slices (partitions) to distribute
``data`` across. Applies to ``data`` of :class:`list` or :class:`pandas.DataFrame`.
Defaults to `self.sparkContext.defaultParallelism`.
:return: :class:`DataFrame`

.. versionchanged:: 2.1
Expand Down Expand Up @@ -534,7 +538,8 @@ def prepare(obj):
if isinstance(data, RDD):
rdd, schema = self._createFromRDD(data.map(prepare), schema, samplingRatio)
else:
rdd, schema = self._createFromLocal(map(prepare, data), schema)
rdd, schema = self._createFromLocal(map(prepare, data), schema, numSlices=numSlices)

jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
jdf = self._jsparkSession.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
df = DataFrame(jdf, self._wrapped)
Expand Down