From 76d63461d2c512f5a6519d25dcaa14cfa8ec6468 Mon Sep 17 00:00:00 2001 From: kaka1992 Date: Wed, 6 May 2015 11:20:01 +0800 Subject: [PATCH 1/7] [SPARK-7321][SQL] Add Column expression for conditional statements (if, case) --- python/pyspark/sql/dataframe.py | 18 +++++++++++++ .../scala/org/apache/spark/sql/Column.scala | 27 +++++++++++++++++++ .../spark/sql/ColumnExpressionSuite.scala | 18 +++++++++++++ 3 files changed, 63 insertions(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 24f370543def4..17eef7070eab2 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1412,6 +1412,24 @@ def between(self, lowerBound, upperBound): """ return (self >= lowerBound) & (self <= upperBound) + @ignore_unicode_prefix + def when(self, whenExpr, thenExpr): + """ A case when otherwise expression.. + >>> df.select(df.age.when(2, 3).otherwise(4).alias("age")).collect() + [Row(age=3), Row(age=4)] + >>> df.select(df.age.when(2, 3).alias("age")).collect() + [Row(age=3), Row(age=None)] + >>> df.select(df.age.otherwise(4).alias("age")).collect() + [Row(age=4), Row(age=4)] + """ + jc = self._jc.when(whenExpr, thenExpr) + return Column(jc) + + @ignore_unicode_prefix + def otherwise(self, elseExpr): + jc = self._jc.otherwise(elseExpr) + return Column(jc) + def __repr__(self): return 'Column<%s>' % self._jc.toString().encode('utf8') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index c0503bf047052..afe0193a56f1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -295,6 +295,33 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def eqNullSafe(other: Any): Column = this <=> other + /** + * Case When Otherwise. + * {{{ + * people.select( people("age").when(18, "SELECTED").other("IGNORED") ) + * }}} + * + * @group expr_ops + */ + def when(whenExpr: Any, thenExpr: Any):Column = { + this.expr match { + case CaseWhen(branches: Seq[Expression]) => + val caseExpr = branches.head.asInstanceOf[EqualNullSafe].left + CaseWhen(branches ++ Seq((caseExpr <=> whenExpr).expr, lit(thenExpr).expr)) + case _ => + CaseWhen(Seq((this <=> whenExpr).expr, lit(thenExpr).expr)) + } + } + + def otherwise(elseExpr: Any):Column = { + this.expr match { + case CaseWhen(branches: Seq[Expression]) => + CaseWhen(branches :+ lit(elseExpr).expr) + case _ => + CaseWhen(Seq(lit(true).expr, lit(elseExpr).expr)) + } + } + /** * True if the current column is between the lower bound and upper bound, inclusive. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 3c1ad656fc855..26997c39224c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -257,6 +257,24 @@ class ColumnExpressionSuite extends QueryTest { Row(false, true) :: Row(true, false) :: Row(true, true) :: Nil) } + test("SPARK-7321 case") { + val testData = (1 to 3).map(i => TestData(i, i.toString)).toDF() + checkAnswer( + testData.select($"key".when(1, -1).when(2, -2).otherwise(0)), + Seq(Row(-1), Row(-2), Row(0)) + ) + + checkAnswer( + testData.select($"key".when(1, -1).when(2, -2)), + Seq(Row(-1), Row(-2), Row(null)) + ) + + checkAnswer( + testData.select($"key".otherwise(0)), + Seq(Row(0), Row(0), Row(0)) + ) + } + test("sqrt") { checkAnswer( testData.select(sqrt('key)).orderBy('key.asc), From 801009e798dc3c82f549241e2764e300ad1295da Mon Sep 17 00:00:00 2001 From: kaka1992 Date: Wed, 6 May 2015 22:55:29 +0800 Subject: [PATCH 2/7] Update --- python/pyspark/sql/dataframe.py | 14 ++------------ python/pyspark/sql/functions.py | 10 ++++++++++ .../main/scala/org/apache/spark/sql/Column.scala | 7 +++---- .../scala/org/apache/spark/sql/functions.scala | 12 ++++++++++++ .../apache/spark/sql/ColumnExpressionSuite.scala | 11 +++-------- 5 files changed, 30 insertions(+), 24 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 17eef7070eab2..bbfda34a045f3 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1414,21 +1414,11 @@ def between(self, lowerBound, upperBound): @ignore_unicode_prefix def when(self, whenExpr, thenExpr): - """ A case when otherwise expression.. - >>> df.select(df.age.when(2, 3).otherwise(4).alias("age")).collect() - [Row(age=3), Row(age=4)] - >>> df.select(df.age.when(2, 3).alias("age")).collect() - [Row(age=3), Row(age=None)] - >>> df.select(df.age.otherwise(4).alias("age")).collect() - [Row(age=4), Row(age=4)] - """ - jc = self._jc.when(whenExpr, thenExpr) - return Column(jc) + return self._jc.when(whenExpr, thenExpr) @ignore_unicode_prefix def otherwise(self, elseExpr): - jc = self._jc.otherwise(elseExpr) - return Column(jc) + return self._jc.otherwise(elseExpr) def __repr__(self): return 'Column<%s>' % self._jc.toString().encode('utf8') diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 641220a264295..a2ba9375cf9be 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -146,6 +146,16 @@ def monotonicallyIncreasingId(): sc = SparkContext._active_spark_context return Column(sc._jvm.functions.monotonicallyIncreasingId()) +def when(whenExpr, thenExpr): + """ A case when otherwise expression. + >>> df.select(when(df.age == 2, 3).otherwise(4).alias("age")).collect() + [Row(age=3), Row(age=4)] + >>> df.select(when(df.age == 2, 3).alias("age")).collect() + [Row(age=3), Row(age=None)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.when(whenExpr, thenExpr) + return Column(jc) def rand(seed=None): """Generates a random column with i.i.d. samples from U[0.0, 1.0]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index afe0193a56f1e..402e346ae030f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -298,7 +298,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { /** * Case When Otherwise. * {{{ - * people.select( people("age").when(18, "SELECTED").other("IGNORED") ) + * people.select( when(people("age") === 18, "SELECTED").other("IGNORED") ) * }}} * * @group expr_ops @@ -306,10 +306,9 @@ class Column(protected[sql] val expr: Expression) extends Logging { def when(whenExpr: Any, thenExpr: Any):Column = { this.expr match { case CaseWhen(branches: Seq[Expression]) => - val caseExpr = branches.head.asInstanceOf[EqualNullSafe].left - CaseWhen(branches ++ Seq((caseExpr <=> whenExpr).expr, lit(thenExpr).expr)) + CaseWhen(branches ++ Seq(lit(whenExpr).expr, lit(thenExpr).expr)) case _ => - CaseWhen(Seq((this <=> whenExpr).expr, lit(thenExpr).expr)) + CaseWhen(Seq(lit(whenExpr).expr, lit(thenExpr).expr)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 7e283393d0563..951a4c09b1e8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -363,6 +363,18 @@ object functions { */ def not(e: Column): Column = !e + /** + * Case When Otherwise. + * {{{ + * people.select( when(people("age") === 18, "SELECTED").other("IGNORED") ) + * }}} + * + * @group normal_funcs + */ + def when(whenExpr: Any, thenExpr: Any): Column = { + CaseWhen(Seq(lit(whenExpr).expr, lit(thenExpr).expr)) + } + /** * Generate a random column with i.i.d. samples from U[0.0, 1.0]. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 26997c39224c8..69a6bc4aebb41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -257,22 +257,17 @@ class ColumnExpressionSuite extends QueryTest { Row(false, true) :: Row(true, false) :: Row(true, true) :: Nil) } - test("SPARK-7321 case") { + test("SPARK-7321 case when otherwise") { val testData = (1 to 3).map(i => TestData(i, i.toString)).toDF() checkAnswer( - testData.select($"key".when(1, -1).when(2, -2).otherwise(0)), + testData.select(when($"key" === 1, -1).when($"key" === 2, -2).otherwise(0)), Seq(Row(-1), Row(-2), Row(0)) ) checkAnswer( - testData.select($"key".when(1, -1).when(2, -2)), + testData.select(when($"key" === 1, -1).when($"key" === 2, -2)), Seq(Row(-1), Row(-2), Row(null)) ) - - checkAnswer( - testData.select($"key".otherwise(0)), - Seq(Row(0), Row(0), Row(0)) - ) } test("sqrt") { From 8218d0acc287565a62259691803fd13c84f651ba Mon Sep 17 00:00:00 2001 From: kaka1992 Date: Tue, 12 May 2015 10:28:49 +0800 Subject: [PATCH 3/7] Update --- python/pyspark/sql/dataframe.py | 7 +++++-- python/pyspark/sql/functions.py | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index bbfda34a045f3..b64cff0fbd550 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1414,11 +1414,14 @@ def between(self, lowerBound, upperBound): @ignore_unicode_prefix def when(self, whenExpr, thenExpr): - return self._jc.when(whenExpr, thenExpr) + if isinstance(whenExpr, Column): + jc = self._jc.when(whenExpr._jc, thenExpr) + return Column(jc) @ignore_unicode_prefix def otherwise(self, elseExpr): - return self._jc.otherwise(elseExpr) + jc = self._jc.otherwise(elseExpr) + return Column(jc) def __repr__(self): return 'Column<%s>' % self._jc.toString().encode('utf8') diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a2ba9375cf9be..b70005d6ed4a2 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -27,7 +27,7 @@ from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.types import StringType -from pyspark.sql.dataframe import Column, _to_java_column, _to_seq +from pyspark.sql.dataframe import Column, _to_java_column, _to_seq, _create_column_from_literal __all__ = [ @@ -154,7 +154,7 @@ def when(whenExpr, thenExpr): [Row(age=3), Row(age=None)] """ sc = SparkContext._active_spark_context - jc = sc._jvm.functions.when(whenExpr, thenExpr) + jc = sc._jvm.functions.when(whenExpr._jc, thenExpr) return Column(jc) def rand(seed=None): From 95724c6375e3f0fda4bef4f2d8c6a62811a196cc Mon Sep 17 00:00:00 2001 From: kaka1992 Date: Tue, 12 May 2015 10:38:21 +0800 Subject: [PATCH 4/7] Update --- python/pyspark/sql/dataframe.py | 2 ++ python/pyspark/sql/functions.py | 9 +++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b64cff0fbd550..272d05a911b6f 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1416,6 +1416,8 @@ def between(self, lowerBound, upperBound): def when(self, whenExpr, thenExpr): if isinstance(whenExpr, Column): jc = self._jc.when(whenExpr._jc, thenExpr) + else: + raise TypeError("whenExpr should be Column") return Column(jc) @ignore_unicode_prefix diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b70005d6ed4a2..e8dbbe1ddb30c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -27,7 +27,7 @@ from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.types import StringType -from pyspark.sql.dataframe import Column, _to_java_column, _to_seq, _create_column_from_literal +from pyspark.sql.dataframe import Column, _to_java_column, _to_seq __all__ = [ @@ -152,9 +152,14 @@ def when(whenExpr, thenExpr): [Row(age=3), Row(age=4)] >>> df.select(when(df.age == 2, 3).alias("age")).collect() [Row(age=3), Row(age=None)] + >>> df.select(when(df.age == 2, 3==3).alias("age")).collect() + [Row(age=True), Row(age=None)] """ sc = SparkContext._active_spark_context - jc = sc._jvm.functions.when(whenExpr._jc, thenExpr) + if isinstance(whenExpr, Column): + jc = sc._jvm.functions.when(whenExpr._jc, thenExpr) + else: + raise TypeError("whenExpr should be Column") return Column(jc) def rand(seed=None): From bfb9d9fcececbc682f2c5bb2c0586e56ea499adf Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 11 May 2015 22:24:43 -0700 Subject: [PATCH 5/7] Updated documentation and test cases. --- python/pyspark/sql/__init__.py | 2 + python/pyspark/sql/dataframe.py | 32 +++++++--- python/pyspark/sql/functions.py | 41 +++++++----- python/run-tests | 8 +-- .../scala/org/apache/spark/sql/Column.scala | 62 ++++++++++++++----- .../org/apache/spark/sql/functions.scala | 20 ++++-- .../spark/sql/ColumnExpressionSuite.scala | 13 +++- 7 files changed, 127 insertions(+), 51 deletions(-) diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index b60b991dd4d8b..7192c89b3dc7f 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -32,6 +32,8 @@ Aggregation methods, returned by :func:`DataFrame.groupBy`. - L{DataFrameNaFunctions} Methods for handling missing data (null values). + - L{DataFrameStatFunctions} + Methods for statistics functionality. - L{functions} List of built-in functions available for :class:`DataFrame`. - L{types} diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 605b9e44e1d93..ad58d7ed9da66 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1462,16 +1462,34 @@ def between(self, lowerBound, upperBound): return (self >= lowerBound) & (self <= upperBound) @ignore_unicode_prefix - def when(self, whenExpr, thenExpr): - if isinstance(whenExpr, Column): - jc = self._jc.when(whenExpr._jc, thenExpr) - else: - raise TypeError("whenExpr should be Column") + def when(self, condition, value): + """Evaluates a list of conditions and returns one of multiple possible result expressions. + If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. + + See :func:`pyspark.sql.functions.when` for example usage. + + :param condition: a boolean :class:`Column` expression. + :param value: a literal value, or a :class:`Column` expression. + + """ + sc = SparkContext._active_spark_context + if not isinstance(condition, Column): + raise TypeError("condition should be a Column") + v = value._jc if isinstance(value, Column) else value + jc = sc._jvm.functions.when(condition._jc, v) return Column(jc) @ignore_unicode_prefix - def otherwise(self, elseExpr): - jc = self._jc.otherwise(elseExpr) + def otherwise(self, value): + """Evaluates a list of conditions and returns one of multiple possible result expressions. + If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. + + See :func:`pyspark.sql.functions.when` for example usage. + + :param value: a literal value, or a :class:`Column` expression. + """ + v = value._jc if isinstance(value, Column) else value + jc = self._jc.otherwise(value) return Column(jc) def __repr__(self): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b603143062387..d91265ee0bec8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -32,13 +32,14 @@ __all__ = [ 'approxCountDistinct', + 'coalesce', 'countDistinct', 'monotonicallyIncreasingId', 'rand', 'randn', 'sparkPartitionId', - 'coalesce', - 'udf'] + 'udf', + 'when'] def _create_function(name, doc=""): @@ -237,21 +238,6 @@ def monotonicallyIncreasingId(): sc = SparkContext._active_spark_context return Column(sc._jvm.functions.monotonicallyIncreasingId()) -def when(whenExpr, thenExpr): - """ A case when otherwise expression. - >>> df.select(when(df.age == 2, 3).otherwise(4).alias("age")).collect() - [Row(age=3), Row(age=4)] - >>> df.select(when(df.age == 2, 3).alias("age")).collect() - [Row(age=3), Row(age=None)] - >>> df.select(when(df.age == 2, 3==3).alias("age")).collect() - [Row(age=True), Row(age=None)] - """ - sc = SparkContext._active_spark_context - if isinstance(whenExpr, Column): - jc = sc._jvm.functions.when(whenExpr._jc, thenExpr) - else: - raise TypeError("whenExpr should be Column") - return Column(jc) def rand(seed=None): """Generates a random column with i.i.d. samples from U[0.0, 1.0]. @@ -306,6 +292,27 @@ def struct(*cols): return Column(jc) +def when(condition, value): + """Evaluates a list of conditions and returns one of multiple possible result expressions. + If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. + + :param condition: a boolean :class:`Column` expression. + :param value: a literal value, or a :class:`Column` expression. + + >>> df.select(when(df['age'] == 2, 3).otherwise(4).alias("age")).collect() + [Row(age=3), Row(age=4)] + + >>> df.select(when(df.age == 2, df.age + 1).alias("age")).collect() + [Row(age=3), Row(age=None)] + """ + sc = SparkContext._active_spark_context + if not isinstance(condition, Column): + raise TypeError("condition should be a Column") + v = value._jc if isinstance(value, Column) else value + jc = sc._jvm.functions.when(condition._jc, v) + return Column(jc) + + class UserDefinedFunction(object): """ User defined function in Python diff --git a/python/run-tests b/python/run-tests index f9ca26467f17e..f235e7b80d646 100755 --- a/python/run-tests +++ b/python/run-tests @@ -136,11 +136,11 @@ fi echo "Testing with Python version:" $PYSPARK_PYTHON --version -run_core_tests +#run_core_tests run_sql_tests -run_mllib_tests -run_ml_tests -run_streaming_tests +#run_mllib_tests +#run_ml_tests +#run_streaming_tests # Try to test with Python 3 if [ $(which python3.4) ]; then diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 8fbd78b70b4a2..3b1d741f453d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -309,29 +309,59 @@ class Column(protected[sql] val expr: Expression) extends Logging { def eqNullSafe(other: Any): Column = this <=> other /** - * Case When Otherwise. + * Evaluates a list of conditions and returns one of multiple possible result expressions. + * If otherwise is not defined at the end, null is returned for unmatched conditions. + * * {{{ - * people.select( when(people("age") === 18, "SELECTED").other("IGNORED") ) + * // Example: encoding gender string column into integer. + * + * // Scala: + * people.select(when(people("gender") === "male", 0) + * .when(people("gender") === "female", 1) + * .otherwise(2)) + * + * // Java: + * people.select(when(col("gender").equalTo("male"), 0) + * .when(col("gender").equalTo("female"), 1) + * .otherwise(2)) * }}} * * @group expr_ops */ - def when(whenExpr: Any, thenExpr: Any):Column = { - this.expr match { - case CaseWhen(branches: Seq[Expression]) => - CaseWhen(branches ++ Seq(lit(whenExpr).expr, lit(thenExpr).expr)) - case _ => - CaseWhen(Seq(lit(whenExpr).expr, lit(thenExpr).expr)) - } + def when(condition: Column, value: Any):Column = this.expr match { + case CaseWhen(branches: Seq[Expression]) => + CaseWhen(branches ++ Seq(lit(condition).expr, lit(value).expr)) + case _ => + throw new IllegalArgumentException( + "when() can only be applied on a Column previously generated by when() function") } - def otherwise(elseExpr: Any):Column = { - this.expr match { - case CaseWhen(branches: Seq[Expression]) => - CaseWhen(branches :+ lit(elseExpr).expr) - case _ => - CaseWhen(Seq(lit(true).expr, lit(elseExpr).expr)) - } + /** + * Evaluates a list of conditions and returns one of multiple possible result expressions. + * If otherwise is not defined at the end, null is returned for unmatched conditions. + * + * {{{ + * // Example: encoding gender string column into integer. + * + * // Scala: + * people.select(when(people("gender") === "male", 0) + * .when(people("gender") === "female", 1) + * .otherwise(2)) + * + * // Java: + * people.select(when(col("gender").equalTo("male"), 0) + * .when(col("gender").equalTo("female"), 1) + * .otherwise(2)) + * }}} + * + * @group expr_ops + */ + def otherwise(value: Any):Column = this.expr match { + case CaseWhen(branches: Seq[Expression]) => + CaseWhen(branches :+ lit(value).expr) + case _ => + throw new IllegalArgumentException( + "otherwise() can only be applied on a Column previously generated by when() function") } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5cccf62d755b1..e6297581c3438 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -382,15 +382,27 @@ object functions { def not(e: Column): Column = !e /** - * Case When Otherwise. + * Evaluates a list of conditions and returns one of multiple possible result expressions. + * If otherwise is not defined at the end, null is returned for unmatched conditions. + * * {{{ - * people.select( when(people("age") === 18, "SELECTED").other("IGNORED") ) + * // Example: encoding gender string column into integer. + * + * // Scala: + * people.select(when(people("gender") === "male", 0) + * .when(people("gender") === "female", 1) + * .otherwise(2)) + * + * // Java: + * people.select(when(col("gender").equalTo("male"), 0) + * .when(col("gender").equalTo("female"), 1) + * .otherwise(2)) * }}} * * @group normal_funcs */ - def when(whenExpr: Any, thenExpr: Any): Column = { - CaseWhen(Seq(lit(whenExpr).expr, lit(thenExpr).expr)) + def when(condition: Column, value: Any): Column = { + CaseWhen(Seq(condition.expr, lit(value).expr)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 8d79f46396247..c10cd036fc729 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -255,17 +255,24 @@ class ColumnExpressionSuite extends QueryTest { Row(false, true) :: Row(true, false) :: Row(true, true) :: Nil) } - test("SPARK-7321 case when otherwise") { - val testData = (1 to 3).map(i => TestData(i, i.toString)).toDF() + test("SPARK-7321 when conditional statements") { + val testData = (1 to 3).map(i => (i, i.toString)).toDF("key", "value") + checkAnswer( testData.select(when($"key" === 1, -1).when($"key" === 2, -2).otherwise(0)), Seq(Row(-1), Row(-2), Row(0)) ) + // Without the ending otherwise, return null for unmatched conditions. + // Also test putting a non-literal value in the expression. checkAnswer( - testData.select(when($"key" === 1, -1).when($"key" === 2, -2)), + testData.select(when($"key" === 1, lit(0) - $"key").when($"key" === 2, -2)), Seq(Row(-1), Row(-2), Row(null)) ) + + intercept[IllegalArgumentException] { + $"key".when($"key" === 1, -1) + } } test("sqrt") { From 0455edae68e461897420a5553a3db755ef913c2f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 11 May 2015 22:27:55 -0700 Subject: [PATCH 6/7] Reset run-tests. --- python/run-tests | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/run-tests b/python/run-tests index f235e7b80d646..f9ca26467f17e 100755 --- a/python/run-tests +++ b/python/run-tests @@ -136,11 +136,11 @@ fi echo "Testing with Python version:" $PYSPARK_PYTHON --version -#run_core_tests +run_core_tests run_sql_tests -#run_mllib_tests -#run_ml_tests -#run_streaming_tests +run_mllib_tests +run_ml_tests +run_streaming_tests # Try to test with Python 3 if [ $(which python3.4) ]; then From 8f49201d7f57376bdf052e2620fd09e9d3c8d889 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 12 May 2015 12:08:54 -0700 Subject: [PATCH 7/7] Throw exception if otherwise is applied twice. --- .../src/main/scala/org/apache/spark/sql/Column.scala | 9 +++++++-- .../org/apache/spark/sql/ColumnExpressionSuite.scala | 7 ++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 3b1d741f453d3..5685923910e42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -358,10 +358,15 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def otherwise(value: Any):Column = this.expr match { case CaseWhen(branches: Seq[Expression]) => - CaseWhen(branches :+ lit(value).expr) + if (branches.size % 2 == 0) { + CaseWhen(branches :+ lit(value).expr) + } else { + throw new IllegalArgumentException( + "otherwise() can only be applied once on a Column previously generated by when()") + } case _ => throw new IllegalArgumentException( - "otherwise() can only be applied on a Column previously generated by when() function") + "otherwise() can only be applied on a Column previously generated by when()") } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index c10cd036fc729..269e185543059 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -270,9 +270,10 @@ class ColumnExpressionSuite extends QueryTest { Seq(Row(-1), Row(-2), Row(null)) ) - intercept[IllegalArgumentException] { - $"key".when($"key" === 1, -1) - } + // Test error handling for invalid expressions. + intercept[IllegalArgumentException] { $"key".when($"key" === 1, -1) } + intercept[IllegalArgumentException] { $"key".otherwise(-1) } + intercept[IllegalArgumentException] { when($"key" === 1, -1).otherwise(-1).otherwise(-1) } } test("sqrt") {