5151from py4j .java_collections import ListConverter , MapConverter
5252
5353from pyspark .context import SparkContext
54- from pyspark .rdd import RDD
54+ from pyspark .rdd import RDD , _prepare_for_python_RDD
5555from pyspark .serializers import BatchedSerializer , AutoBatchedSerializer , PickleSerializer , \
5656 CloudPickleSerializer , UTF8Deserializer
5757from pyspark .storagelevel import StorageLevel
@@ -1274,28 +1274,15 @@ def registerFunction(self, name, f, returnType=StringType()):
12741274 [Row(c0=4)]
12751275 """
12761276 func = lambda _ , it : imap (lambda x : f (* x ), it )
1277- command = (func , None ,
1278- AutoBatchedSerializer (PickleSerializer ()),
1279- AutoBatchedSerializer (PickleSerializer ()))
1280- ser = CloudPickleSerializer ()
1281- pickled_command = ser .dumps (command )
1282- if len (pickled_command ) > (1 << 20 ): # 1M
1283- broadcast = self ._sc .broadcast (pickled_command )
1284- pickled_command = ser .dumps (broadcast )
1285- broadcast_vars = ListConverter ().convert (
1286- [x ._jbroadcast for x in self ._sc ._pickled_broadcast_vars ],
1287- self ._sc ._gateway ._gateway_client )
1288- self ._sc ._pickled_broadcast_vars .clear ()
1289- env = MapConverter ().convert (self ._sc .environment ,
1290- self ._sc ._gateway ._gateway_client )
1291- includes = ListConverter ().convert (self ._sc ._python_includes ,
1292- self ._sc ._gateway ._gateway_client )
1277+ ser = AutoBatchedSerializer (PickleSerializer ())
1278+ command = (func , None , ser , ser )
1279+ pickled_cmd , bvars , env , includes = _prepare_for_python_RDD (self ._sc , command , self )
12931280 self ._ssql_ctx .udf ().registerPython (name ,
1294- bytearray (pickled_command ),
1281+ bytearray (pickled_cmd ),
12951282 env ,
12961283 includes ,
12971284 self ._sc .pythonExec ,
1298- broadcast_vars ,
1285+ bvars ,
12991286 self ._sc ._javaAccumulator ,
13001287 returnType .json ())
13011288
@@ -2077,9 +2064,9 @@ def dtypes(self):
20772064 """Return all column names and their data types as a list.
20782065
20792066 >>> df.dtypes
2080- [(u 'age', 'IntegerType '), (u 'name', 'StringType ')]
2067+ [('age', 'integer '), ('name', 'string ')]
20812068 """
2082- return [(f .name , str ( f .dataType )) for f in self .schema ().fields ]
2069+ return [(str ( f .name ), f .dataType . jsonValue ( )) for f in self .schema ().fields ]
20832070
20842071 @property
20852072 def columns (self ):
@@ -2194,7 +2181,7 @@ def select(self, *cols):
21942181 [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
21952182 >>> df.select('name', 'age').collect()
21962183 [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
2197- >>> df.select(df.name, (df.age + 10).As ('age')).collect()
2184+ >>> df.select(df.name, (df.age + 10).alias ('age')).collect()
21982185 [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
21992186 """
22002187 if not cols :
@@ -2295,25 +2282,13 @@ def subtract(self, other):
22952282 """
22962283 return DataFrame (getattr (self ._jdf , "except" )(other ._jdf ), self .sql_ctx )
22972284
2298- def sample (self , withReplacement , fraction , seed = None ):
2299- """ Return a new DataFrame by sampling a fraction of rows.
2300-
2301- >>> df.sample(False, 0.5, 10).collect()
2302- [Row(age=2, name=u'Alice')]
2303- """
2304- if seed is None :
2305- jdf = self ._jdf .sample (withReplacement , fraction )
2306- else :
2307- jdf = self ._jdf .sample (withReplacement , fraction , seed )
2308- return DataFrame (jdf , self .sql_ctx )
2309-
23102285 def addColumn (self , colName , col ):
23112286 """ Return a new :class:`DataFrame` by adding a column.
23122287
23132288 >>> df.addColumn('age2', df.age + 2).collect()
23142289 [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
23152290 """
2316- return self .select ('*' , col .As (colName ))
2291+ return self .select ('*' , col .alias (colName ))
23172292
23182293
23192294# Having SchemaRDD for backward compatibility (for docs)
@@ -2408,28 +2383,6 @@ def sum(self):
24082383 group."""
24092384
24102385
2411- SCALA_METHOD_MAPPINGS = {
2412- '=' : '$eq' ,
2413- '>' : '$greater' ,
2414- '<' : '$less' ,
2415- '+' : '$plus' ,
2416- '-' : '$minus' ,
2417- '*' : '$times' ,
2418- '/' : '$div' ,
2419- '!' : '$bang' ,
2420- '@' : '$at' ,
2421- '#' : '$hash' ,
2422- '%' : '$percent' ,
2423- '^' : '$up' ,
2424- '&' : '$amp' ,
2425- '~' : '$tilde' ,
2426- '?' : '$qmark' ,
2427- '|' : '$bar' ,
2428- '\\ ' : '$bslash' ,
2429- ':' : '$colon' ,
2430- }
2431-
2432-
24332386def _create_column_from_literal (literal ):
24342387 sc = SparkContext ._active_spark_context
24352388 return sc ._jvm .Dsl .lit (literal )
@@ -2448,23 +2401,18 @@ def _to_java_column(col):
24482401 return jcol
24492402
24502403
2451- def _scalaMethod (name ):
2452- """ Translate operators into methodName in Scala
2453-
2454- >>> _scalaMethod('+')
2455- '$plus'
2456- >>> _scalaMethod('>=')
2457- '$greater$eq'
2458- >>> _scalaMethod('cast')
2459- 'cast'
2460- """
2461- return '' .join (SCALA_METHOD_MAPPINGS .get (c , c ) for c in name )
2462-
2463-
24642404def _unary_op (name , doc = "unary operator" ):
24652405 """ Create a method for given unary operator """
24662406 def _ (self ):
2467- jc = getattr (self ._jc , _scalaMethod (name ))()
2407+ jc = getattr (self ._jc , name )()
2408+ return Column (jc , self .sql_ctx )
2409+ _ .__doc__ = doc
2410+ return _
2411+
2412+
2413+ def _dsl_op (name , doc = '' ):
2414+ def _ (self ):
2415+ jc = getattr (self ._sc ._jvm .Dsl , name )(self ._jc )
24682416 return Column (jc , self .sql_ctx )
24692417 _ .__doc__ = doc
24702418 return _
@@ -2475,7 +2423,7 @@ def _bin_op(name, doc="binary operator"):
24752423 """
24762424 def _ (self , other ):
24772425 jc = other ._jc if isinstance (other , Column ) else other
2478- njc = getattr (self ._jc , _scalaMethod ( name ) )(jc )
2426+ njc = getattr (self ._jc , name )(jc )
24792427 return Column (njc , self .sql_ctx )
24802428 _ .__doc__ = doc
24812429 return _
@@ -2486,7 +2434,7 @@ def _reverse_op(name, doc="binary operator"):
24862434 """
24872435 def _ (self , other ):
24882436 jother = _create_column_from_literal (other )
2489- jc = getattr (jother , _scalaMethod ( name ) )(self ._jc )
2437+ jc = getattr (jother , name )(self ._jc )
24902438 return Column (jc , self .sql_ctx )
24912439 _ .__doc__ = doc
24922440 return _
@@ -2513,34 +2461,33 @@ def __init__(self, jc, sql_ctx=None):
25132461 super (Column , self ).__init__ (jc , sql_ctx )
25142462
25152463 # arithmetic operators
2516- __neg__ = _unary_op ("unary_-" )
2517- __add__ = _bin_op ("+" )
2518- __sub__ = _bin_op ("-" )
2519- __mul__ = _bin_op ("*" )
2520- __div__ = _bin_op ("/" )
2521- __mod__ = _bin_op ("%" )
2522- __radd__ = _bin_op ("+" )
2523- __rsub__ = _reverse_op ("-" )
2524- __rmul__ = _bin_op ("*" )
2525- __rdiv__ = _reverse_op ("/" )
2526- __rmod__ = _reverse_op ("%" )
2527- __abs__ = _unary_op ("abs" )
2464+ __neg__ = _dsl_op ("negate" )
2465+ __add__ = _bin_op ("plus" )
2466+ __sub__ = _bin_op ("minus" )
2467+ __mul__ = _bin_op ("multiply" )
2468+ __div__ = _bin_op ("divide" )
2469+ __mod__ = _bin_op ("mod" )
2470+ __radd__ = _bin_op ("plus" )
2471+ __rsub__ = _reverse_op ("minus" )
2472+ __rmul__ = _bin_op ("multiply" )
2473+ __rdiv__ = _reverse_op ("divide" )
2474+ __rmod__ = _reverse_op ("mod" )
25282475
25292476 # logistic operators
2530- __eq__ = _bin_op ("=== " )
2531- __ne__ = _bin_op ("!== " )
2532- __lt__ = _bin_op ("< " )
2533- __le__ = _bin_op ("<= " )
2534- __ge__ = _bin_op (">= " )
2535- __gt__ = _bin_op ("> " )
2477+ __eq__ = _bin_op ("equalTo " )
2478+ __ne__ = _bin_op ("notEqual " )
2479+ __lt__ = _bin_op ("lt " )
2480+ __le__ = _bin_op ("leq " )
2481+ __ge__ = _bin_op ("geq " )
2482+ __gt__ = _bin_op ("gt " )
25362483
25372484 # `and`, `or`, `not` cannot be overloaded in Python,
25382485 # so use bitwise operators as boolean operators
2539- __and__ = _bin_op ('&& ' )
2540- __or__ = _bin_op ('|| ' )
2541- __invert__ = _unary_op ( 'unary_! ' )
2542- __rand__ = _bin_op ("&& " )
2543- __ror__ = _bin_op ("|| " )
2486+ __and__ = _bin_op ('and ' )
2487+ __or__ = _bin_op ('or ' )
2488+ __invert__ = _dsl_op ( 'not ' )
2489+ __rand__ = _bin_op ("and " )
2490+ __ror__ = _bin_op ("or " )
25442491
25452492 # container operators
25462493 __contains__ = _bin_op ("contains" )
@@ -2582,24 +2529,20 @@ def substr(self, startPos, length):
25822529 isNull = _unary_op ("isNull" , "True if the current expression is null." )
25832530 isNotNull = _unary_op ("isNotNull" , "True if the current expression is not null." )
25842531
2585- # `as` is keyword
25862532 def alias (self , alias ):
25872533 """Return a alias for this column
25882534
2589- >>> df.age.As("age2").collect()
2590- [Row(age2=2), Row(age2=5)]
25912535 >>> df.age.alias("age2").collect()
25922536 [Row(age2=2), Row(age2=5)]
25932537 """
25942538 return Column (getattr (self ._jc , "as" )(alias ), self .sql_ctx )
2595- As = alias
25962539
25972540 def cast (self , dataType ):
25982541 """ Convert the column into type `dataType`
25992542
2600- >>> df.select(df.age.cast("string").As ('ages')).collect()
2543+ >>> df.select(df.age.cast("string").alias ('ages')).collect()
26012544 [Row(ages=u'2'), Row(ages=u'5')]
2602- >>> df.select(df.age.cast(StringType()).As ('ages')).collect()
2545+ >>> df.select(df.age.cast(StringType()).alias ('ages')).collect()
26032546 [Row(ages=u'2'), Row(ages=u'5')]
26042547 """
26052548 if self .sql_ctx is None :
@@ -2626,6 +2569,40 @@ def _(col):
26262569 return staticmethod (_ )
26272570
26282571
2572+ class UserDefinedFunction (object ):
2573+ def __init__ (self , func , returnType ):
2574+ self .func = func
2575+ self .returnType = returnType
2576+ self ._broadcast = None
2577+ self ._judf = self ._create_judf ()
2578+
2579+ def _create_judf (self ):
2580+ f = self .func # put it in closure `func`
2581+ func = lambda _ , it : imap (lambda x : f (* x ), it )
2582+ ser = AutoBatchedSerializer (PickleSerializer ())
2583+ command = (func , None , ser , ser )
2584+ sc = SparkContext ._active_spark_context
2585+ pickled_command , broadcast_vars , env , includes = _prepare_for_python_RDD (sc , command , self )
2586+ ssql_ctx = sc ._jvm .SQLContext (sc ._jsc .sc ())
2587+ jdt = ssql_ctx .parseDataType (self .returnType .json ())
2588+ judf = sc ._jvm .UserDefinedPythonFunction (f .__name__ , bytearray (pickled_command ), env ,
2589+ includes , sc .pythonExec , broadcast_vars ,
2590+ sc ._javaAccumulator , jdt )
2591+ return judf
2592+
2593+ def __del__ (self ):
2594+ if self ._broadcast is not None :
2595+ self ._broadcast .unpersist ()
2596+ self ._broadcast = None
2597+
2598+ def __call__ (self , * cols ):
2599+ sc = SparkContext ._active_spark_context
2600+ jcols = ListConverter ().convert ([_to_java_column (c ) for c in cols ],
2601+ sc ._gateway ._gateway_client )
2602+ jc = self ._judf .apply (sc ._jvm .PythonUtils .toSeq (jcols ))
2603+ return Column (jc )
2604+
2605+
26292606class Dsl (object ):
26302607 """
26312608 A collections of builtin aggregators
@@ -2659,7 +2636,7 @@ def countDistinct(col, *cols):
26592636 """ Return a new Column for distinct count of (col, *cols)
26602637
26612638 >>> from pyspark.sql import Dsl
2662- >>> df.agg(Dsl.countDistinct(df.age, df.name).As ('c')).collect()
2639+ >>> df.agg(Dsl.countDistinct(df.age, df.name).alias ('c')).collect()
26632640 [Row(c=2)]
26642641 """
26652642 sc = SparkContext ._active_spark_context
@@ -2674,7 +2651,7 @@ def approxCountDistinct(col, rsd=None):
26742651 """ Return a new Column for approxiate distinct count of (col, *cols)
26752652
26762653 >>> from pyspark.sql import Dsl
2677- >>> df.agg(Dsl.approxCountDistinct(df.age).As ('c')).collect()
2654+ >>> df.agg(Dsl.approxCountDistinct(df.age).alias ('c')).collect()
26782655 [Row(c=2)]
26792656 """
26802657 sc = SparkContext ._active_spark_context
@@ -2684,6 +2661,16 @@ def approxCountDistinct(col, rsd=None):
26842661 jc = sc ._jvm .Dsl .approxCountDistinct (_to_java_column (col ), rsd )
26852662 return Column (jc )
26862663
2664+ @staticmethod
2665+ def udf (f , returnType = StringType ()):
2666+ """Create a user defined function (UDF)
2667+
2668+ >>> slen = Dsl.udf(lambda s: len(s), IntegerType())
2669+ >>> df.select(slen(df.name).alias('slen')).collect()
2670+ [Row(slen=5), Row(slen=3)]
2671+ """
2672+ return UserDefinedFunction (f , returnType )
2673+
26872674
26882675def _test ():
26892676 import doctest
0 commit comments