@@ -564,12 +564,13 @@ def partitionBy(self, *cols):
564564 return self
565565
566566 @since (2.3 )
567- def bucketBy (self , numBuckets , * cols ):
567+ def bucketBy (self , numBuckets , col , * cols ):
568568 """Buckets the output by the given columns.If specified,
569569 the output is laid out on the file system similar to Hive's bucketing scheme.
570570
571571 :param numBuckets: the number of buckets to save
572- :param cols: name of columns
572+ :param col: a name of a column, or a list of names.
573+ :param cols: additional names (optional). If `col` is a list it should be empty.
573574
574575 .. note:: Applicable for file-based data sources in combination with
575576 :py:meth:`DataFrameWriter.saveAsTable`.
@@ -579,41 +580,42 @@ def bucketBy(self, numBuckets, *cols):
579580 ... .mode("overwrite")
580581 ... .saveAsTable('bucketed_table'))
581582 """
582- if len (cols ) == 1 and isinstance (cols [0 ], (list , tuple )):
583- cols = cols [0 ]
584-
585583 if not isinstance (numBuckets , int ):
586584 raise TypeError ("numBuckets should be an int, got {0}." .format (type (numBuckets )))
587585
588- if not all (isinstance (c , basestring ) for c in cols ):
589- raise TypeError ("cols argument should be a string or a sequence of strings." )
586+ if isinstance (col , (list , tuple )):
587+ if cols :
588+ raise ValueError ("col is a {0} but cols are not empty" .format (type (col )))
590589
591- col = cols [0 ]
592- cols = cols [1 :]
590+ col , cols = col [0 ], col [1 :]
591+
592+ if not all (isinstance (c , basestring ) for c in cols ) or not (isinstance (col , basestring )):
593+ raise TypeError ("all names should be `str`" )
593594
594595 self ._jwrite = self ._jwrite .bucketBy (numBuckets , col , _to_seq (self ._spark ._sc , cols ))
595596 return self
596597
597598 @since (2.3 )
598- def sortBy (self , * cols ):
599+ def sortBy (self , col , * cols ):
599600 """Sorts the output in each bucket by the given columns on the file system.
600601
601- :param cols: name of columns
602+ :param col: a name of a column, or a list of names.
603+ :param cols: additional names (optional). If `col` is a list it should be empty.
602604
603605 >>> (df.write.format('parquet')
604606 ... .bucketBy(100, 'year', 'month')
605607 ... .sortBy('day')
606608 ... .mode("overwrite")
607609 ... .saveAsTable('sorted_bucketed_table'))
608610 """
609- if len (cols ) == 1 and isinstance (cols [0 ], (list , tuple )):
610- cols = cols [0 ]
611+ if isinstance (col , (list , tuple )):
612+ if cols :
613+ raise ValueError ("col is a {0} but cols are not empty" .format (type (col )))
611614
612- if not all (isinstance (c , basestring ) for c in cols ):
613- raise TypeError ("cols argument should be a string or a sequence of strings." )
615+ col , cols = col [0 ], col [1 :]
614616
615- col = cols [ 0 ]
616- cols = cols [ 1 :]
617+ if not all ( isinstance ( c , basestring ) for c in cols ) or not ( isinstance ( col , basestring )):
618+ raise TypeError ( "all names should be `str`" )
617619
618620 self ._jwrite = self ._jwrite .sortBy (col , _to_seq (self ._spark ._sc , cols ))
619621 return self
0 commit comments