Skip to content

Commit 95724c6

Browse files
committed
Update
1 parent 8218d0a commit 95724c6

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1416,6 +1416,8 @@ def between(self, lowerBound, upperBound):
14161416
def when(self, whenExpr, thenExpr):
14171417
if isinstance(whenExpr, Column):
14181418
jc = self._jc.when(whenExpr._jc, thenExpr)
1419+
else:
1420+
raise TypeError("whenExpr should be Column")
14191421
return Column(jc)
14201422

14211423
@ignore_unicode_prefix

python/pyspark/sql/functions.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
2828
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
2929
from pyspark.sql.types import StringType
30-
from pyspark.sql.dataframe import Column, _to_java_column, _to_seq, _create_column_from_literal
30+
from pyspark.sql.dataframe import Column, _to_java_column, _to_seq
3131

3232

3333
__all__ = [
@@ -152,9 +152,14 @@ def when(whenExpr, thenExpr):
152152
[Row(age=3), Row(age=4)]
153153
>>> df.select(when(df.age == 2, 3).alias("age")).collect()
154154
[Row(age=3), Row(age=None)]
155+
>>> df.select(when(df.age == 2, 3==3).alias("age")).collect()
156+
[Row(age=True), Row(age=None)]
155157
"""
156158
sc = SparkContext._active_spark_context
157-
jc = sc._jvm.functions.when(whenExpr._jc, thenExpr)
159+
if isinstance(whenExpr, Column):
160+
jc = sc._jvm.functions.when(whenExpr._jc, thenExpr)
161+
else:
162+
raise TypeError("whenExpr should be Column")
158163
return Column(jc)
159164

160165
def rand(seed=None):

0 commit comments

Comments
 (0)