diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index c41a10c7b0f87..1bd70acbe1b7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -262,37 +262,97 @@ case class CaseWhenCodegen( // } // } // } + + val isNull = ctx.freshName("caseWhenIsNull") + val value = ctx.freshName("caseWhenValue") + val cases = branches.map { case (condExpr, valueExpr) => - val cond = condExpr.genCode(ctx) - val res = valueExpr.genCode(ctx) + val (condFunc, condIsNull, condValue) = genCodeForExpression(ctx, condExpr) + val (resFunc, resIsNull, resValue) = genCodeForExpression(ctx, valueExpr) s""" - ${cond.code} - if (!${cond.isNull} && ${cond.value}) { - ${res.code} - ${ev.isNull} = ${res.isNull}; - ${ev.value} = ${res.value}; + ${condFunc} + if (!${condIsNull} && ${condValue}) { + ${resFunc} + $isNull = ${resIsNull}; + $value = ${resValue}; } """ } - var generatedCode = cases.mkString("", "\nelse {\n", "\nelse {\n") + var isGlobalVariable = false + val (generatedIfThenElse, numBrankets) = if (cases.map(s => s.length).sum <= 1024) { + (cases.mkString("", "\nelse {\n", "\nelse {\n"), cases.length) + } else { + var numIfThen = 0 + var code = "" + cases.foreach { ifThen => + code += ifThen + "\nelse {\n" + numIfThen += 1 + + if (code.length > 1024 && + // Split these expressions only if they are created from a row object + (ctx.INPUT_ROW != null && ctx.currentVars == null)) { + val flag = "flag" + code += s" $flag = false;\n" + "}\n" * numIfThen + val funcName = ctx.freshName("caseWhenNestedIf") + val funcBody = + s""" + |private boolean $funcName(InternalRow ${ctx.INPUT_ROW}) { + | boolean $flag = true; + | $code + | return $flag; + |} + """.stripMargin + val fullFuncName = ctx.addNewFunction(funcName, funcBody) + isGlobalVariable = true + + code = s"if ($fullFuncName(${ctx.INPUT_ROW})) {\n// do nothing\n} else {\n" + numIfThen = 1 + } + } + (code, numIfThen) + } + var generatedCode = generatedIfThenElse elseValue.foreach { elseExpr => - val res = elseExpr.genCode(ctx) - generatedCode += - s""" - ${res.code} - ${ev.isNull} = ${res.isNull}; - ${ev.value} = ${res.value}; - """ + val (resFunc, resIsNull, resValue) = genCodeForExpression(ctx, elseExpr) + generatedCode += s""" + ${resFunc} + $isNull = ${resIsNull}; + $value = ${resValue}; + """ } - generatedCode += "}\n" * cases.size + generatedCode += "}\n" * numBrankets - ev.copy(code = s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $generatedCode""") + if (!isGlobalVariable) { + ev.copy(s""" + boolean $isNull = true; + ${ctx.javaType(dataType)} $value = ${ctx.defaultValue(dataType)}; + $generatedCode + """, isNull, value) + } else { + ctx.addMutableState("boolean", isNull, s"$isNull = false;") + ctx.addMutableState(ctx.javaType(dataType), value, + s"$value = ${ctx.defaultValue(dataType)};") + ev.copy(code = s""" + $generatedCode + boolean ${ev.isNull} = $isNull; + ${ctx.javaType(dataType)} ${ev.value} = $value; + """) + } + } + + def genCodeForExpression(ctx: CodegenContext, expression: Expression): + (String, String, String) = { + val ev = expression.genCode(ctx) + if (ev.code.length > 1024 && (ctx.INPUT_ROW != null && ctx.currentVars == null)) { + val (funcName, globalIsNull, globalValue) = + ctx.createAndAddFunction(ev, expression.dataType, "caseWhenElseExpr") + (s"$funcName(${ctx.INPUT_ROW});", globalIsNull, globalValue) + } else { + (ev.code, ev.isNull, ev.value) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 8f6289f00571c..d4cc9a6943666 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -380,4 +380,68 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { s"Incorrect Evaluation: expressions: $exprAnd, actual: $actualAnd, expected: $expectedAnd") } } + + test("SPARK-21413: split large case when into blocks due to JVM code size limit") { + val expectedInt = -2 + var exprInt: Expression = BoundReference(0, IntegerType, true) + val expectedStr = UTF8String.fromString("abc") + val exprStr: Expression = BoundReference(0, StringType, true) + + // Code size of condition or then expression is large + var expr1 = exprInt + for (i <- 1 to 10) { + expr1 = CaseWhen(Seq((EqualTo(expr1, Literal(i)), Literal(-1))), expr1).toCodegen() + } + val plan1 = GenerateMutableProjection.generate(Seq(expr1)) + val row1 = new GenericInternalRow(Array[Any](1)) + row1.setInt(0, expectedInt) + val actual1 = plan1(row1).toSeq(Seq(expr1.dataType)) + assert(actual1.length == 1) + val result1 = actual1(0) + if (!checkResult(result1, expectedInt, expr1.dataType)) { + fail(s"Incorrect Evaluation: expressions: $expr1, actual: $result1, expected: $expectedInt") + } + + // Code size of else expression is large + var expr2 = exprStr + for (i <- 1 to 512) { + expr2 = CaseWhen(Seq((EqualTo(exprStr, Literal(s"def$i")), Literal(s"xyz$i"))), expr2) + .toCodegen() + } + val plan2 = GenerateMutableProjection.generate(Seq(expr2)) + val row2 = new GenericInternalRow(Array[Any](1)) + row2.update(0, expectedStr) + val actual2 = plan2(row2).toSeq(Seq(expr2.dataType)) + assert(actual2.length == 1) + val result2 = actual2(0) + if (!checkResult(result2, expectedStr, expr2.dataType)) { + fail(s"Incorrect Evaluation: expressions: $expr2, actual: $result2, expected: $expectedStr") + } + + // total code size of conditional branches is large + val cases = (1 to 512).map(i => (EqualTo(exprStr, Literal(s"def$i")), Literal(s"xyz$i"))) + val expr3 = CaseWhen(cases, exprStr).toCodegen() + val plan3 = GenerateMutableProjection.generate(Seq(expr3)) + val row3 = new GenericInternalRow(Array[Any](1)) + row3.update(0, expectedStr) + val actual3 = plan3(row3).toSeq(Seq(expr3.dataType)) + assert(actual3.length == 1) + val result3 = actual3(0) + if (!checkResult(result3, expectedStr, expr3.dataType)) { + fail(s"Incorrect Evaluation: expressions: $expr3, actual: $result3, expected: $expectedStr") + } + + // total code size is small + val cases4 = Seq((EqualTo(exprStr, Literal("def")), Literal("xyz"))) + val expr4 = CaseWhen(cases4, exprStr).toCodegen() + val plan4 = GenerateMutableProjection.generate(Seq(expr4)) + val row4 = new GenericInternalRow(Array[Any](1)) + row4.update(0, expectedStr) + val actual4 = plan4(row4).toSeq(Seq(expr4.dataType)) + assert(actual4.length == 1) + val result4 = actual4(0) + if (!checkResult(result4, expectedStr, expr4.dataType)) { + fail(s"Incorrect Evaluation: expressions: $expr4, actual: $result4, expected: $expectedStr") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 644e72c893ceb..f5d5fa1fbc25e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2158,4 +2158,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val mean = result.select("DecimalCol").where($"summary" === "mean") assert(mean.collect().toSet === Set(Row("0.0345678900000000000000000000000000000"))) } + + // ignore end-to-end test since sbt test does not go to fallback path in whole-stage codegen + ignore("SPARK-21413: Multiple projections with CASE WHEN fails") { + val schema = StructType(StructField("a", IntegerType) :: Nil) + var df = spark.createDataFrame(sparkContext.parallelize(Seq(Row(1))), schema) + for (i <- 1 to 10) { + df = df.withColumn("a", when($"a" === 0, null).otherwise($"a")) + } + checkAnswer(df, Row(1)) + } }