2121import random
2222import os
2323from tempfile import NamedTemporaryFile
24- from itertools import imap
2524
2625from py4j .java_collections import ListConverter , MapConverter
2726
2827from pyspark .context import SparkContext
29- from pyspark .rdd import RDD , _prepare_for_python_RDD
30- from pyspark .serializers import BatchedSerializer , AutoBatchedSerializer , PickleSerializer , \
31- UTF8Deserializer
28+ from pyspark .rdd import RDD
29+ from pyspark .serializers import BatchedSerializer , PickleSerializer , UTF8Deserializer
3230from pyspark .storagelevel import StorageLevel
3331from pyspark .traceback_utils import SCCallSiteSync
3432from pyspark .sql .types import *
3533from pyspark .sql .types import _create_cls , _parse_datatype_json_string
3634
3735
38- __all__ = ["DataFrame" , "GroupedData" , "Column" , "Dsl" , " SchemaRDD" ]
36+ __all__ = ["DataFrame" , "GroupedData" , "Column" , "SchemaRDD" ]
3937
4038
4139class DataFrame (object ):
@@ -310,8 +308,9 @@ def take(self, num):
310308 return self .limit (num ).collect ()
311309
312310 def map (self , f ):
313- """ Return a new RDD by applying a function to each Row, it's a
314- shorthand for df.rdd.map()
311+ """ Return a new RDD by applying a function to each Row
312+
313+ It's a shorthand for df.rdd.map()
315314
316315 >>> df.map(lambda p: p.name).collect()
317316 [u'Alice', u'Bob']
@@ -586,8 +585,8 @@ def agg(self, *exprs):
586585
587586 >>> df.agg({"age": "max"}).collect()
588587 [Row(MAX(age#0)=5)]
589- >>> from pyspark.sql import Dsl
590- >>> df.agg(Dsl .min(df.age)).collect()
588+ >>> from pyspark.sql import functions as F
589+ >>> df.agg(F .min(df.age)).collect()
591590 [Row(MIN(age#0)=2)]
592591 """
593592 return self .groupBy ().agg (* exprs )
@@ -616,18 +615,18 @@ def subtract(self, other):
616615 """
617616 return DataFrame (getattr (self ._jdf , "except" )(other ._jdf ), self .sql_ctx )
618617
619- def addColumn (self , colName , col ):
618+ def withColumn (self , colName , col ):
620619 """ Return a new :class:`DataFrame` by adding a column.
621620
622- >>> df.addColumn ('age2', df.age + 2).collect()
621+ >>> df.withColumn ('age2', df.age + 2).collect()
623622 [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
624623 """
625624 return self .select ('*' , col .alias (colName ))
626625
627- def renameColumn (self , existing , new ):
626+ def withColumnRenamed (self , existing , new ):
628627 """ Rename an existing column to a new name
629628
630- >>> df.renameColumn ('age', 'age2').collect()
629+ >>> df.withColumnRenamed ('age', 'age2').collect()
631630 [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')]
632631 """
633632 cols = [Column (_to_java_column (c ), self .sql_ctx ).alias (new )
@@ -689,8 +688,9 @@ def agg(self, *exprs):
689688 >>> gdf = df.groupBy(df.name)
690689 >>> gdf.agg({"age": "max"}).collect()
691690 [Row(name=u'Bob', MAX(age#0)=5), Row(name=u'Alice', MAX(age#0)=2)]
692- >>> from pyspark.sql import Dsl
693- >>> gdf.agg(Dsl.min(df.age)).collect()
691+
692+ >>> from pyspark.sql import functions as F
693+ >>> gdf.agg(F.min(df.age)).collect()
694694 [Row(MIN(age#0)=5), Row(MIN(age#0)=2)]
695695 """
696696 assert exprs , "exprs should not be empty"
@@ -742,12 +742,12 @@ def sum(self):
742742
743743def _create_column_from_literal (literal ):
744744 sc = SparkContext ._active_spark_context
745- return sc ._jvm .Dsl .lit (literal )
745+ return sc ._jvm .functions .lit (literal )
746746
747747
748748def _create_column_from_name (name ):
749749 sc = SparkContext ._active_spark_context
750- return sc ._jvm .Dsl .col (name )
750+ return sc ._jvm .functions .col (name )
751751
752752
753753def _to_java_column (col ):
@@ -767,9 +767,9 @@ def _(self):
767767 return _
768768
769769
770- def _dsl_op (name , doc = '' ):
770+ def _func_op (name , doc = '' ):
771771 def _ (self ):
772- jc = getattr (self ._sc ._jvm .Dsl , name )(self ._jc )
772+ jc = getattr (self ._sc ._jvm .functions , name )(self ._jc )
773773 return Column (jc , self .sql_ctx )
774774 _ .__doc__ = doc
775775 return _
@@ -818,7 +818,7 @@ def __init__(self, jc, sql_ctx=None):
818818 super (Column , self ).__init__ (jc , sql_ctx )
819819
820820 # arithmetic operators
821- __neg__ = _dsl_op ("negate" )
821+ __neg__ = _func_op ("negate" )
822822 __add__ = _bin_op ("plus" )
823823 __sub__ = _bin_op ("minus" )
824824 __mul__ = _bin_op ("multiply" )
@@ -842,7 +842,7 @@ def __init__(self, jc, sql_ctx=None):
842842 # so use bitwise operators as boolean operators
843843 __and__ = _bin_op ('and' )
844844 __or__ = _bin_op ('or' )
845- __invert__ = _dsl_op ('not' )
845+ __invert__ = _func_op ('not' )
846846 __rand__ = _bin_op ("and" )
847847 __ror__ = _bin_op ("or" )
848848
@@ -934,123 +934,6 @@ def to_pandas(self):
934934 return pd .Series (data )
935935
936936
937- def _aggregate_func (name , doc = "" ):
938- """ Create a function for aggregator by name"""
939- def _ (col ):
940- sc = SparkContext ._active_spark_context
941- jc = getattr (sc ._jvm .Dsl , name )(_to_java_column (col ))
942- return Column (jc )
943- _ .__name__ = name
944- _ .__doc__ = doc
945- return staticmethod (_ )
946-
947-
948- class UserDefinedFunction (object ):
949- def __init__ (self , func , returnType ):
950- self .func = func
951- self .returnType = returnType
952- self ._broadcast = None
953- self ._judf = self ._create_judf ()
954-
955- def _create_judf (self ):
956- f = self .func # put it in closure `func`
957- func = lambda _ , it : imap (lambda x : f (* x ), it )
958- ser = AutoBatchedSerializer (PickleSerializer ())
959- command = (func , None , ser , ser )
960- sc = SparkContext ._active_spark_context
961- pickled_command , broadcast_vars , env , includes = _prepare_for_python_RDD (sc , command , self )
962- ssql_ctx = sc ._jvm .SQLContext (sc ._jsc .sc ())
963- jdt = ssql_ctx .parseDataType (self .returnType .json ())
964- judf = sc ._jvm .UserDefinedPythonFunction (f .__name__ , bytearray (pickled_command ), env ,
965- includes , sc .pythonExec , broadcast_vars ,
966- sc ._javaAccumulator , jdt )
967- return judf
968-
969- def __del__ (self ):
970- if self ._broadcast is not None :
971- self ._broadcast .unpersist ()
972- self ._broadcast = None
973-
974- def __call__ (self , * cols ):
975- sc = SparkContext ._active_spark_context
976- jcols = ListConverter ().convert ([_to_java_column (c ) for c in cols ],
977- sc ._gateway ._gateway_client )
978- jc = self ._judf .apply (sc ._jvm .PythonUtils .toSeq (jcols ))
979- return Column (jc )
980-
981-
982- class Dsl (object ):
983- """
984- A collections of builtin aggregators
985- """
986- DSLS = {
987- 'lit' : 'Creates a :class:`Column` of literal value.' ,
988- 'col' : 'Returns a :class:`Column` based on the given column name.' ,
989- 'column' : 'Returns a :class:`Column` based on the given column name.' ,
990- 'upper' : 'Converts a string expression to upper case.' ,
991- 'lower' : 'Converts a string expression to upper case.' ,
992- 'sqrt' : 'Computes the square root of the specified float value.' ,
993- 'abs' : 'Computes the absolutle value.' ,
994-
995- 'max' : 'Aggregate function: returns the maximum value of the expression in a group.' ,
996- 'min' : 'Aggregate function: returns the minimum value of the expression in a group.' ,
997- 'first' : 'Aggregate function: returns the first value in a group.' ,
998- 'last' : 'Aggregate function: returns the last value in a group.' ,
999- 'count' : 'Aggregate function: returns the number of items in a group.' ,
1000- 'sum' : 'Aggregate function: returns the sum of all values in the expression.' ,
1001- 'avg' : 'Aggregate function: returns the average of the values in a group.' ,
1002- 'mean' : 'Aggregate function: returns the average of the values in a group.' ,
1003- 'sumDistinct' : 'Aggregate function: returns the sum of distinct values in the expression.' ,
1004- }
1005-
1006- for _name , _doc in DSLS .items ():
1007- locals ()[_name ] = _aggregate_func (_name , _doc )
1008- del _name , _doc
1009-
1010- @staticmethod
1011- def countDistinct (col , * cols ):
1012- """ Return a new Column for distinct count of (col, *cols)
1013-
1014- >>> from pyspark.sql import Dsl
1015- >>> df.agg(Dsl.countDistinct(df.age, df.name).alias('c')).collect()
1016- [Row(c=2)]
1017-
1018- >>> df.agg(Dsl.countDistinct("age", "name").alias('c')).collect()
1019- [Row(c=2)]
1020- """
1021- sc = SparkContext ._active_spark_context
1022- jcols = ListConverter ().convert ([_to_java_column (c ) for c in cols ],
1023- sc ._gateway ._gateway_client )
1024- jc = sc ._jvm .Dsl .countDistinct (_to_java_column (col ),
1025- sc ._jvm .PythonUtils .toSeq (jcols ))
1026- return Column (jc )
1027-
1028- @staticmethod
1029- def approxCountDistinct (col , rsd = None ):
1030- """ Return a new Column for approxiate distinct count of (col, *cols)
1031-
1032- >>> from pyspark.sql import Dsl
1033- >>> df.agg(Dsl.approxCountDistinct(df.age).alias('c')).collect()
1034- [Row(c=2)]
1035- """
1036- sc = SparkContext ._active_spark_context
1037- if rsd is None :
1038- jc = sc ._jvm .Dsl .approxCountDistinct (_to_java_column (col ))
1039- else :
1040- jc = sc ._jvm .Dsl .approxCountDistinct (_to_java_column (col ), rsd )
1041- return Column (jc )
1042-
1043- @staticmethod
1044- def udf (f , returnType = StringType ()):
1045- """Create a user defined function (UDF)
1046-
1047- >>> slen = Dsl.udf(lambda s: len(s), IntegerType())
1048- >>> df.select(slen(df.name).alias('slen')).collect()
1049- [Row(slen=5), Row(slen=3)]
1050- """
1051- return UserDefinedFunction (f , returnType )
1052-
1053-
1054937def _test ():
1055938 import doctest
1056939 from pyspark .context import SparkContext
0 commit comments