Skip to content

Commit dc101b0

Browse files
Davies Liurxin
authored andcommitted
[SPARK-5577] Python udf for DataFrame
Author: Davies Liu <[email protected]> Closes apache#4351 from davies/python_udf and squashes the following commits: d250692 [Davies Liu] fix conflict 34234d4 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python_udf 440f769 [Davies Liu] address comments f0a3121 [Davies Liu] track life cycle of broadcast f99b2e1 [Davies Liu] address comments 462b334 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python_udf 7bccc3b [Davies Liu] python udf 58dee20 [Davies Liu] clean up
1 parent e0490e2 commit dc101b0

File tree

4 files changed

+157
-122
lines changed

4 files changed

+157
-122
lines changed

python/pyspark/rdd.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2162,6 +2162,25 @@ def toLocalIterator(self):
21622162
yield row
21632163

21642164

2165+
def _prepare_for_python_RDD(sc, command, obj=None):
2166+
# the serialized command will be compressed by broadcast
2167+
ser = CloudPickleSerializer()
2168+
pickled_command = ser.dumps(command)
2169+
if len(pickled_command) > (1 << 20): # 1M
2170+
broadcast = sc.broadcast(pickled_command)
2171+
pickled_command = ser.dumps(broadcast)
2172+
# tracking the life cycle by obj
2173+
if obj is not None:
2174+
obj._broadcast = broadcast
2175+
broadcast_vars = ListConverter().convert(
2176+
[x._jbroadcast for x in sc._pickled_broadcast_vars],
2177+
sc._gateway._gateway_client)
2178+
sc._pickled_broadcast_vars.clear()
2179+
env = MapConverter().convert(sc.environment, sc._gateway._gateway_client)
2180+
includes = ListConverter().convert(sc._python_includes, sc._gateway._gateway_client)
2181+
return pickled_command, broadcast_vars, env, includes
2182+
2183+
21652184
class PipelinedRDD(RDD):
21662185

21672186
"""
@@ -2228,25 +2247,12 @@ def _jrdd(self):
22282247

22292248
command = (self.func, profiler, self._prev_jrdd_deserializer,
22302249
self._jrdd_deserializer)
2231-
# the serialized command will be compressed by broadcast
2232-
ser = CloudPickleSerializer()
2233-
pickled_command = ser.dumps(command)
2234-
if len(pickled_command) > (1 << 20): # 1M
2235-
self._broadcast = self.ctx.broadcast(pickled_command)
2236-
pickled_command = ser.dumps(self._broadcast)
2237-
broadcast_vars = ListConverter().convert(
2238-
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
2239-
self.ctx._gateway._gateway_client)
2240-
self.ctx._pickled_broadcast_vars.clear()
2241-
env = MapConverter().convert(self.ctx.environment,
2242-
self.ctx._gateway._gateway_client)
2243-
includes = ListConverter().convert(self.ctx._python_includes,
2244-
self.ctx._gateway._gateway_client)
2250+
pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self.ctx, command, self)
22452251
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
2246-
bytearray(pickled_command),
2252+
bytearray(pickled_cmd),
22472253
env, includes, self.preservesPartitioning,
22482254
self.ctx.pythonExec,
2249-
broadcast_vars, self.ctx._javaAccumulator)
2255+
bvars, self.ctx._javaAccumulator)
22502256
self._jrdd_val = python_rdd.asJavaRDD()
22512257

22522258
if profiler:

python/pyspark/sql.py

Lines changed: 91 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from py4j.java_collections import ListConverter, MapConverter
5252

5353
from pyspark.context import SparkContext
54-
from pyspark.rdd import RDD
54+
from pyspark.rdd import RDD, _prepare_for_python_RDD
5555
from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \
5656
CloudPickleSerializer, UTF8Deserializer
5757
from pyspark.storagelevel import StorageLevel
@@ -1274,28 +1274,15 @@ def registerFunction(self, name, f, returnType=StringType()):
12741274
[Row(c0=4)]
12751275
"""
12761276
func = lambda _, it: imap(lambda x: f(*x), it)
1277-
command = (func, None,
1278-
AutoBatchedSerializer(PickleSerializer()),
1279-
AutoBatchedSerializer(PickleSerializer()))
1280-
ser = CloudPickleSerializer()
1281-
pickled_command = ser.dumps(command)
1282-
if len(pickled_command) > (1 << 20): # 1M
1283-
broadcast = self._sc.broadcast(pickled_command)
1284-
pickled_command = ser.dumps(broadcast)
1285-
broadcast_vars = ListConverter().convert(
1286-
[x._jbroadcast for x in self._sc._pickled_broadcast_vars],
1287-
self._sc._gateway._gateway_client)
1288-
self._sc._pickled_broadcast_vars.clear()
1289-
env = MapConverter().convert(self._sc.environment,
1290-
self._sc._gateway._gateway_client)
1291-
includes = ListConverter().convert(self._sc._python_includes,
1292-
self._sc._gateway._gateway_client)
1277+
ser = AutoBatchedSerializer(PickleSerializer())
1278+
command = (func, None, ser, ser)
1279+
pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self)
12931280
self._ssql_ctx.udf().registerPython(name,
1294-
bytearray(pickled_command),
1281+
bytearray(pickled_cmd),
12951282
env,
12961283
includes,
12971284
self._sc.pythonExec,
1298-
broadcast_vars,
1285+
bvars,
12991286
self._sc._javaAccumulator,
13001287
returnType.json())
13011288

@@ -2077,9 +2064,9 @@ def dtypes(self):
20772064
"""Return all column names and their data types as a list.
20782065
20792066
>>> df.dtypes
2080-
[(u'age', 'IntegerType'), (u'name', 'StringType')]
2067+
[('age', 'integer'), ('name', 'string')]
20812068
"""
2082-
return [(f.name, str(f.dataType)) for f in self.schema().fields]
2069+
return [(str(f.name), f.dataType.jsonValue()) for f in self.schema().fields]
20832070

20842071
@property
20852072
def columns(self):
@@ -2194,7 +2181,7 @@ def select(self, *cols):
21942181
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
21952182
>>> df.select('name', 'age').collect()
21962183
[Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
2197-
>>> df.select(df.name, (df.age + 10).As('age')).collect()
2184+
>>> df.select(df.name, (df.age + 10).alias('age')).collect()
21982185
[Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
21992186
"""
22002187
if not cols:
@@ -2295,25 +2282,13 @@ def subtract(self, other):
22952282
"""
22962283
return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)
22972284

2298-
def sample(self, withReplacement, fraction, seed=None):
2299-
""" Return a new DataFrame by sampling a fraction of rows.
2300-
2301-
>>> df.sample(False, 0.5, 10).collect()
2302-
[Row(age=2, name=u'Alice')]
2303-
"""
2304-
if seed is None:
2305-
jdf = self._jdf.sample(withReplacement, fraction)
2306-
else:
2307-
jdf = self._jdf.sample(withReplacement, fraction, seed)
2308-
return DataFrame(jdf, self.sql_ctx)
2309-
23102285
def addColumn(self, colName, col):
23112286
""" Return a new :class:`DataFrame` by adding a column.
23122287
23132288
>>> df.addColumn('age2', df.age + 2).collect()
23142289
[Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
23152290
"""
2316-
return self.select('*', col.As(colName))
2291+
return self.select('*', col.alias(colName))
23172292

23182293

23192294
# Having SchemaRDD for backward compatibility (for docs)
@@ -2408,28 +2383,6 @@ def sum(self):
24082383
group."""
24092384

24102385

2411-
SCALA_METHOD_MAPPINGS = {
2412-
'=': '$eq',
2413-
'>': '$greater',
2414-
'<': '$less',
2415-
'+': '$plus',
2416-
'-': '$minus',
2417-
'*': '$times',
2418-
'/': '$div',
2419-
'!': '$bang',
2420-
'@': '$at',
2421-
'#': '$hash',
2422-
'%': '$percent',
2423-
'^': '$up',
2424-
'&': '$amp',
2425-
'~': '$tilde',
2426-
'?': '$qmark',
2427-
'|': '$bar',
2428-
'\\': '$bslash',
2429-
':': '$colon',
2430-
}
2431-
2432-
24332386
def _create_column_from_literal(literal):
24342387
sc = SparkContext._active_spark_context
24352388
return sc._jvm.Dsl.lit(literal)
@@ -2448,23 +2401,18 @@ def _to_java_column(col):
24482401
return jcol
24492402

24502403

2451-
def _scalaMethod(name):
2452-
""" Translate operators into methodName in Scala
2453-
2454-
>>> _scalaMethod('+')
2455-
'$plus'
2456-
>>> _scalaMethod('>=')
2457-
'$greater$eq'
2458-
>>> _scalaMethod('cast')
2459-
'cast'
2460-
"""
2461-
return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name)
2462-
2463-
24642404
def _unary_op(name, doc="unary operator"):
24652405
""" Create a method for given unary operator """
24662406
def _(self):
2467-
jc = getattr(self._jc, _scalaMethod(name))()
2407+
jc = getattr(self._jc, name)()
2408+
return Column(jc, self.sql_ctx)
2409+
_.__doc__ = doc
2410+
return _
2411+
2412+
2413+
def _dsl_op(name, doc=''):
2414+
def _(self):
2415+
jc = getattr(self._sc._jvm.Dsl, name)(self._jc)
24682416
return Column(jc, self.sql_ctx)
24692417
_.__doc__ = doc
24702418
return _
@@ -2475,7 +2423,7 @@ def _bin_op(name, doc="binary operator"):
24752423
"""
24762424
def _(self, other):
24772425
jc = other._jc if isinstance(other, Column) else other
2478-
njc = getattr(self._jc, _scalaMethod(name))(jc)
2426+
njc = getattr(self._jc, name)(jc)
24792427
return Column(njc, self.sql_ctx)
24802428
_.__doc__ = doc
24812429
return _
@@ -2486,7 +2434,7 @@ def _reverse_op(name, doc="binary operator"):
24862434
"""
24872435
def _(self, other):
24882436
jother = _create_column_from_literal(other)
2489-
jc = getattr(jother, _scalaMethod(name))(self._jc)
2437+
jc = getattr(jother, name)(self._jc)
24902438
return Column(jc, self.sql_ctx)
24912439
_.__doc__ = doc
24922440
return _
@@ -2513,34 +2461,33 @@ def __init__(self, jc, sql_ctx=None):
25132461
super(Column, self).__init__(jc, sql_ctx)
25142462

25152463
# arithmetic operators
2516-
__neg__ = _unary_op("unary_-")
2517-
__add__ = _bin_op("+")
2518-
__sub__ = _bin_op("-")
2519-
__mul__ = _bin_op("*")
2520-
__div__ = _bin_op("/")
2521-
__mod__ = _bin_op("%")
2522-
__radd__ = _bin_op("+")
2523-
__rsub__ = _reverse_op("-")
2524-
__rmul__ = _bin_op("*")
2525-
__rdiv__ = _reverse_op("/")
2526-
__rmod__ = _reverse_op("%")
2527-
__abs__ = _unary_op("abs")
2464+
__neg__ = _dsl_op("negate")
2465+
__add__ = _bin_op("plus")
2466+
__sub__ = _bin_op("minus")
2467+
__mul__ = _bin_op("multiply")
2468+
__div__ = _bin_op("divide")
2469+
__mod__ = _bin_op("mod")
2470+
__radd__ = _bin_op("plus")
2471+
__rsub__ = _reverse_op("minus")
2472+
__rmul__ = _bin_op("multiply")
2473+
__rdiv__ = _reverse_op("divide")
2474+
__rmod__ = _reverse_op("mod")
25282475

25292476
# logistic operators
2530-
__eq__ = _bin_op("===")
2531-
__ne__ = _bin_op("!==")
2532-
__lt__ = _bin_op("<")
2533-
__le__ = _bin_op("<=")
2534-
__ge__ = _bin_op(">=")
2535-
__gt__ = _bin_op(">")
2477+
__eq__ = _bin_op("equalTo")
2478+
__ne__ = _bin_op("notEqual")
2479+
__lt__ = _bin_op("lt")
2480+
__le__ = _bin_op("leq")
2481+
__ge__ = _bin_op("geq")
2482+
__gt__ = _bin_op("gt")
25362483

25372484
# `and`, `or`, `not` cannot be overloaded in Python,
25382485
# so use bitwise operators as boolean operators
2539-
__and__ = _bin_op('&&')
2540-
__or__ = _bin_op('||')
2541-
__invert__ = _unary_op('unary_!')
2542-
__rand__ = _bin_op("&&")
2543-
__ror__ = _bin_op("||")
2486+
__and__ = _bin_op('and')
2487+
__or__ = _bin_op('or')
2488+
__invert__ = _dsl_op('not')
2489+
__rand__ = _bin_op("and")
2490+
__ror__ = _bin_op("or")
25442491

25452492
# container operators
25462493
__contains__ = _bin_op("contains")
@@ -2582,24 +2529,20 @@ def substr(self, startPos, length):
25822529
isNull = _unary_op("isNull", "True if the current expression is null.")
25832530
isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")
25842531

2585-
# `as` is keyword
25862532
def alias(self, alias):
25872533
"""Return a alias for this column
25882534
2589-
>>> df.age.As("age2").collect()
2590-
[Row(age2=2), Row(age2=5)]
25912535
>>> df.age.alias("age2").collect()
25922536
[Row(age2=2), Row(age2=5)]
25932537
"""
25942538
return Column(getattr(self._jc, "as")(alias), self.sql_ctx)
2595-
As = alias
25962539

25972540
def cast(self, dataType):
25982541
""" Convert the column into type `dataType`
25992542
2600-
>>> df.select(df.age.cast("string").As('ages')).collect()
2543+
>>> df.select(df.age.cast("string").alias('ages')).collect()
26012544
[Row(ages=u'2'), Row(ages=u'5')]
2602-
>>> df.select(df.age.cast(StringType()).As('ages')).collect()
2545+
>>> df.select(df.age.cast(StringType()).alias('ages')).collect()
26032546
[Row(ages=u'2'), Row(ages=u'5')]
26042547
"""
26052548
if self.sql_ctx is None:
@@ -2626,6 +2569,40 @@ def _(col):
26262569
return staticmethod(_)
26272570

26282571

2572+
class UserDefinedFunction(object):
2573+
def __init__(self, func, returnType):
2574+
self.func = func
2575+
self.returnType = returnType
2576+
self._broadcast = None
2577+
self._judf = self._create_judf()
2578+
2579+
def _create_judf(self):
2580+
f = self.func # put it in closure `func`
2581+
func = lambda _, it: imap(lambda x: f(*x), it)
2582+
ser = AutoBatchedSerializer(PickleSerializer())
2583+
command = (func, None, ser, ser)
2584+
sc = SparkContext._active_spark_context
2585+
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
2586+
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
2587+
jdt = ssql_ctx.parseDataType(self.returnType.json())
2588+
judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env,
2589+
includes, sc.pythonExec, broadcast_vars,
2590+
sc._javaAccumulator, jdt)
2591+
return judf
2592+
2593+
def __del__(self):
2594+
if self._broadcast is not None:
2595+
self._broadcast.unpersist()
2596+
self._broadcast = None
2597+
2598+
def __call__(self, *cols):
2599+
sc = SparkContext._active_spark_context
2600+
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
2601+
sc._gateway._gateway_client)
2602+
jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols))
2603+
return Column(jc)
2604+
2605+
26292606
class Dsl(object):
26302607
"""
26312608
A collections of builtin aggregators
@@ -2659,7 +2636,7 @@ def countDistinct(col, *cols):
26592636
""" Return a new Column for distinct count of (col, *cols)
26602637
26612638
>>> from pyspark.sql import Dsl
2662-
>>> df.agg(Dsl.countDistinct(df.age, df.name).As('c')).collect()
2639+
>>> df.agg(Dsl.countDistinct(df.age, df.name).alias('c')).collect()
26632640
[Row(c=2)]
26642641
"""
26652642
sc = SparkContext._active_spark_context
@@ -2674,7 +2651,7 @@ def approxCountDistinct(col, rsd=None):
26742651
""" Return a new Column for approxiate distinct count of (col, *cols)
26752652
26762653
>>> from pyspark.sql import Dsl
2677-
>>> df.agg(Dsl.approxCountDistinct(df.age).As('c')).collect()
2654+
>>> df.agg(Dsl.approxCountDistinct(df.age).alias('c')).collect()
26782655
[Row(c=2)]
26792656
"""
26802657
sc = SparkContext._active_spark_context
@@ -2684,6 +2661,16 @@ def approxCountDistinct(col, rsd=None):
26842661
jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd)
26852662
return Column(jc)
26862663

2664+
@staticmethod
2665+
def udf(f, returnType=StringType()):
2666+
"""Create a user defined function (UDF)
2667+
2668+
>>> slen = Dsl.udf(lambda s: len(s), IntegerType())
2669+
>>> df.select(slen(df.name).alias('slen')).collect()
2670+
[Row(slen=5), Row(slen=3)]
2671+
"""
2672+
return UserDefinedFunction(f, returnType)
2673+
26872674

26882675
def _test():
26892676
import doctest

0 commit comments

Comments
 (0)