5151from py4j .java_collections import ListConverter , MapConverter
5252
5353from pyspark .context import SparkContext
54- from pyspark .rdd import RDD
54+ from pyspark .rdd import RDD , _prepare_for_python_RDD
5555from pyspark .serializers import BatchedSerializer , AutoBatchedSerializer , PickleSerializer , \
5656 CloudPickleSerializer , UTF8Deserializer
5757from pyspark .storagelevel import StorageLevel
@@ -1274,22 +1274,9 @@ def registerFunction(self, name, f, returnType=StringType()):
12741274 [Row(c0=4)]
12751275 """
12761276 func = lambda _ , it : imap (lambda x : f (* x ), it )
1277- command = (func , None ,
1278- AutoBatchedSerializer (PickleSerializer ()),
1279- AutoBatchedSerializer (PickleSerializer ()))
1280- ser = CloudPickleSerializer ()
1281- pickled_command = ser .dumps (command )
1282- if len (pickled_command ) > (1 << 20 ): # 1M
1283- broadcast = self ._sc .broadcast (pickled_command )
1284- pickled_command = ser .dumps (broadcast )
1285- broadcast_vars = ListConverter ().convert (
1286- [x ._jbroadcast for x in self ._sc ._pickled_broadcast_vars ],
1287- self ._sc ._gateway ._gateway_client )
1288- self ._sc ._pickled_broadcast_vars .clear ()
1289- env = MapConverter ().convert (self ._sc .environment ,
1290- self ._sc ._gateway ._gateway_client )
1291- includes = ListConverter ().convert (self ._sc ._python_includes ,
1292- self ._sc ._gateway ._gateway_client )
1277+ ser = AutoBatchedSerializer (PickleSerializer ())
1278+ command = (func , None , ser , ser )
1279+ pickled_command , broadcast_vars , env , includes = _prepare_for_python_RDD (self ._sc , command )
12931280 self ._ssql_ctx .udf ().registerPython (name ,
12941281 bytearray (pickled_command ),
12951282 env ,
@@ -2187,7 +2174,7 @@ def select(self, *cols):
21872174 [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
21882175 >>> df.select('name', 'age').collect()
21892176 [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
2190- >>> df.select(df.name, (df.age + 10).As ('age')).collect()
2177+ >>> df.select(df.name, (df.age + 10).alias ('age')).collect()
21912178 [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
21922179 """
21932180 if not cols :
@@ -2268,7 +2255,7 @@ def addColumn(self, colName, col):
22682255 >>> df.addColumn('age2', df.age + 2).collect()
22692256 [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
22702257 """
2271- return self .select ('*' , col .As (colName ))
2258+ return self .select ('*' , col .alias (colName ))
22722259
22732260
22742261# Having SchemaRDD for backward compatibility (for docs)
@@ -2509,24 +2496,20 @@ def substr(self, startPos, length):
25092496 isNull = _unary_op ("isNull" , "True if the current expression is null." )
25102497 isNotNull = _unary_op ("isNotNull" , "True if the current expression is not null." )
25112498
2512- # `as` is keyword
25132499 def alias (self , alias ):
25142500 """Return a alias for this column
25152501
2516- >>> df.age.As("age2").collect()
2517- [Row(age2=2), Row(age2=5)]
25182502 >>> df.age.alias("age2").collect()
25192503 [Row(age2=2), Row(age2=5)]
25202504 """
25212505 return Column (getattr (self ._jc , "as" )(alias ), self .sql_ctx )
2522- As = alias
25232506
25242507 def cast (self , dataType ):
25252508 """ Convert the column into type `dataType`
25262509
2527- >>> df.select(df.age.cast("string").As ('ages')).collect()
2510+ >>> df.select(df.age.cast("string").alias ('ages')).collect()
25282511 [Row(ages=u'2'), Row(ages=u'5')]
2529- >>> df.select(df.age.cast(StringType()).As ('ages')).collect()
2512+ >>> df.select(df.age.cast(StringType()).alias ('ages')).collect()
25302513 [Row(ages=u'2'), Row(ages=u'5')]
25312514 """
25322515 if self .sql_ctx is None :
@@ -2560,24 +2543,12 @@ def __init__(self, func, returnType):
25602543 self ._judf = self ._create_judf ()
25612544
25622545 def _create_judf (self ):
2563- f = self .func
2564- sc = SparkContext ._active_spark_context
2565- # TODO(davies): refactor
2546+ f = self .func # put it in closure `func`
25662547 func = lambda _ , it : imap (lambda x : f (* x ), it )
2567- command = (func , None ,
2568- AutoBatchedSerializer (PickleSerializer ()),
2569- AutoBatchedSerializer (PickleSerializer ()))
2570- ser = CloudPickleSerializer ()
2571- pickled_command = ser .dumps (command )
2572- if len (pickled_command ) > (1 << 20 ): # 1M
2573- broadcast = sc .broadcast (pickled_command )
2574- pickled_command = ser .dumps (broadcast )
2575- broadcast_vars = ListConverter ().convert (
2576- [x ._jbroadcast for x in sc ._pickled_broadcast_vars ],
2577- sc ._gateway ._gateway_client )
2578- sc ._pickled_broadcast_vars .clear ()
2579- env = MapConverter ().convert (sc .environment , sc ._gateway ._gateway_client )
2580- includes = ListConverter ().convert (sc ._python_includes , sc ._gateway ._gateway_client )
2548+ ser = AutoBatchedSerializer (PickleSerializer ())
2549+ command = (func , None , ser , ser )
2550+ sc = SparkContext ._active_spark_context
2551+ pickled_command , broadcast_vars , env , includes = _prepare_for_python_RDD (sc , command )
25812552 ssql_ctx = sc ._jvm .SQLContext (sc ._jsc .sc ())
25822553 jdt = ssql_ctx .parseDataType (self .returnType .json ())
25832554 judf = sc ._jvm .Dsl .pythonUDF (f .__name__ , bytearray (pickled_command ), env , includes ,
@@ -2625,7 +2596,7 @@ def countDistinct(col, *cols):
26252596 """ Return a new Column for distinct count of (col, *cols)
26262597
26272598 >>> from pyspark.sql import Dsl
2628- >>> df.agg(Dsl.countDistinct(df.age, df.name).As ('c')).collect()
2599+ >>> df.agg(Dsl.countDistinct(df.age, df.name).alias ('c')).collect()
26292600 [Row(c=2)]
26302601 """
26312602 sc = SparkContext ._active_spark_context
@@ -2640,7 +2611,7 @@ def approxCountDistinct(col, rsd=None):
26402611 """ Return a new Column for approxiate distinct count of (col, *cols)
26412612
26422613 >>> from pyspark.sql import Dsl
2643- >>> df.agg(Dsl.approxCountDistinct(df.age).As ('c')).collect()
2614+ >>> df.agg(Dsl.approxCountDistinct(df.age).alias ('c')).collect()
26442615 [Row(c=2)]
26452616 """
26462617 sc = SparkContext ._active_spark_context
@@ -2655,7 +2626,7 @@ def udf(f, returnType=StringType()):
26552626 """Create a user defined function (UDF)
26562627
26572628 >>> slen = Dsl.udf(lambda s: len(s), IntegerType())
2658- >>> df.select(slen(df.name).As ('slen')).collect()
2629+ >>> df.select(slen(df.name).alias ('slen')).collect()
26592630 [Row(slen=5), Row(slen=3)]
26602631 """
26612632 return UserDefinedFunction (f , returnType )
0 commit comments