Skip to content

Commit f9c20be

Browse files
committed
review comments
1 parent 6225c8e commit f9c20be

File tree

2 files changed

+28
-17
lines changed

2 files changed

+28
-17
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,12 @@ case class CaseWhen(
233233
}
234234

235235
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
236+
// This variable represents whether the first successful condition is met or not.
237+
// It is initialized to `false` and it is set to `true` when the first condition which
238+
// evaluates to `true` is met and therefore is not needed to go on anymore on the computation
239+
// of the following conditions.
236240
val conditionMet = ctx.freshName("caseWhenConditionMet")
237-
ctx.addMutableState("boolean", ev.isNull, "")
241+
ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull, "")
238242
ctx.addMutableState(ctx.javaType(dataType), ev.value, "")
239243
val cases = branches.map { case (condExpr, valueExpr) =>
240244
val cond = condExpr.genCode(ctx)
@@ -266,21 +270,28 @@ case class CaseWhen(
266270
val allConditions = cases ++ elseCode
267271

268272
val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
269-
allConditions.mkString("\n")
270-
} else {
271-
ctx.splitExpressions(allConditions, "caseWhen",
272-
("InternalRow", ctx.INPUT_ROW) :: ("boolean", conditionMet) :: Nil, returnType = "boolean",
273-
makeSplitFunction = {
274-
func =>
275-
s"""
276-
$func
277-
return $conditionMet;
278-
"""
279-
},
280-
foldFunctions = { funcCalls =>
281-
funcCalls.map(funcCall => s"$conditionMet = $funcCall;").mkString("\n")
282-
})
283-
}
273+
allConditions.mkString("\n")
274+
} else {
275+
ctx.splitExpressions(allConditions, "caseWhen",
276+
("InternalRow", ctx.INPUT_ROW) :: (ctx.JAVA_BOOLEAN, conditionMet) :: Nil,
277+
returnType = ctx.JAVA_BOOLEAN,
278+
makeSplitFunction = {
279+
func =>
280+
s"""
281+
$func
282+
return $conditionMet;
283+
"""
284+
},
285+
foldFunctions = { funcCalls =>
286+
funcCalls.map { funcCall =>
287+
s"""
288+
$conditionMet = $funcCall;
289+
if ($conditionMet) {
290+
continue;
291+
}"""
292+
}.mkString("do {", "", "\n} while (false);")
293+
})
294+
}
284295

285296
ev.copy(code = s"""
286297
${ev.isNull} = true;

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2136,7 +2136,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
21362136
expr2 = expr2.when($"id" === lit(i + 10), i)
21372137
}
21382138
val df = spark.range(1).select(expr1, expr2.otherwise(0))
2139-
df.show
2139+
checkAnswer(df, Row(0, 10) :: Nil)
21402140
assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
21412141
}
21422142
}

0 commit comments

Comments
 (0)