2020if sys .version >= '3' :
2121 basestring = unicode = str
2222
23+ import logging
2324from py4j .java_gateway import JavaClass
2425
2526from pyspark import RDD , since , keyword_only
@@ -370,7 +371,7 @@ def orc(self, path):
370371
371372 >>> df = spark.read.orc('python/test_support/sql/orc_partitioned')
372373 >>> df.dtypes
373- [('a ', 'bigint '), ('b ', 'int'), ('c ', 'int')]
374+ [('name ', 'string '), ('year ', 'int'), ('month', 'int'), ('day ', 'int')]
374375 """
375376 return self ._df (self ._jreader .orc (path ))
376377
@@ -510,7 +511,7 @@ def bucketBy(self, numBuckets, *cols):
510511
511512 >>> (df.write.format('parquet')
512513 ... .bucketBy(100, 'year', 'month')
513- ... .saveAsTable(os.path.join(tempfile.mkdtemp(), 'bucketed_table') ))
514+ ... .saveAsTable('bucketed_data' ))
514515 """
515516 if len (cols ) == 1 and isinstance (cols [0 ], (list , tuple )):
516517 cols = cols [0 ]
@@ -530,7 +531,7 @@ def sortBy(self, *cols):
530531 >>> (df.write.format('parquet')
531532 ... .bucketBy(100, 'year', 'month')
532533 ... .sortBy('day')
533- ... .saveAsTable(os.path.join(tempfile.mkdtemp(), 'sorted_bucketed_table') ))
534+ ... .saveAsTable('sorted_data' ))
534535 """
535536 if len (cols ) == 1 and isinstance (cols [0 ], (list , tuple )):
536537 cols = cols [0 ]
@@ -602,6 +603,8 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options)
602603 :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
603604 :param partitionBy: names of partitioning columns
604605 :param options: all other string options
606+
607+ >>> df.write.saveAsTable('my_table')
605608 """
606609 self .mode (mode ).options (** options )
607610 if partitionBy is not None :
@@ -733,8 +736,7 @@ def orc(self, path, mode=None, partitionBy=None, compression=None):
733736 This will override ``orc.compress``. If None is set, it uses the
734737 default value, ``snappy``.
735738
736- >>> orc_df = spark.read.orc('python/test_support/sql/orc_partitioned')
737- >>> orc_df.write.orc(os.path.join(tempfile.mkdtemp(), 'data'))
739+ >>> df.write.orc(os.path.join(tempfile.mkdtemp(), 'data'))
738740 """
739741 self .mode (mode )
740742 if partitionBy is not None :
@@ -774,11 +776,22 @@ def _test():
774776 import os
775777 import tempfile
776778 import py4j
779+ import shutil
780+ from random import Random
781+ from time import time
777782 from pyspark .context import SparkContext
778783 from pyspark .sql import SparkSession , Row
779784 import pyspark .sql .readwriter
780785
781- os .chdir (os .environ ["SPARK_HOME" ])
786+ spark_home = os .path .realpath (os .environ ["SPARK_HOME" ])
787+
788+ test_dir = tempfile .mkdtemp ()
789+ os .chdir (test_dir )
790+
791+ path = lambda x , y , z : os .path .join (x , y )
792+
793+ shutil .copytree (path (spark_home , 'python' , 'test_support' ),
794+ path (test_dir , 'python' , 'test_support' ))
782795
783796 globs = pyspark .sql .readwriter .__dict__ .copy ()
784797 sc = SparkContext ('local[4]' , 'PythonTest' )
@@ -787,16 +800,25 @@ def _test():
787800 except py4j .protocol .Py4JError :
788801 spark = SparkSession (sc )
789802
803+ seed = int (time () * 1000 )
804+ rng = Random (seed )
805+
806+ base_df_format = rng .choice (('orc' , 'parquet' ))
807+ loader = getattr (spark .read , base_df_format )
808+ path = os .path .join (test_dir , 'python/test_support/sql/%s_partitioned' % base_df_format )
809+ df = loader (path )
810+
790811 globs ['tempfile' ] = tempfile
791812 globs ['os' ] = os
792813 globs ['sc' ] = sc
793814 globs ['spark' ] = spark
794- globs ['df' ] = spark . read . parquet ( 'python/test_support/sql/parquet_partitioned' )
815+ globs ['df' ] = df
795816 (failure_count , test_count ) = doctest .testmod (
796817 pyspark .sql .readwriter , globs = globs ,
797818 optionflags = doctest .ELLIPSIS | doctest .NORMALIZE_WHITESPACE | doctest .REPORT_NDIFF )
798819 sc .stop ()
799820 if failure_count :
821+ logging .error ('Random seed for test: %d' , seed )
800822 exit (- 1 )
801823
802824
0 commit comments