Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2162,6 +2162,25 @@ def toLocalIterator(self):
yield row


def _prepare_for_python_RDD(sc, command, obj=None):
# the serialized command will be compressed by broadcast
ser = CloudPickleSerializer()
pickled_command = ser.dumps(command)
if len(pickled_command) > (1 << 20): # 1M
broadcast = sc.broadcast(pickled_command)
pickled_command = ser.dumps(broadcast)
# tracking the life cycle by obj
if obj is not None:
obj._broadcast = broadcast
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in sc._pickled_broadcast_vars],
sc._gateway._gateway_client)
sc._pickled_broadcast_vars.clear()
env = MapConverter().convert(sc.environment, sc._gateway._gateway_client)
includes = ListConverter().convert(sc._python_includes, sc._gateway._gateway_client)
return pickled_command, broadcast_vars, env, includes


class PipelinedRDD(RDD):

"""
Expand Down Expand Up @@ -2228,25 +2247,12 @@ def _jrdd(self):

command = (self.func, profiler, self._prev_jrdd_deserializer,
self._jrdd_deserializer)
# the serialized command will be compressed by broadcast
ser = CloudPickleSerializer()
pickled_command = ser.dumps(command)
if len(pickled_command) > (1 << 20): # 1M
self._broadcast = self.ctx.broadcast(pickled_command)
pickled_command = ser.dumps(self._broadcast)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
self.ctx._gateway._gateway_client)
self.ctx._pickled_broadcast_vars.clear()
env = MapConverter().convert(self.ctx.environment,
self.ctx._gateway._gateway_client)
includes = ListConverter().convert(self.ctx._python_includes,
self.ctx._gateway._gateway_client)
pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self.ctx, command, self)
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
bytearray(pickled_command),
bytearray(pickled_cmd),
env, includes, self.preservesPartitioning,
self.ctx.pythonExec,
broadcast_vars, self.ctx._javaAccumulator)
bvars, self.ctx._javaAccumulator)
self._jrdd_val = python_rdd.asJavaRDD()

if profiler:
Expand Down
195 changes: 91 additions & 104 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from py4j.java_collections import ListConverter, MapConverter

from pyspark.context import SparkContext
from pyspark.rdd import RDD
from pyspark.rdd import RDD, _prepare_for_python_RDD
from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \
CloudPickleSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
Expand Down Expand Up @@ -1274,28 +1274,15 @@ def registerFunction(self, name, f, returnType=StringType()):
[Row(c0=4)]
"""
func = lambda _, it: imap(lambda x: f(*x), it)
command = (func, None,
AutoBatchedSerializer(PickleSerializer()),
AutoBatchedSerializer(PickleSerializer()))
ser = CloudPickleSerializer()
pickled_command = ser.dumps(command)
if len(pickled_command) > (1 << 20): # 1M
broadcast = self._sc.broadcast(pickled_command)
pickled_command = ser.dumps(broadcast)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self._sc._pickled_broadcast_vars],
self._sc._gateway._gateway_client)
self._sc._pickled_broadcast_vars.clear()
env = MapConverter().convert(self._sc.environment,
self._sc._gateway._gateway_client)
includes = ListConverter().convert(self._sc._python_includes,
self._sc._gateway._gateway_client)
ser = AutoBatchedSerializer(PickleSerializer())
command = (func, None, ser, ser)
pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self)
self._ssql_ctx.udf().registerPython(name,
bytearray(pickled_command),
bytearray(pickled_cmd),
env,
includes,
self._sc.pythonExec,
broadcast_vars,
bvars,
self._sc._javaAccumulator,
returnType.json())

Expand Down Expand Up @@ -2077,9 +2064,9 @@ def dtypes(self):
"""Return all column names and their data types as a list.

>>> df.dtypes
[(u'age', 'IntegerType'), (u'name', 'StringType')]
[('age', 'integer'), ('name', 'string')]
"""
return [(f.name, str(f.dataType)) for f in self.schema().fields]
return [(str(f.name), f.dataType.jsonValue()) for f in self.schema().fields]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should use simplestring (which isn't available yet...) we can change it in the future


@property
def columns(self):
Expand Down Expand Up @@ -2194,7 +2181,7 @@ def select(self, *cols):
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
>>> df.select('name', 'age').collect()
[Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
>>> df.select(df.name, (df.age + 10).As('age')).collect()
>>> df.select(df.name, (df.age + 10).alias('age')).collect()
[Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
"""
if not cols:
Expand Down Expand Up @@ -2295,25 +2282,13 @@ def subtract(self, other):
"""
return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)

def sample(self, withReplacement, fraction, seed=None):
""" Return a new DataFrame by sampling a fraction of rows.

>>> df.sample(False, 0.5, 10).collect()
[Row(age=2, name=u'Alice')]
"""
if seed is None:
jdf = self._jdf.sample(withReplacement, fraction)
else:
jdf = self._jdf.sample(withReplacement, fraction, seed)
return DataFrame(jdf, self.sql_ctx)

def addColumn(self, colName, col):
""" Return a new :class:`DataFrame` by adding a column.

>>> df.addColumn('age2', df.age + 2).collect()
[Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
"""
return self.select('*', col.As(colName))
return self.select('*', col.alias(colName))


# Having SchemaRDD for backward compatibility (for docs)
Expand Down Expand Up @@ -2408,28 +2383,6 @@ def sum(self):
group."""


SCALA_METHOD_MAPPINGS = {
'=': '$eq',
'>': '$greater',
'<': '$less',
'+': '$plus',
'-': '$minus',
'*': '$times',
'/': '$div',
'!': '$bang',
'@': '$at',
'#': '$hash',
'%': '$percent',
'^': '$up',
'&': '$amp',
'~': '$tilde',
'?': '$qmark',
'|': '$bar',
'\\': '$bslash',
':': '$colon',
}


def _create_column_from_literal(literal):
sc = SparkContext._active_spark_context
return sc._jvm.Dsl.lit(literal)
Expand All @@ -2448,23 +2401,18 @@ def _to_java_column(col):
return jcol


def _scalaMethod(name):
""" Translate operators into methodName in Scala

>>> _scalaMethod('+')
'$plus'
>>> _scalaMethod('>=')
'$greater$eq'
>>> _scalaMethod('cast')
'cast'
"""
return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name)


def _unary_op(name, doc="unary operator"):
""" Create a method for given unary operator """
def _(self):
jc = getattr(self._jc, _scalaMethod(name))()
jc = getattr(self._jc, name)()
return Column(jc, self.sql_ctx)
_.__doc__ = doc
return _


def _dsl_op(name, doc=''):
def _(self):
jc = getattr(self._sc._jvm.Dsl, name)(self._jc)
return Column(jc, self.sql_ctx)
_.__doc__ = doc
return _
Expand All @@ -2475,7 +2423,7 @@ def _bin_op(name, doc="binary operator"):
"""
def _(self, other):
jc = other._jc if isinstance(other, Column) else other
njc = getattr(self._jc, _scalaMethod(name))(jc)
njc = getattr(self._jc, name)(jc)
return Column(njc, self.sql_ctx)
_.__doc__ = doc
return _
Expand All @@ -2486,7 +2434,7 @@ def _reverse_op(name, doc="binary operator"):
"""
def _(self, other):
jother = _create_column_from_literal(other)
jc = getattr(jother, _scalaMethod(name))(self._jc)
jc = getattr(jother, name)(self._jc)
return Column(jc, self.sql_ctx)
_.__doc__ = doc
return _
Expand All @@ -2513,34 +2461,33 @@ def __init__(self, jc, sql_ctx=None):
super(Column, self).__init__(jc, sql_ctx)

# arithmetic operators
__neg__ = _unary_op("unary_-")
__add__ = _bin_op("+")
__sub__ = _bin_op("-")
__mul__ = _bin_op("*")
__div__ = _bin_op("/")
__mod__ = _bin_op("%")
__radd__ = _bin_op("+")
__rsub__ = _reverse_op("-")
__rmul__ = _bin_op("*")
__rdiv__ = _reverse_op("/")
__rmod__ = _reverse_op("%")
__abs__ = _unary_op("abs")
__neg__ = _dsl_op("negate")
__add__ = _bin_op("plus")
__sub__ = _bin_op("minus")
__mul__ = _bin_op("multiply")
__div__ = _bin_op("divide")
__mod__ = _bin_op("mod")
__radd__ = _bin_op("plus")
__rsub__ = _reverse_op("minus")
__rmul__ = _bin_op("multiply")
__rdiv__ = _reverse_op("divide")
__rmod__ = _reverse_op("mod")

# logistic operators
__eq__ = _bin_op("===")
__ne__ = _bin_op("!==")
__lt__ = _bin_op("<")
__le__ = _bin_op("<=")
__ge__ = _bin_op(">=")
__gt__ = _bin_op(">")
__eq__ = _bin_op("equalTo")
__ne__ = _bin_op("notEqual")
__lt__ = _bin_op("lt")
__le__ = _bin_op("leq")
__ge__ = _bin_op("geq")
__gt__ = _bin_op("gt")

# `and`, `or`, `not` cannot be overloaded in Python,
# so use bitwise operators as boolean operators
__and__ = _bin_op('&&')
__or__ = _bin_op('||')
__invert__ = _unary_op('unary_!')
__rand__ = _bin_op("&&")
__ror__ = _bin_op("||")
__and__ = _bin_op('and')
__or__ = _bin_op('or')
__invert__ = _dsl_op('not')
__rand__ = _bin_op("and")
__ror__ = _bin_op("or")

# container operators
__contains__ = _bin_op("contains")
Expand Down Expand Up @@ -2582,24 +2529,20 @@ def substr(self, startPos, length):
isNull = _unary_op("isNull", "True if the current expression is null.")
isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")

# `as` is keyword
def alias(self, alias):
"""Return a alias for this column

>>> df.age.As("age2").collect()
[Row(age2=2), Row(age2=5)]
>>> df.age.alias("age2").collect()
[Row(age2=2), Row(age2=5)]
"""
return Column(getattr(self._jc, "as")(alias), self.sql_ctx)
As = alias

def cast(self, dataType):
""" Convert the column into type `dataType`

>>> df.select(df.age.cast("string").As('ages')).collect()
>>> df.select(df.age.cast("string").alias('ages')).collect()
[Row(ages=u'2'), Row(ages=u'5')]
>>> df.select(df.age.cast(StringType()).As('ages')).collect()
>>> df.select(df.age.cast(StringType()).alias('ages')).collect()
[Row(ages=u'2'), Row(ages=u'5')]
"""
if self.sql_ctx is None:
Expand All @@ -2626,6 +2569,40 @@ def _(col):
return staticmethod(_)


class UserDefinedFunction(object):
def __init__(self, func, returnType):
self.func = func
self.returnType = returnType
self._broadcast = None
self._judf = self._create_judf()

def _create_judf(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can u add some inline comment explaining what's happening in this function?

f = self.func # put it in closure `func`
func = lambda _, it: imap(lambda x: f(*x), it)
ser = AutoBatchedSerializer(PickleSerializer())
command = (func, None, ser, ser)
sc = SparkContext._active_spark_context
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
jdt = ssql_ctx.parseDataType(self.returnType.json())
judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env,
includes, sc.pythonExec, broadcast_vars,
sc._javaAccumulator, jdt)
return judf

def __del__(self):
if self._broadcast is not None:
self._broadcast.unpersist()
self._broadcast = None

def __call__(self, *cols):
sc = SparkContext._active_spark_context
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
sc._gateway._gateway_client)
jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols))
return Column(jc)


class Dsl(object):
"""
A collections of builtin aggregators
Expand Down Expand Up @@ -2659,7 +2636,7 @@ def countDistinct(col, *cols):
""" Return a new Column for distinct count of (col, *cols)

>>> from pyspark.sql import Dsl
>>> df.agg(Dsl.countDistinct(df.age, df.name).As('c')).collect()
>>> df.agg(Dsl.countDistinct(df.age, df.name).alias('c')).collect()
[Row(c=2)]
"""
sc = SparkContext._active_spark_context
Expand All @@ -2674,7 +2651,7 @@ def approxCountDistinct(col, rsd=None):
""" Return a new Column for approxiate distinct count of (col, *cols)

>>> from pyspark.sql import Dsl
>>> df.agg(Dsl.approxCountDistinct(df.age).As('c')).collect()
>>> df.agg(Dsl.approxCountDistinct(df.age).alias('c')).collect()
[Row(c=2)]
"""
sc = SparkContext._active_spark_context
Expand All @@ -2684,6 +2661,16 @@ def approxCountDistinct(col, rsd=None):
jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd)
return Column(jc)

@staticmethod
def udf(f, returnType=StringType()):
"""Create a user defined function (UDF)

>>> slen = Dsl.udf(lambda s: len(s), IntegerType())
>>> df.select(slen(df.name).alias('slen')).collect()
[Row(slen=5), Row(slen=3)]
"""
return UserDefinedFunction(f, returnType)


def _test():
import doctest
Expand Down
Loading