From 63ac57996498553505300a66edaaa3a41fdd42e7 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Fri, 8 May 2015 00:20:45 -0700 Subject: [PATCH 1/5] add na.replace in pyspark --- .../apache/spark/api/python/PythonUtils.scala | 7 ++ python/pyspark/sql/dataframe.py | 85 +++++++++++++++++++ 2 files changed, 92 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index acbaba6791850..2761df1619296 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -53,4 +53,11 @@ private[spark] object PythonUtils { def toSeq[T](cols: JList[T]): Seq[T] = { cols.toList.toSeq } + + /** + * Convert java map of K, V into Map of K, V (for calling API with varargs) + */ + def toMap[K, V](jm: java.util.Map[K, V]): Map[K, V] = { + jm.toMap + } } diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 72180f6d05fbc..25a3b4d2cad18 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -578,6 +578,10 @@ def _jseq(self, cols, converter=None): """Return a JVM Seq of Columns from a list of Column or names""" return _to_seq(self.sql_ctx._sc, cols, converter) + def _jmap(self, jm): + """Return a JVM Map from a dict""" + return _to_map(self.sql_ctx._sc, jm) + def _jcols(self, *cols): """Return a JVM Seq of Columns from a list of Column or column names @@ -924,6 +928,77 @@ def fillna(self, value, subset=None): return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) + def replacena(self, to_replace, value, subset=None): + """Returns a new :class:`DataFrame` replacing a value with another value. + alias for ``na.replace()``. + + :param to_replace: int, long, float, string, or list. + Value to be replaced. + The replacement value must be an int, long, float, or string. + :param value: int, long, float, string, or list. + Value to use to replace holes. + The replacement value must be an int, long, float, or string. + :param subset: optional list of column names to consider. + Columns specified in subset that do not have matching data type are ignored. + For example, if `value` is a string, and subset contains a non-string column, + then the non-string column is simply ignored. + >>> df4.na.replace(10, 20).show() + +----+------+-----+ + | age|height| name| + +----+------+-----+ + | 20| 80|Alice| + | 5| null| Bob| + |null| null| Tom| + |null| null| null| + +----+------+-----+ + + >>> df4.replacena(['Alice', 'Bob'], ['A', 'B'], 'name').show() + +----+------+----+ + | age|height|name| + +----+------+----+ + | 10| 80| A| + | 5| null| B| + |null| null| Tom| + |null| null|null| + +----+------+----+ + """ + if not isinstance(to_replace, (float, int, long, basestring, list, tuple)): + raise ValueError("to_replace should be a float, int, long, string, list, or tuple") + + if not isinstance(value, (float, int, long, basestring, list, tuple)): + raise ValueError("value should be a float, int, long, string, list, or tuple") + + if isinstance(to_replace, dict) and not isinstance(value, (list, tuple, dict)): + raise TypeError("") + + if isinstance(to_replace, (float, int, long, basestring)): + to_replace = [to_replace] + + if isinstance(value, (float, int, long, basestring)): + value = [value] + + if isinstance(to_replace, tuple): + to_replace = list(to_replace) + + if isinstance(value, tuple): + value = list(value) + + if isinstance(to_replace, list) and isinstance(value, list): + if len(to_replace) != len(value): + raise ValueError("to_replace and value lists should be of the same length") + + rep_dict = dict(zip(to_replace, value)) + if subset is None: + return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx) + elif isinstance(subset, basestring): + subset = [subset] + + if not isinstance(subset, (list, tuple)): + raise ValueError("subset should be a list or tuple of column names") + + return DataFrame( + self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx) + def corr(self, col1, col2, method=None): """ Calculates the correlation of two columns of a DataFrame as a double value. Currently only @@ -1225,6 +1300,11 @@ def _to_seq(sc, cols, converter=None): cols = [converter(c) for c in cols] return sc._jvm.PythonUtils.toSeq(cols) +def _to_map(sc, jm): + """ + Convert a dict into a JVM Map. + """ + return sc._jvm.PythonUtils.toMap(jm) def _unary_op(name, doc="unary operator"): """ Create a method for given unary operator """ @@ -1482,6 +1562,11 @@ def fill(self, value, subset=None): fill.__doc__ = DataFrame.fillna.__doc__ + def replace(self, to_replace, value, subset=None): + return self.df.replacena(to_replace=to_replace, value=value, subset=subset) + + replace.__doc__ = DataFrame.replacena.__doc__ + class DataFrameStatFunctions(object): """Functionality for statistic functions with :class:`DataFrame`. From af0268abeb918f43b2e536c16e14c2ef78a7f0c4 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Fri, 8 May 2015 00:53:43 -0700 Subject: [PATCH 2/5] remove na --- python/pyspark/sql/dataframe.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 25a3b4d2cad18..988c29217032d 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -928,9 +928,8 @@ def fillna(self, value, subset=None): return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) - def replacena(self, to_replace, value, subset=None): + def replace(self, to_replace, value, subset=None): """Returns a new :class:`DataFrame` replacing a value with another value. - alias for ``na.replace()``. :param to_replace: int, long, float, string, or list. Value to be replaced. @@ -942,7 +941,7 @@ def replacena(self, to_replace, value, subset=None): Columns specified in subset that do not have matching data type are ignored. For example, if `value` is a string, and subset contains a non-string column, then the non-string column is simply ignored. - >>> df4.na.replace(10, 20).show() + >>> df4.replace(10, 20).show() +----+------+-----+ | age|height| name| +----+------+-----+ @@ -952,7 +951,7 @@ def replacena(self, to_replace, value, subset=None): |null| null| null| +----+------+-----+ - >>> df4.replacena(['Alice', 'Bob'], ['A', 'B'], 'name').show() + >>> df4.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show() +----+------+----+ | age|height|name| +----+------+----+ @@ -1300,12 +1299,14 @@ def _to_seq(sc, cols, converter=None): cols = [converter(c) for c in cols] return sc._jvm.PythonUtils.toSeq(cols) + def _to_map(sc, jm): """ Convert a dict into a JVM Map. """ return sc._jvm.PythonUtils.toMap(jm) + def _unary_op(name, doc="unary operator"): """ Create a method for given unary operator """ def _(self): @@ -1562,11 +1563,6 @@ def fill(self, value, subset=None): fill.__doc__ = DataFrame.fillna.__doc__ - def replace(self, to_replace, value, subset=None): - return self.df.replacena(to_replace=to_replace, value=value, subset=subset) - - replace.__doc__ = DataFrame.replacena.__doc__ - class DataFrameStatFunctions(object): """Functionality for statistic functions with :class:`DataFrame`. From 9e232e75ce79317cb1c7bcd41a841df4838a57b6 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Mon, 11 May 2015 03:58:53 -0700 Subject: [PATCH 3/5] rename scala map --- .../scala/org/apache/spark/api/python/PythonUtils.scala | 2 +- python/pyspark/sql/dataframe.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 2761df1619296..b5889a1b534ef 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -57,7 +57,7 @@ private[spark] object PythonUtils { /** * Convert java map of K, V into Map of K, V (for calling API with varargs) */ - def toMap[K, V](jm: java.util.Map[K, V]): Map[K, V] = { + def toScalaMap[K, V](jm: java.util.Map[K, V]): Map[K, V] = { jm.toMap } } diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 988c29217032d..7b87774a323a3 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -579,8 +579,8 @@ def _jseq(self, cols, converter=None): return _to_seq(self.sql_ctx._sc, cols, converter) def _jmap(self, jm): - """Return a JVM Map from a dict""" - return _to_map(self.sql_ctx._sc, jm) + """Return a JVM Scala Map from a dict""" + return _to_scala_map(self.sql_ctx._sc, jm) def _jcols(self, *cols): """Return a JVM Seq of Columns from a list of Column or column names @@ -1300,11 +1300,11 @@ def _to_seq(sc, cols, converter=None): return sc._jvm.PythonUtils.toSeq(cols) -def _to_map(sc, jm): +def _to_scala_map(sc, jm): """ Convert a dict into a JVM Map. """ - return sc._jvm.PythonUtils.toMap(jm) + return sc._jvm.PythonUtils.toScalaMap(jm) def _unary_op(name, doc="unary operator"): From 4a148f7c567be89171b02c24609ae35a0ae5cf1f Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Mon, 11 May 2015 23:06:34 -0700 Subject: [PATCH 4/5] to_replace support dict, value support single value, and add full tests --- python/pyspark/sql/dataframe.py | 24 ++++++++++------- python/pyspark/sql/tests.py | 48 +++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 7b87774a323a3..11ccd08fbde50 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -933,10 +933,13 @@ def replace(self, to_replace, value, subset=None): :param to_replace: int, long, float, string, or list. Value to be replaced. - The replacement value must be an int, long, float, or string. + If the value is a dict, then `value` is ignored and `to_replace` must be a + mapping from column name (string) to replacement value. The value to be + replaced must be an int, long, float, or string. :param value: int, long, float, string, or list. Value to use to replace holes. - The replacement value must be an int, long, float, or string. + The replacement value must be an int, long, float, or string. If `value` is a + list or tuple, `value` should be of the same length with `to_replace`. :param subset: optional list of column names to consider. Columns specified in subset that do not have matching data type are ignored. For example, if `value` is a string, and subset contains a non-string column, @@ -961,21 +964,18 @@ def replace(self, to_replace, value, subset=None): |null| null|null| +----+------+----+ """ - if not isinstance(to_replace, (float, int, long, basestring, list, tuple)): - raise ValueError("to_replace should be a float, int, long, string, list, or tuple") + if not isinstance(to_replace, (float, int, long, basestring, list, tuple, dict)): + raise ValueError( + "to_replace should be a float, int, long, string, list, tuple, or dict") if not isinstance(value, (float, int, long, basestring, list, tuple)): raise ValueError("value should be a float, int, long, string, list, or tuple") - if isinstance(to_replace, dict) and not isinstance(value, (list, tuple, dict)): - raise TypeError("") + rep_dict = dict() if isinstance(to_replace, (float, int, long, basestring)): to_replace = [to_replace] - if isinstance(value, (float, int, long, basestring)): - value = [value] - if isinstance(to_replace, tuple): to_replace = list(to_replace) @@ -985,8 +985,12 @@ def replace(self, to_replace, value, subset=None): if isinstance(to_replace, list) and isinstance(value, list): if len(to_replace) != len(value): raise ValueError("to_replace and value lists should be of the same length") + rep_dict = dict(zip(to_replace, value)) + elif isinstance(to_replace, list) and isinstance(value, (float, int, long, basestring)): + rep_dict = {tr: value for tr in to_replace} + elif isinstance(to_replace, dict): + rep_dict = to_replace - rep_dict = dict(zip(to_replace, value)) if subset is None: return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx) elif isinstance(subset, basestring): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 7e63f4d6461f6..1922d03af61da 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -665,6 +665,54 @@ def test_bitwise_operations(self): result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict() self.assertEqual(~75, result['~b']) + def test_replace(self): + schema = StructType([ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", DoubleType(), True)]) + + # replace with int + row = self.sqlCtx.createDataFrame([(u'Alice', 10, 10.0)], schema).replace(10, 20).first() + self.assertEqual(row.age, 20) + self.assertEqual(row.height, 20.0) + + # replace with double + row = self.sqlCtx.createDataFrame( + [(u'Alice', 80, 80.0)], schema).replace(80.0, 82.1).first() + self.assertEqual(row.age, 82) + self.assertEqual(row.height, 82.1) + + # replace with string + row = self.sqlCtx.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace(u'Alice', u'Ann').first() + self.assertEqual(row.name, u"Ann") + self.assertEqual(row.age, 10) + + # replace with subset specified by a string of a column name w/ actual change + row = self.sqlCtx.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='age').first() + self.assertEqual(row.age, 20) + + # replace with subset specified by a string of a column name w/o actual change + row = self.sqlCtx.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='height').first() + self.assertEqual(row.age, 10) + + # replace with subset specified with one column replaced, another column not in subset + # stays unchanged. + row = self.sqlCtx.createDataFrame( + [(u'Alice', 10, 10.0)], schema).replace(10, 20, subset=['name', 'age']).first() + self.assertEqual(row.name, u'Alice') + self.assertEqual(row.age, 20) + self.assertEqual(row.height, 10.0) + + # replace with subset specified but no column will be replaced + row = self.sqlCtx.createDataFrame( + [(u'Alice', 10, None)], schema).replace(10, 20, subset=['name', 'height']).first() + self.assertEqual(row.name, u'Alice') + self.assertEqual(row.age, 10) + self.assertEqual(row.height, None) + class HiveContextSQLTests(ReusedPySparkTestCase): From 672efba1a97096280844b1c06a6de71fb8af79bc Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 12 May 2015 01:48:34 -0700 Subject: [PATCH 5/5] remove py2.7 feature --- python/pyspark/sql/dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 11ccd08fbde50..078acfdf7e2df 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -987,7 +987,7 @@ def replace(self, to_replace, value, subset=None): raise ValueError("to_replace and value lists should be of the same length") rep_dict = dict(zip(to_replace, value)) elif isinstance(to_replace, list) and isinstance(value, (float, int, long, basestring)): - rep_dict = {tr: value for tr in to_replace} + rep_dict = dict([(tr, value) for tr in to_replace]) elif isinstance(to_replace, dict): rep_dict = to_replace