Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 67 additions & 5 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
if sys.version >= '3':
basestring = unicode = str

import logging
from py4j.java_gateway import JavaClass

from pyspark import RDD, since, keyword_only
Expand Down Expand Up @@ -370,7 +371,7 @@ def orc(self, path):

>>> df = spark.read.orc('python/test_support/sql/orc_partitioned')
>>> df.dtypes
[('a', 'bigint'), ('b', 'int'), ('c', 'int')]
[('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
"""
return self._df(self._jreader.orc(path))

Expand Down Expand Up @@ -501,6 +502,46 @@ def partitionBy(self, *cols):
self._jwrite = self._jwrite.partitionBy(_to_seq(self._spark._sc, cols))
return self

@since(2.1)
def bucketBy(self, numBuckets, *cols):
"""Buckets the output by the given columns on the file system.

:param numBuckets: the number of buckets to save
:param cols: name of columns

>>> (df.write.format('parquet')
... .bucketBy(100, 'year', 'month')
... .saveAsTable('bucketed_data'))
"""
if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
cols = cols[0]

col = cols[0]
cols = cols[1:]

self._jwrite = self._jwrite.bucketBy(numBuckets, col, _to_seq(self._spark._sc, cols))
return self

@since(2.1)
def sortBy(self, *cols):
"""Sorts the output in each bucket by the given columns on the file system.

:param cols: name of columns

>>> (df.write.format('parquet')
... .bucketBy(100, 'year', 'month')
... .sortBy('day')
... .saveAsTable('sorted_data'))
"""
if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
cols = cols[0]

col = cols[0]
cols = cols[1:]

self._jwrite = self._jwrite.sortBy(col, _to_seq(self._spark._sc, cols))
return self

@since(1.4)
def save(self, path=None, format=None, mode=None, partitionBy=None, **options):
"""Saves the contents of the :class:`DataFrame` to a data source.
Expand Down Expand Up @@ -562,6 +603,8 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options)
:param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
:param partitionBy: names of partitioning columns
:param options: all other string options

>>> df.write.saveAsTable('my_table')
"""
self.mode(mode).options(**options)
if partitionBy is not None:
Expand Down Expand Up @@ -693,8 +736,7 @@ def orc(self, path, mode=None, partitionBy=None, compression=None):
This will override ``orc.compress``. If None is set, it uses the
default value, ``snappy``.

>>> orc_df = spark.read.orc('python/test_support/sql/orc_partitioned')
>>> orc_df.write.orc(os.path.join(tempfile.mkdtemp(), 'data'))
>>> df.write.orc(os.path.join(tempfile.mkdtemp(), 'data'))
"""
self.mode(mode)
if partitionBy is not None:
Expand Down Expand Up @@ -734,11 +776,22 @@ def _test():
import os
import tempfile
import py4j
import shutil
from random import Random
from time import time
from pyspark.context import SparkContext
from pyspark.sql import SparkSession, Row
import pyspark.sql.readwriter

os.chdir(os.environ["SPARK_HOME"])
spark_home = os.path.realpath(os.environ["SPARK_HOME"])

test_dir = tempfile.mkdtemp()
os.chdir(test_dir)

path = lambda x, y, z: os.path.join(x, y)

shutil.copytree(path(spark_home, 'python', 'test_support'),
path(test_dir, 'python', 'test_support'))

globs = pyspark.sql.readwriter.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
Expand All @@ -747,16 +800,25 @@ def _test():
except py4j.protocol.Py4JError:
spark = SparkSession(sc)

seed = int(time() * 1000)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to have determistic test, testing with parquet should be enough.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have been really busy with work of late, but I will try to sort this out today

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@GregBowyer Any progress on this? :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@GregBowyer ping. Let me propose to close this after a week.

Copy link
Member

@HyukjinKwon HyukjinKwon Feb 27, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @zero323, would you maybe be interested in taking over this? I was thinking of taking over this if no one goes for it assuming it looks quite close to be merged.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HyukjinKwon By all means. I prepared a bunch of tests (7d911c647f21ada7fb429fd7c1c5f15934ff8847) and extended a bit code provided by @GregBowyer (72c04a3f196da5223ebb44725aa88cffa81036e4). I think we can skip low level tests (direct access to the files) which are already present in Scala test base.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zero323, Good to know. Then, please go ahead if you are ready :).

rng = Random(seed)

base_df_format = rng.choice(('orc', 'parquet'))
loader = getattr(spark.read, base_df_format)
path = os.path.join(test_dir, 'python/test_support/sql/%s_partitioned' % base_df_format)
df = loader(path)

globs['tempfile'] = tempfile
globs['os'] = os
globs['sc'] = sc
globs['spark'] = spark
globs['df'] = spark.read.parquet('python/test_support/sql/parquet_partitioned')
globs['df'] = df
(failure_count, test_count) = doctest.testmod(
pyspark.sql.readwriter, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
sc.stop()
if failure_count:
logging.error('Random seed for test: %d', seed)
exit(-1)


Expand Down
Binary file not shown.
Empty file modified python/test_support/sql/orc_partitioned/_SUCCESS
100755 → 100644
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed python/test_support/sql/parquet_partitioned/_metadata
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.