diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index c4e5b844299a6..190c5bc416940 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -57,20 +57,70 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi val trueEval = trueValue.gen(ctx) val falseEval = falseValue.gen(ctx) - s""" - ${condEval.code} - boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${condEval.isNull} && ${condEval.value}) { - ${trueEval.code} - ${ev.isNull} = ${trueEval.isNull}; - ${ev.value} = ${trueEval.value}; - } else { - ${falseEval.code} - ${ev.isNull} = ${falseEval.isNull}; - ${ev.value} = ${falseEval.value}; - } - """ + // place generated code of condition, true value and false value in separate methods if + // their code combined is large + val combinedLength = condEval.code.length + trueEval.code.length + falseEval.code.length + if (combinedLength > 1024) { + val (condFuncName, condGlobalIsNull, condGlobalValue) = + createAndAddFunction(ctx, condEval, predicate.dataType, "evalIfCondExpr") + val (trueFuncName, trueGlobalIsNull, trueGlobalValue) = + createAndAddFunction(ctx, trueEval, trueValue.dataType, "evalIfTrueExpr") + val (falseFuncName, falseGlobalIsNull, falseGlobalValue) = + createAndAddFunction(ctx, falseEval, falseValue.dataType, "evalIfFalseExpr") + s""" + $condFuncName(${ctx.INPUT_ROW}); + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!$condGlobalIsNull && $condGlobalValue) { + $trueFuncName(${ctx.INPUT_ROW}); + ${ev.isNull} = $trueGlobalIsNull; + ${ev.value} = $trueGlobalValue; + } else { + $falseFuncName(${ctx.INPUT_ROW}); + ${ev.isNull} = $falseGlobalIsNull; + ${ev.value} = $falseGlobalValue; + } + """ + } + else { + s""" + ${condEval.code} + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${condEval.isNull} && ${condEval.value}) { + ${trueEval.code} + ${ev.isNull} = ${trueEval.isNull}; + ${ev.value} = ${trueEval.value}; + } else { + ${falseEval.code} + ${ev.isNull} = ${falseEval.isNull}; + ${ev.value} = ${falseEval.value}; + } + """ + } + } + + private def createAndAddFunction( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + dataType: DataType, + baseFuncName: String): (String, String, String) = { + val globalIsNull = ctx.freshName("isNull") + ctx.addMutableState("boolean", globalIsNull, s"$globalIsNull = false;") + val globalValue = ctx.freshName("value") + ctx.addMutableState(ctx.javaType(dataType), globalValue, + s"$globalValue = ${ctx.defaultValue(dataType)};") + val funcName = ctx.freshName(baseFuncName) + val funcBody = + s""" + |private void $funcName(InternalRow ${ctx.INPUT_ROW}) { + | ${ev.code.trim} + | $globalIsNull = ${ev.isNull}; + | $globalValue = ${ev.value}; + |} + """.stripMargin + ctx.addNewFunction(funcName, funcBody) + (funcName, globalIsNull, globalValue) } override def toString: String = s"if ($predicate) $trueValue else $falseValue" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index f4b0cdc4c7b74..aaa5b859d76f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -414,6 +414,8 @@ case class MapObjects private( override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val javaType = ctx.javaType(dataType) val elementJavaType = ctx.javaType(loopVar.dataType) + ctx.addMutableState("boolean", loopVar.isNull, "") + ctx.addMutableState(elementJavaType, loopVar.value, "") val genInputData = inputData.gen(ctx) val genFunction = lambdaFunction.gen(ctx) val dataLength = ctx.freshName("dataLength") @@ -434,9 +436,9 @@ case class MapObjects private( } val loopNullCheck = if (primitiveElement) { - s"boolean ${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);" + s"${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);" } else { - s"boolean ${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;" + s"${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;" } s""" @@ -452,7 +454,7 @@ case class MapObjects private( int $loopIndex = 0; while ($loopIndex < $dataLength) { - $elementJavaType ${loopVar.value} = + ${loopVar.value} = ($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)}; $loopNullCheck diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index e35a1b2d7c9a4..030c99c5489e3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -80,6 +80,21 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(actual(0) == cases) } + test("SPARK-18091: split large if expressions into blocks due to JVM code size limit") { + var strExpr: Expression = Literal("abc") + for (_ <- 1 to 100) { + strExpr = Decode(Encode(strExpr, "utf-8"), "utf-8") + } + + val expressions = Seq(If(EqualTo(strExpr, strExpr), strExpr, strExpr)) + val plan = GenerateMutableProjection.generate(expressions)() + val actual = plan(null).toSeq(expressions.map(_.dataType)) + val expected = Seq(UTF8String.fromString("abc")) + + if (!checkResult(actual, expected)) { + fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") + } + } test("test generated safe and unsafe projection") { val schema = new StructType(Array( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index baa258ad26152..f376c2b7a4848 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1464,4 +1464,18 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } } + test("wide nested json table") { + val nested = (1 to 100).map { i => + s""" + |"c$i": $i + """.stripMargin + }.mkString(", ") + val json = s""" + |{"a": [{$nested}], "b": [{$nested}]} + """.stripMargin + val rdd = sqlContext.sparkContext.makeRDD(Seq(json)) + val df = sqlContext.read.json(rdd) + assert(df.schema.size === 2) + df.collect() + } }