From 0f4fe62ce18b9ed7179940947ee737acde9376bd Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 23 May 2018 06:32:12 +0000 Subject: [PATCH 1/2] Initial pass to forbid string interpolation. --- .../catalyst/expressions/BoundAttribute.scala | 8 +- .../spark/sql/catalyst/expressions/Cast.scala | 368 ++++++------- .../sql/catalyst/expressions/Expression.scala | 49 +- .../MonotonicallyIncreasingID.scala | 12 +- .../sql/catalyst/expressions/ScalaUDF.scala | 40 +- .../sql/catalyst/expressions/SortOrder.scala | 28 +- .../expressions/SparkPartitionID.scala | 9 +- .../sql/catalyst/expressions/TimeWindow.scala | 5 +- .../sql/catalyst/expressions/arithmetic.scala | 107 ++-- .../expressions/bitwiseExpressions.scala | 10 +- .../expressions/codegen/CodeGenerator.scala | 341 +++++++----- .../expressions/codegen/CodegenFallback.scala | 15 +- .../codegen/GenerateMutableProjection.scala | 19 +- .../codegen/GenerateOrdering.scala | 47 +- .../codegen/GenerateSafeProjection.scala | 53 +- .../codegen/GenerateUnsafeProjection.scala | 114 ++-- .../codegen/GenerateUnsafeRowJoiner.scala | 51 +- .../expressions/codegen/javaCode.scala | 114 +++- .../expressions/collectionOperations.scala | 494 ++++++++++-------- .../expressions/complexTypeCreator.scala | 49 +- .../expressions/complexTypeExtractors.scala | 51 +- .../expressions/conditionalExpressions.scala | 30 +- .../expressions/datetimeExpressions.scala | 198 +++---- .../expressions/decimalExpressions.scala | 11 +- .../sql/catalyst/expressions/generators.scala | 8 +- .../spark/sql/catalyst/expressions/hash.scala | 257 ++++----- .../catalyst/expressions/inputFileBlock.scala | 22 +- .../expressions/mathExpressions.scala | 108 ++-- .../spark/sql/catalyst/expressions/misc.scala | 10 +- .../expressions/nullExpressions.scala | 62 ++- .../expressions/objects/objects.scala | 412 ++++++++------- .../sql/catalyst/expressions/predicates.scala | 62 +-- .../expressions/randomExpressions.scala | 12 +- .../expressions/regexpExpressions.scala | 91 ++-- .../expressions/stringExpressions.scala | 212 ++++---- .../catalyst/analysis/TypeCoercionSuite.scala | 6 +- .../expressions/codegen/CodeBlockSuite.scala | 23 +- .../sql/catalyst/trees/TreeNodeSuite.scala | 6 +- .../sql/execution/ColumnarBatchScan.scala | 18 +- .../sql/execution/DataSourceScanExec.scala | 4 +- .../spark/sql/execution/ExpandExec.scala | 6 +- .../spark/sql/execution/GenerateExec.scala | 45 +- .../sql/execution/WholeStageCodegenExec.scala | 8 +- .../aggregate/HashAggregateExec.scala | 38 +- .../aggregate/HashMapGenerator.scala | 8 +- .../aggregate/RowBasedHashMapGenerator.scala | 8 +- .../VectorizedHashMapGenerator.scala | 30 +- .../joins/BroadcastHashJoinExec.scala | 16 +- .../execution/joins/SortMergeJoinExec.scala | 52 +- .../spark/sql/GeneratorFunctionSuite.scala | 2 +- 50 files changed, 2073 insertions(+), 1676 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index df3ab05e02c7..13eb5982fe16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, JavaCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ @@ -53,8 +53,10 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) ev.copy(code = oev.code) } else { assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.") - val javaType = CodeGenerator.javaType(dataType) - val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) + val javaType = inline"${CodeGenerator.javaType(dataType)}" + val value = JavaCode.expression( + CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString), + dataType) if (nullable) { ev.copy(code = code""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 699ea53b5df0..f69d3da67394 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -625,25 +625,25 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val eval = child.genCode(ctx) val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) - ev.copy(code = + // Below the code comment including `eval.value` and `eval.isNull` is a trick. It makes the two + // expr values are referred by this code block. + ev.copy(code = eval.code + code""" - ${eval.code} - // This comment is added for manually tracking reference of ${eval.value}, ${eval.isNull} ${castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)} """) } // The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull` // in parameter list, because the returned code will be put in null safe evaluation region. - private[this] type CastFunction = (String, String, String) => String + private[this] type CastFunction = (ExprValue, ExprValue, ExprValue) => Block private[this] def nullSafeCastFunction( from: DataType, to: DataType, ctx: CodegenContext): CastFunction = to match { - case _ if from == NullType => (c, evPrim, evNull) => s"$evNull = true;" - case _ if to == from => (c, evPrim, evNull) => s"$evPrim = $c;" + case _ if from == NullType => (c, evPrim, evNull) => code"$evNull = true;" + case _ if to == from => (c, evPrim, evNull) => code"$evPrim = $c;" case StringType => castToStringCode(from, ctx) case BinaryType => castToBinaryCode(from) case DateType => castToDateCode(from, ctx) @@ -664,18 +664,19 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx) case udt: UserDefinedType[_] if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => - (c, evPrim, evNull) => s"$evPrim = $c;" + (c, evPrim, evNull) => code"$evPrim = $c;" case _: UserDefinedType[_] => throw new SparkException(s"Cannot cast $from to $to.") } // Since we need to cast input expressions recursively inside ComplexTypes, such as Map's // Key and Value, Struct's field, we need to name out all the variable names involved in a cast. - private[this] def castCode(ctx: CodegenContext, input: String, inputIsNull: String, - result: String, resultIsNull: String, resultType: DataType, cast: CastFunction): String = { - s""" + private[this] def castCode(ctx: CodegenContext, input: ExprValue, inputIsNull: ExprValue, + result: ExprValue, resultIsNull: ExprValue, resultType: DataType, cast: CastFunction): Block = { + val javaType = inline"${CodeGenerator.javaType(resultType)}" + code""" boolean $resultIsNull = $inputIsNull; - ${CodeGenerator.javaType(resultType)} $result = ${CodeGenerator.defaultValue(resultType)}; + $javaType $result = ${CodeGenerator.defaultValue(resultType)}; if (!$inputIsNull) { ${cast(input, result, resultIsNull)} } @@ -684,22 +685,24 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private def writeArrayToStringBuilder( et: DataType, - array: String, - buffer: String, - ctx: CodegenContext): String = { + array: ExprValue, + buffer: ExprValue, + ctx: CodegenContext): Block = { val elementToStringCode = castToStringCode(et, ctx) val funcName = ctx.freshName("elementToString") - val elementToStringFunc = ctx.addNewFunction(funcName, + val element = JavaCode.variable("element", et) + val elementStr = JavaCode.variable("elementStr", StringType) + val elementToStringFunc = inline"${ctx.addNewFunction(funcName, s""" - |private UTF8String $funcName(${CodeGenerator.javaType(et)} element) { - | UTF8String elementStr = null; - | ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)} + |private UTF8String $funcName(${CodeGenerator.javaType(et)} $element) { + | UTF8String $elementStr = null; + | ${elementToStringCode(element, elementStr, null /* resultIsNull won't be used */)} | return elementStr; |} - """.stripMargin) + """.stripMargin)}" - val loopIndex = ctx.freshName("loopIndex") - s""" + val loopIndex = JavaCode.variable(ctx.freshName("loopIndex"), IntegerType) + code""" |$buffer.append("["); |if ($array.numElements() > 0) { | if (!$array.isNullAt(0)) { @@ -720,31 +723,36 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private def writeMapToStringBuilder( kt: DataType, vt: DataType, - map: String, - buffer: String, - ctx: CodegenContext): String = { + map: ExprValue, + buffer: ExprValue, + ctx: CodegenContext): Block = { def dataToStringFunc(func: String, dataType: DataType) = { val funcName = ctx.freshName(func) val dataToStringCode = castToStringCode(dataType, ctx) + val data = JavaCode.variable("data", dataType) + val dataStr = JavaCode.variable("dataStr", StringType) ctx.addNewFunction(funcName, s""" - |private UTF8String $funcName(${CodeGenerator.javaType(dataType)} data) { - | UTF8String dataStr = null; - | ${dataToStringCode("data", "dataStr", null /* resultIsNull won't be used */)} + |private UTF8String $funcName(${CodeGenerator.javaType(dataType)} $data) { + | UTF8String $dataStr = null; + | ${dataToStringCode(data, dataStr, null /* resultIsNull won't be used */)} | return dataStr; |} """.stripMargin) } - val keyToStringFunc = dataToStringFunc("keyToString", kt) - val valueToStringFunc = dataToStringFunc("valueToString", vt) - val loopIndex = ctx.freshName("loopIndex") - val getMapFirstKey = CodeGenerator.getValue(s"$map.keyArray()", kt, "0") - val getMapFirstValue = CodeGenerator.getValue(s"$map.valueArray()", vt, "0") - val getMapKeyArray = CodeGenerator.getValue(s"$map.keyArray()", kt, loopIndex) - val getMapValueArray = CodeGenerator.getValue(s"$map.valueArray()", vt, loopIndex) - s""" + val keyToStringFunc = inline"${dataToStringFunc("keyToString", kt)}" + val valueToStringFunc = inline"${dataToStringFunc("valueToString", vt)}" + val loopIndex = JavaCode.variable(ctx.freshName("loopIndex"), IntegerType) + val mapKeyArray = JavaCode.expression(s"$map.keyArray()", classOf[ArrayData]) + val mapValueArray = JavaCode.expression(s"$map.valueArray()", classOf[ArrayData]) + val getMapFirstKey = CodeGenerator.getValue(mapKeyArray, kt, JavaCode.literal("0", IntegerType)) + val getMapFirstValue = CodeGenerator.getValue(mapValueArray, vt, + JavaCode.literal("0", IntegerType)) + val getMapKeyArray = CodeGenerator.getValue(mapKeyArray, kt, loopIndex) + val getMapValueArray = CodeGenerator.getValue(mapValueArray, vt, loopIndex) + code""" |$buffer.append("["); |if ($map.numElements() > 0) { | $buffer.append($keyToStringFunc($getMapFirstKey)); @@ -769,20 +777,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private def writeStructToStringBuilder( st: Seq[DataType], - row: String, - buffer: String, - ctx: CodegenContext): String = { + row: ExprValue, + buffer: ExprValue, + ctx: CodegenContext): Block = { val structToStringCode = st.zipWithIndex.map { case (ft, i) => val fieldToStringCode = castToStringCode(ft, ctx) - val field = ctx.freshName("field") - val fieldStr = ctx.freshName("fieldStr") - s""" - |${if (i != 0) s"""$buffer.append(",");""" else ""} + val field = JavaCode.variable(ctx.freshName("field"), ft) + val fieldStr = JavaCode.variable(ctx.freshName("fieldStr"), StringType) + val javaType = inline"${CodeGenerator.javaType(ft)}" + code""" + |${if (i != 0) code"""$buffer.append(",");""" else EmptyBlock} |if (!$row.isNullAt($i)) { - | ${if (i != 0) s"""$buffer.append(" ");""" else ""} + | ${if (i != 0) code"""$buffer.append(" ");""" else EmptyBlock} | | // Append $i field into the string buffer - | ${CodeGenerator.javaType(ft)} $field = ${CodeGenerator.getValue(row, ft, s"$i")}; + | $javaType $field = ${CodeGenerator.getValue(row, ft, s"$i")}; | UTF8String $fieldStr = null; | ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)} | $buffer.append($fieldStr); @@ -793,9 +802,9 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val writeStructCode = ctx.splitExpressions( expressions = structToStringCode, funcName = "fieldToString", - arguments = ("InternalRow", row) :: (classOf[UTF8StringBuilder].getName, buffer) :: Nil) + arguments = (row) :: (buffer) :: Nil) - s""" + code""" |$buffer.append("["); |$writeStructCode |$buffer.append("]"); @@ -805,20 +814,20 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => - (c, evPrim, evNull) => s"$evPrim = UTF8String.fromBytes($c);" + (c, evPrim, evNull) => code"$evPrim = UTF8String.fromBytes($c);" case DateType => - (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( + (c, evPrim, evNull) => code"""$evPrim = UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c));""" case TimestampType => - val tz = ctx.addReferenceObj("timeZone", timeZone) - (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), timeZone.getClass) + (c, evPrim, evNull) => code"""$evPrim = UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));""" case ArrayType(et, _) => (c, evPrim, evNull) => { - val buffer = ctx.freshName("buffer") - val bufferClass = classOf[UTF8StringBuilder].getName + val buffer = JavaCode.variable(ctx.freshName("buffer"), classOf[UTF8StringBuilder]) + val bufferClass = inline"${classOf[UTF8StringBuilder].getName}" val writeArrayElemCode = writeArrayToStringBuilder(et, c, buffer, ctx) - s""" + code""" |$bufferClass $buffer = new $bufferClass(); |$writeArrayElemCode; |$evPrim = $buffer.build(); @@ -826,10 +835,10 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } case MapType(kt, vt, _) => (c, evPrim, evNull) => { - val buffer = ctx.freshName("buffer") - val bufferClass = classOf[UTF8StringBuilder].getName + val buffer = JavaCode.variable(ctx.freshName("buffer"), classOf[UTF8StringBuilder]) + val bufferClass = inline"${classOf[UTF8StringBuilder].getName}" val writeMapElemCode = writeMapToStringBuilder(kt, vt, c, buffer, ctx) - s""" + code""" |$bufferClass $buffer = new $bufferClass(); |$writeMapElemCode; |$evPrim = $buffer.build(); @@ -837,11 +846,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } case StructType(fields) => (c, evPrim, evNull) => { - val row = ctx.freshName("row") - val buffer = ctx.freshName("buffer") - val bufferClass = classOf[UTF8StringBuilder].getName + val row = JavaCode.variable(ctx.freshName("row"), classOf[InternalRow]) + val buffer = JavaCode.variable(ctx.freshName("buffer"), classOf[UTF8StringBuilder]) + val bufferClass = inline"${classOf[UTF8StringBuilder].getName}" val writeStructCode = writeStructToStringBuilder(fields.map(_.dataType), row, buffer, ctx) - s""" + code""" |InternalRow $row = $c; |$bufferClass $buffer = new $bufferClass(); |$writeStructCode @@ -850,26 +859,26 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } case pudt: PythonUserDefinedType => castToStringCode(pudt.sqlType, ctx) case udt: UserDefinedType[_] => - val udtRef = ctx.addReferenceObj("udt", udt) + val udtRef = JavaCode.global(ctx.addReferenceObj("udt", udt), udt.sqlType) (c, evPrim, evNull) => { - s"$evPrim = UTF8String.fromString($udtRef.deserialize($c).toString());" + code"$evPrim = UTF8String.fromString($udtRef.deserialize($c).toString());" } case _ => - (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" + (c, evPrim, evNull) => code"$evPrim = UTF8String.fromString(String.valueOf($c));" } } private[this] def castToBinaryCode(from: DataType): CastFunction = from match { case StringType => - (c, evPrim, evNull) => s"$evPrim = $c.getBytes();" + (c, evPrim, evNull) => code"$evPrim = $c.getBytes();" } private[this] def castToDateCode( from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val intOpt = ctx.freshName("intOpt") - (c, evPrim, evNull) => s""" + val intOpt = JavaCode.variable(ctx.freshName("intOpt"), classOf[Option[Integer]]) + (c, evPrim, evNull) => code""" scala.Option $intOpt = org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToDate($c); if ($intOpt.isDefined()) { @@ -879,16 +888,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ case TimestampType => - val tz = ctx.addReferenceObj("timeZone", timeZone) + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), timeZone.getClass) (c, evPrim, evNull) => - s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToDays($c / 1000L, $tz);" + code"""$evPrim = + org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToDays($c / 1000L, $tz);""" case _ => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" } - private[this] def changePrecision(d: String, decimalType: DecimalType, - evPrim: String, evNull: String): String = - s""" + private[this] def changePrecision(d: ExprValue, decimalType: DecimalType, + evPrim: ExprValue, evNull: ExprValue): Block = + code""" if ($d.changePrecision(${decimalType.precision}, ${decimalType.scale})) { $evPrim = $d; } else { @@ -900,11 +910,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String from: DataType, target: DecimalType, ctx: CodegenContext): CastFunction = { - val tmp = ctx.freshName("tmpDecimal") + val tmp = JavaCode.variable(ctx.freshName("tmpDecimal"), classOf[Decimal]) from match { case StringType => (c, evPrim, evNull) => - s""" + code""" try { Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString())); ${changePrecision(tmp, target, evPrim, evNull)} @@ -914,37 +924,37 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """ case BooleanType => (c, evPrim, evNull) => - s""" + code""" Decimal $tmp = $c ? Decimal.apply(1) : Decimal.apply(0); ${changePrecision(tmp, target, evPrim, evNull)} """ case DateType => // date can't cast to decimal in Hive - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => // Note that we lose precision here. (c, evPrim, evNull) => - s""" + code""" Decimal $tmp = Decimal.apply( scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)})); ${changePrecision(tmp, target, evPrim, evNull)} """ case DecimalType() => (c, evPrim, evNull) => - s""" + code""" Decimal $tmp = $c.clone(); ${changePrecision(tmp, target, evPrim, evNull)} """ case x: IntegralType => (c, evPrim, evNull) => - s""" + code""" Decimal $tmp = Decimal.apply((long) $c); ${changePrecision(tmp, target, evPrim, evNull)} """ case x: FractionalType => // All other numeric types can be represented precisely as Doubles (c, evPrim, evNull) => - s""" + code""" try { Decimal $tmp = Decimal.apply(scala.math.BigDecimal.valueOf((double) $c)); ${changePrecision(tmp, target, evPrim, evNull)} @@ -959,10 +969,10 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val tz = ctx.addReferenceObj("timeZone", timeZone) - val longOpt = ctx.freshName("longOpt") + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), timeZone.getClass) + val longOpt = JavaCode.variable(ctx.freshName("longOpt"), classOf[Option[Long]]) (c, evPrim, evNull) => - s""" + code""" scala.Option $longOpt = org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToTimestamp($c, $tz); if ($longOpt.isDefined()) { @@ -972,18 +982,19 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;" + (c, evPrim, evNull) => code"$evPrim = $c ? 1L : 0L;" case _: IntegralType => - (c, evPrim, evNull) => s"$evPrim = ${longToTimeStampCode(c)};" + (c, evPrim, evNull) => code"$evPrim = ${longToTimeStampCode(c)};" case DateType => - val tz = ctx.addReferenceObj("timeZone", timeZone) + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), timeZone.getClass) (c, evPrim, evNull) => - s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMillis($c, $tz) * 1000;" + code"""$evPrim = + org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMillis($c, $tz) * 1000;""" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = ${decimalToTimestampCode(c)};" + (c, evPrim, evNull) => code"$evPrim = ${decimalToTimestampCode(c)};" case DoubleType => (c, evPrim, evNull) => - s""" + code""" if (Double.isNaN($c) || Double.isInfinite($c)) { $evNull = true; } else { @@ -992,7 +1003,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """ case FloatType => (c, evPrim, evNull) => - s""" + code""" if (Float.isNaN($c) || Float.isInfinite($c)) { $evNull = true; } else { @@ -1004,7 +1015,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToIntervalCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s"""$evPrim = CalendarInterval.fromString($c.toString()); + code"""$evPrim = CalendarInterval.fromString($c.toString()); if(${evPrim} == null) { ${evNull} = true; } @@ -1012,18 +1023,22 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } - private[this] def decimalToTimestampCode(d: String): String = - s"($d.toBigDecimal().bigDecimal().multiply(new java.math.BigDecimal(1000000L))).longValue()" - private[this] def longToTimeStampCode(l: String): String = s"$l * 1000000L" - private[this] def timestampToIntegerCode(ts: String): String = - s"java.lang.Math.floor((double) $ts / 1000000L)" - private[this] def timestampToDoubleCode(ts: String): String = s"$ts / 1000000.0" + private[this] def decimalToTimestampCode(d: ExprValue): ExprValue = + JavaCode.expression( + s"($d.toBigDecimal().bigDecimal().multiply(new java.math.BigDecimal(1000000L))).longValue()", + TimestampType) + private[this] def longToTimeStampCode(l: ExprValue): ExprValue = + JavaCode.expression(s"$l * 1000000L", TimestampType) + private[this] def timestampToIntegerCode(ts: ExprValue): ExprValue = + JavaCode.expression(s"java.lang.Math.floor((double) $ts / 1000000L)", IntegerType) + private[this] def timestampToDoubleCode(ts: ExprValue): ExprValue = + JavaCode.expression(s"$ts / 1000000.0", DoubleType) private[this] def castToBooleanCode(from: DataType): CastFunction = from match { case StringType => - val stringUtils = StringUtils.getClass.getName.stripSuffix("$") + val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}" (c, evPrim, evNull) => - s""" + code""" if ($stringUtils.isTrueString($c)) { $evPrim = true; } else if ($stringUtils.isFalseString($c)) { @@ -1033,21 +1048,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ case TimestampType => - (c, evPrim, evNull) => s"$evPrim = $c != 0;" + (c, evPrim, evNull) => code"$evPrim = $c != 0;" case DateType => // Hive would return null when cast from date to boolean - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = !$c.isZero();" + (c, evPrim, evNull) => code"$evPrim = !$c.isZero();" case n: NumericType => - (c, evPrim, evNull) => s"$evPrim = $c != 0;" + (c, evPrim, evNull) => code"$evPrim = $c != 0;" } private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("intWrapper") + val wrapper = JavaCode.variable(ctx.freshName("intWrapper"), classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => - s""" + code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); if ($c.toByte($wrapper)) { $evPrim = (byte) $wrapper.value; @@ -1057,24 +1072,24 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $wrapper = null; """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? (byte) 1 : (byte) 0;" + (c, evPrim, evNull) => code"$evPrim = $c ? (byte) 1 : (byte) 0;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = (byte) ${timestampToIntegerCode(c)};" + (c, evPrim, evNull) => code"$evPrim = (byte) ${timestampToIntegerCode(c)};" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toByte();" + (c, evPrim, evNull) => code"$evPrim = $c.toByte();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (byte) $c;" + (c, evPrim, evNull) => code"$evPrim = (byte) $c;" } private[this] def castToShortCode( from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("intWrapper") + val wrapper = JavaCode.variable(ctx.freshName("intWrapper"), classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => - s""" + code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); if ($c.toShort($wrapper)) { $evPrim = (short) $wrapper.value; @@ -1084,22 +1099,22 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $wrapper = null; """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? (short) 1 : (short) 0;" + (c, evPrim, evNull) => code"$evPrim = $c ? (short) 1 : (short) 0;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = (short) ${timestampToIntegerCode(c)};" + (c, evPrim, evNull) => code"$evPrim = (short) ${timestampToIntegerCode(c)};" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toShort();" + (c, evPrim, evNull) => code"$evPrim = $c.toShort();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (short) $c;" + (c, evPrim, evNull) => code"$evPrim = (short) $c;" } private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("intWrapper") + val wrapper = JavaCode.variable(ctx.freshName("intWrapper"), classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => - s""" + code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); if ($c.toInt($wrapper)) { $evPrim = $wrapper.value; @@ -1109,23 +1124,23 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $wrapper = null; """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + (c, evPrim, evNull) => code"$evPrim = $c ? 1 : 0;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = (int) ${timestampToIntegerCode(c)};" + (c, evPrim, evNull) => code"$evPrim = (int) ${timestampToIntegerCode(c)};" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toInt();" + (c, evPrim, evNull) => code"$evPrim = $c.toInt();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (int) $c;" + (c, evPrim, evNull) => code"$evPrim = (int) $c;" } private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("longWrapper") + val wrapper = JavaCode.variable(ctx.freshName("longWrapper"), classOf[UTF8String.LongWrapper]) (c, evPrim, evNull) => - s""" + code""" UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper(); if ($c.toLong($wrapper)) { $evPrim = $wrapper.value; @@ -1135,21 +1150,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $wrapper = null; """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;" + (c, evPrim, evNull) => code"$evPrim = $c ? 1L : 0L;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = (long) ${timestampToIntegerCode(c)};" + (c, evPrim, evNull) => code"$evPrim = (long) ${timestampToIntegerCode(c)};" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toLong();" + (c, evPrim, evNull) => code"$evPrim = $c.toLong();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (long) $c;" + (c, evPrim, evNull) => code"$evPrim = (long) $c;" } private[this] def castToFloatCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s""" + code""" try { $evPrim = Float.valueOf($c.toString()); } catch (java.lang.NumberFormatException e) { @@ -1157,21 +1172,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1.0f : 0.0f;" + (c, evPrim, evNull) => code"$evPrim = $c ? 1.0f : 0.0f;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = (float) (${timestampToDoubleCode(c)});" + (c, evPrim, evNull) => code"$evPrim = (float) (${timestampToDoubleCode(c)});" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toFloat();" + (c, evPrim, evNull) => code"$evPrim = $c.toFloat();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (float) $c;" + (c, evPrim, evNull) => code"$evPrim = (float) $c;" } private[this] def castToDoubleCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s""" + code""" try { $evPrim = Double.valueOf($c.toString()); } catch (java.lang.NumberFormatException e) { @@ -1179,31 +1194,32 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1.0d : 0.0d;" + (c, evPrim, evNull) => code"$evPrim = $c ? 1.0d : 0.0d;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = ${timestampToDoubleCode(c)};" + (c, evPrim, evNull) => code"$evPrim = ${timestampToDoubleCode(c)};" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toDouble();" + (c, evPrim, evNull) => code"$evPrim = $c.toDouble();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (double) $c;" + (c, evPrim, evNull) => code"$evPrim = (double) $c;" } private[this] def castArrayCode( fromType: DataType, toType: DataType, ctx: CodegenContext): CastFunction = { val elementCast = nullSafeCastFunction(fromType, toType, ctx) - val arrayClass = classOf[GenericArrayData].getName - val fromElementNull = ctx.freshName("feNull") - val fromElementPrim = ctx.freshName("fePrim") - val toElementNull = ctx.freshName("teNull") - val toElementPrim = ctx.freshName("tePrim") - val size = ctx.freshName("n") - val j = ctx.freshName("j") - val values = ctx.freshName("values") + val arrayClass = inline"${classOf[GenericArrayData].getName}" + val fromElementNull = JavaCode.isNullVariable(ctx.freshName("feNull")) + val fromElementPrim = JavaCode.variable(ctx.freshName("fePrim"), fromType) + val toElementNull = JavaCode.isNullVariable(ctx.freshName("teNull")) + val toElementPrim = JavaCode.variable(ctx.freshName("tePrim"), toType) + val size = JavaCode.variable(ctx.freshName("n"), IntegerType) + val j = JavaCode.variable(ctx.freshName("j"), IntegerType) + val values = JavaCode.variable(ctx.freshName("values"), classOf[Array[Object]]) + val javaType = inline"${CodeGenerator.javaType(fromType)}" (c, evPrim, evNull) => - s""" + code""" final int $size = $c.numElements(); final Object[] $values = new Object[$size]; for (int $j = 0; $j < $size; $j ++) { @@ -1211,7 +1227,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $values[$j] = null; } else { boolean $fromElementNull = false; - ${CodeGenerator.javaType(fromType)} $fromElementPrim = + $javaType $fromElementPrim = ${CodeGenerator.getValue(c, fromType, j)}; ${castCode(ctx, fromElementPrim, fromElementNull, toElementPrim, toElementNull, toType, elementCast)} @@ -1230,23 +1246,24 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val keysCast = castArrayCode(from.keyType, to.keyType, ctx) val valuesCast = castArrayCode(from.valueType, to.valueType, ctx) - val mapClass = classOf[ArrayBasedMapData].getName + val mapClass = inline"${classOf[ArrayBasedMapData].getName}" - val keys = ctx.freshName("keys") - val convertedKeys = ctx.freshName("convertedKeys") - val convertedKeysNull = ctx.freshName("convertedKeysNull") + val keys = JavaCode.variable(ctx.freshName("keys"), ArrayType(from.keyType)) + val convertedKeys = JavaCode.variable(ctx.freshName("convertedKeys"), ArrayType(to.keyType)) + val convertedKeysNull = JavaCode.isNullVariable(ctx.freshName("convertedKeysNull")) - val values = ctx.freshName("values") - val convertedValues = ctx.freshName("convertedValues") - val convertedValuesNull = ctx.freshName("convertedValuesNull") + val values = JavaCode.variable(ctx.freshName("values"), ArrayType(from.valueType)) + val convertedValues = JavaCode.variable(ctx.freshName("convertedValues"), + ArrayType(to.valueType)) + val convertedValuesNull = JavaCode.isNullVariable(ctx.freshName("convertedValuesNull")) (c, evPrim, evNull) => - s""" + code""" final ArrayData $keys = $c.keyArray(); final ArrayData $values = $c.valueArray(); - ${castCode(ctx, keys, "false", + ${castCode(ctx, keys, FalseLiteral, convertedKeys, convertedKeysNull, ArrayType(to.keyType), keysCast)} - ${castCode(ctx, values, "false", + ${castCode(ctx, values, FalseLiteral, convertedValues, convertedValuesNull, ArrayType(to.valueType), valuesCast)} $evPrim = new $mapClass($convertedKeys, $convertedValues); @@ -1259,17 +1276,18 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val fieldsCasts = from.fields.zip(to.fields).map { case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx) } - val rowClass = classOf[GenericInternalRow].getName - val tmpResult = ctx.freshName("tmpResult") - val tmpInput = ctx.freshName("tmpInput") + val tmpResult = JavaCode.variable(ctx.freshName("tmpResult"), classOf[GenericInternalRow]) + val rowClass = inline"${classOf[GenericInternalRow].getName}" + val tmpInput = JavaCode.variable(ctx.freshName("tmpInput"), classOf[InternalRow]) val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) => - val fromFieldPrim = ctx.freshName("ffp") - val fromFieldNull = ctx.freshName("ffn") - val toFieldPrim = ctx.freshName("tfp") - val toFieldNull = ctx.freshName("tfn") - val fromType = CodeGenerator.javaType(from.fields(i).dataType) - s""" + val fromFieldPrim = JavaCode.variable(ctx.freshName("ffp"), from.fields(i).dataType) + val fromFieldNull = JavaCode.isNullVariable(ctx.freshName("ffn")) + val toFieldPrim = JavaCode.variable(ctx.freshName("tfp"), to.fields(i).dataType) + val toFieldNull = JavaCode.isNullVariable(ctx.freshName("tfn")) + val fromType = inline"${CodeGenerator.javaType(from.fields(i).dataType)}" + val setColumn = CodeGenerator.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim) + code""" boolean $fromFieldNull = $tmpInput.isNullAt($i); if ($fromFieldNull) { $tmpResult.setNullAt($i); @@ -1281,7 +1299,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String if ($toFieldNull) { $tmpResult.setNullAt($i); } else { - ${CodeGenerator.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)}; + $setColumn; } } """ @@ -1289,10 +1307,10 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val fieldsEvalCodes = ctx.splitExpressions( expressions = fieldsEvalCode, funcName = "castStruct", - arguments = ("InternalRow", tmpInput) :: (rowClass, tmpResult) :: Nil) + arguments = (tmpInput) :: (tmpResult) :: Nil) (input, result, resultIsNull) => - s""" + code""" final $rowClass $tmpResult = new $rowClass(${fieldsCasts.length}); final InternalRow $tmpInput = $input; $fieldsEvalCodes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 9b9fa41a47d0..c0a18ced22b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -130,8 +130,8 @@ abstract class Expression extends TreeNode[Expression] { "" } - val javaType = CodeGenerator.javaType(dataType) - val newValue = ctx.freshName("value") + val javaType = inline"${CodeGenerator.javaType(dataType)}" + val newValue = JavaCode.variable(ctx.freshName("value"), dataType) val funcName = ctx.freshName(nodeName) val funcFullName = ctx.addNewFunction(funcName, @@ -143,8 +143,8 @@ abstract class Expression extends TreeNode[Expression] { |} """.stripMargin) - eval.value = JavaCode.variable(newValue, dataType) - eval.code = code"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" + eval.value = newValue + eval.code = code"$javaType $newValue = ${inline"$funcFullName"}(${ctx.INPUT_ROW});" } } @@ -416,9 +416,9 @@ abstract class UnaryExpression extends Expression { protected def defineCodeGen( ctx: CodegenContext, ev: ExprCode, - f: String => String): ExprCode = { + f: ExprValue => Block): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { - s"${ev.value} = ${f(eval)};" + code"${ev.value} = ${f(eval)};" }) } @@ -432,22 +432,23 @@ abstract class UnaryExpression extends Expression { protected def nullSafeCodeGen( ctx: CodegenContext, ev: ExprCode, - f: String => String): ExprCode = { + f: ExprValue => Block): ExprCode = { val childGen = child.genCode(ctx) val resultCode = f(childGen.value) + val javaType = inline"${CodeGenerator.javaType(dataType)}" if (nullable) { val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode) ev.copy(code = code""" ${childGen.code} boolean ${ev.isNull} = ${childGen.isNull}; - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { ev.copy(code = code""" ${childGen.code} - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $resultCode""", isNull = FalseLiteral) } } @@ -504,9 +505,9 @@ abstract class BinaryExpression extends Expression { protected def defineCodeGen( ctx: CodegenContext, ev: ExprCode, - f: (String, String) => String): ExprCode = { + f: (ExprValue, ExprValue) => Block): ExprCode = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - s"${ev.value} = ${f(eval1, eval2)};" + code"${ev.value} = ${f(eval1, eval2)};" }) } @@ -521,16 +522,17 @@ abstract class BinaryExpression extends Expression { protected def nullSafeCodeGen( ctx: CodegenContext, ev: ExprCode, - f: (String, String) => String): ExprCode = { + f: (ExprValue, ExprValue) => Block): ExprCode = { val leftGen = left.genCode(ctx) val rightGen = right.genCode(ctx) val resultCode = f(leftGen.value, rightGen.value) + val javaType = inline"${CodeGenerator.javaType(dataType)}" if (nullable) { val nullSafeEval = leftGen.code + ctx.nullSafeExec(left.nullable, leftGen.isNull) { rightGen.code + ctx.nullSafeExec(right.nullable, rightGen.isNull) { - s""" + code""" ${ev.isNull} = false; // resultCode could change nullability. $resultCode """ @@ -539,14 +541,14 @@ abstract class BinaryExpression extends Expression { ev.copy(code = code""" boolean ${ev.isNull} = true; - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { ev.copy(code = code""" ${leftGen.code} ${rightGen.code} - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $resultCode""", isNull = FalseLiteral) } } @@ -568,9 +570,9 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { */ def inputType: AbstractDataType - def symbol: String + def symbol: JavaCode - def sqlOperator: String = symbol + def sqlOperator: String = symbol.code override def toString: String = s"($left $symbol $right)" @@ -644,9 +646,9 @@ abstract class TernaryExpression extends Expression { protected def defineCodeGen( ctx: CodegenContext, ev: ExprCode, - f: (String, String, String) => String): ExprCode = { + f: (ExprValue, ExprValue, ExprValue) => Block): ExprCode = { nullSafeCodeGen(ctx, ev, (eval1, eval2, eval3) => { - s"${ev.value} = ${f(eval1, eval2, eval3)};" + code"${ev.value} = ${f(eval1, eval2, eval3)};" }) } @@ -661,18 +663,19 @@ abstract class TernaryExpression extends Expression { protected def nullSafeCodeGen( ctx: CodegenContext, ev: ExprCode, - f: (String, String, String) => String): ExprCode = { + f: (ExprValue, ExprValue, ExprValue) => Block): ExprCode = { val leftGen = children(0).genCode(ctx) val midGen = children(1).genCode(ctx) val rightGen = children(2).genCode(ctx) val resultCode = f(leftGen.value, midGen.value, rightGen.value) + val javaType = inline"${CodeGenerator.javaType(dataType)}" if (nullable) { val nullSafeEval = leftGen.code + ctx.nullSafeExec(children(0).nullable, leftGen.isNull) { midGen.code + ctx.nullSafeExec(children(1).nullable, midGen.isNull) { rightGen.code + ctx.nullSafeExec(children(2).nullable, rightGen.isNull) { - s""" + code""" ${ev.isNull} = false; // resultCode could change nullability. $resultCode """ @@ -682,14 +685,14 @@ abstract class TernaryExpression extends Expression { ev.copy(code = code""" boolean ${ev.isNull} = true; - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval""") } else { ev.copy(code = code""" ${leftGen.code} ${midGen.code} ${rightGen.code} - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $resultCode""", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index f1da592a7684..4182e215ab4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, JavaCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, LongType} @@ -67,14 +67,16 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Stateful { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count") - val partitionMaskTerm = "partitionMask" - ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_LONG, partitionMaskTerm) + val countTerm = JavaCode.variable( + ctx.addMutableState(CodeGenerator.JAVA_LONG, "count"), LongType) + val partitionMaskTerm = JavaCode.variable("partitionMask", LongType) + ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_LONG, partitionMaskTerm.code) ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") + val javaType = inline"${CodeGenerator.javaType(dataType)}" ev.copy(code = code""" - final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; + final $javaType ${ev.value} = $partitionMaskTerm + $countTerm; $countTerm++;""", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 3e7ca8824973..dcf47cf03428 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -997,9 +997,12 @@ case class ScalaUDF( val converters: Array[Any => Any] = children.map { c => CatalystTypeConverters.createToScalaConverter(c.dataType) }.toArray :+ CatalystTypeConverters.createToCatalystConverter(dataType) - val convertersTerm = ctx.addReferenceObj("converters", converters, s"$converterClassName[]") - val errorMsgTerm = ctx.addReferenceObj("errMsg", udfErrorMessage) - val resultTerm = ctx.freshName("result") + val convertersTerm = JavaCode.global( + ctx.addReferenceObj("converters", converters, s"$converterClassName[]"), + classOf[Array[Object]]) + val errorMsgTerm = JavaCode.global(ctx.addReferenceObj("errMsg", udfErrorMessage), + classOf[String]) + val resultTerm = JavaCode.variable(ctx.freshName("result"), dataType) // codegen for children expressions val evals = children.map(_.genCode(ctx)) @@ -1008,20 +1011,28 @@ case class ScalaUDF( // We need to get the boxedType of dataType's javaType here. Because for the dataType // such as IntegerType, its javaType is `int` and the returned type of user-defined // function is Object. Trying to convert an Object to `int` will cause casting exception. - val evalCode = evals.map(_.code).mkString("\n") + val evalCode = Blocks(evals.map(_.code)) val (funcArgs, initArgs) = evals.zipWithIndex.map { case (eval, i) => - val argTerm = ctx.freshName("arg") - val convert = s"$convertersTerm[$i].apply(${eval.value})" - val initArg = s"Object $argTerm = ${eval.isNull} ? null : $convert;" + val argTerm = JavaCode.variable(ctx.freshName("arg"), classOf[Object]) + val convert = code"$convertersTerm[$i].apply(${eval.value})" + val initArg = code"Object $argTerm = ${eval.isNull} ? null : $convert;" (argTerm, initArg) }.unzip - val udf = ctx.addReferenceObj("udf", function, s"scala.Function${children.length}") - val getFuncResult = s"$udf.apply(${funcArgs.mkString(", ")})" - val resultConverter = s"$convertersTerm[${children.length}]" - val boxedType = CodeGenerator.boxedType(dataType) + val udf = JavaCode.global( + ctx.addReferenceObj("udf", function, s"scala.Function${children.length}"), classOf[Object]) + val funcArgBlock = funcArgs.foldLeft[Block](EmptyBlock) { (block, arg) => + if (block.length == 0) { + code"$arg" + } else { + code"$block, $arg" + } + } + val getFuncResult = code"$udf.apply($funcArgBlock)" + val resultConverter = code"$convertersTerm[${children.length}]" + val boxedType = inline"${CodeGenerator.boxedType(dataType)}" val callFunc = - s""" + code""" |$boxedType $resultTerm = null; |try { | $resultTerm = ($boxedType)$resultConverter.apply($getFuncResult); @@ -1030,14 +1041,15 @@ case class ScalaUDF( |} """.stripMargin + val javaType = inline"${CodeGenerator.javaType(dataType)}" ev.copy(code = code""" |$evalCode - |${initArgs.mkString("\n")} + |${Blocks(initArgs)} |$callFunc | |boolean ${ev.isNull} = $resultTerm == null; - |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |if (!${ev.isNull}) { | ${ev.value} = $resultTerm; |} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 2ce9d072c71c..cbef40f0071b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, JavaCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._ @@ -153,32 +153,32 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childCode = child.child.genCode(ctx) val input = childCode.value - val BinaryPrefixCmp = classOf[BinaryPrefixComparator].getName - val DoublePrefixCmp = classOf[DoublePrefixComparator].getName - val StringPrefixCmp = classOf[StringPrefixComparator].getName + val BinaryPrefixCmp = inline"${classOf[BinaryPrefixComparator].getName}" + val DoublePrefixCmp = inline"${classOf[DoublePrefixComparator].getName}" + val StringPrefixCmp = inline"${classOf[StringPrefixComparator].getName}" val prefixCode = child.child.dataType match { case BooleanType => - s"$input ? 1L : 0L" + code"$input ? 1L : 0L" case _: IntegralType => - s"(long) $input" + code"(long) $input" case DateType | TimestampType => - s"(long) $input" + code"(long) $input" case FloatType | DoubleType => - s"$DoublePrefixCmp.computePrefix((double)$input)" - case StringType => s"$StringPrefixCmp.computePrefix($input)" - case BinaryType => s"$BinaryPrefixCmp.computePrefix($input)" + code"$DoublePrefixCmp.computePrefix((double)$input)" + case StringType => code"$StringPrefixCmp.computePrefix($input)" + case BinaryType => code"$BinaryPrefixCmp.computePrefix($input)" case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS => if (dt.precision <= Decimal.MAX_LONG_DIGITS) { - s"$input.toUnscaledLong()" + code"$input.toUnscaledLong()" } else { // reduce the scale to fit in a long val p = Decimal.MAX_LONG_DIGITS val s = p - (dt.precision - dt.scale) - s"$input.changePrecision($p, $s) ? $input.toUnscaledLong() : ${Long.MinValue}L" + code"$input.changePrecision($p, $s) ? $input.toUnscaledLong() : ${Long.MinValue}L" } case dt: DecimalType => - s"$DoublePrefixCmp.computePrefix($input.toDouble())" - case _ => "0L" + code"$DoublePrefixCmp.computePrefix($input.toDouble())" + case _ => code"0L" } ev.copy(code = childCode.code + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 9856b37e53fb..2aa6ca9e11d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, JavaCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, IntegerType} @@ -44,10 +44,11 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { override protected def evalInternal(input: InternalRow): Int = partitionId override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val idTerm = "partitionId" - ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm) + val idTerm = JavaCode.variable("partitionId", dataType) + ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm.code) ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") - ev.copy(code = code"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", + val javaType = inline"${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"final $javaType ${ev.value} = $idTerm;", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 84e38a8b2711..e4700ab372db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -22,7 +22,7 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, JavaCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -164,9 +164,10 @@ case class PreciseTimestampConversion( override def dataType: DataType = toType override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) + val javaType = inline"${CodeGenerator.javaType(dataType)}" ev.copy(code = eval.code + code"""boolean ${ev.isNull} = ${eval.isNull}; - |${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value}; + |$javaType ${ev.value} = ${eval.value}; """.stripMargin) } override def nullSafeEval(input: Any): Any = input diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index fe91e520169b..16a4d98f33b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -44,16 +44,17 @@ case class UnaryMinus(child: Expression) extends UnaryExpression private lazy val numeric = TypeUtils.getNumeric(dataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { - case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") + case _: DecimalType => defineCodeGen(ctx, ev, c => code"$c.unary_$$minus()") case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => { - val originValue = ctx.freshName("origin") + val originValue = JavaCode.variable(ctx.freshName("origin"), dt) + val javaType = inline"${CodeGenerator.javaType(dt)}" // codegen would fail to compile if we just write (-($c)) // for example, we could not write --9223372036854775808L in code - s""" - ${CodeGenerator.javaType(dt)} $originValue = (${CodeGenerator.javaType(dt)})($eval); - ${ev.value} = (${CodeGenerator.javaType(dt)})(-($originValue)); + code""" + $javaType $originValue = ($javaType)($eval); + ${ev.value} = ($javaType)(-($originValue)); """}) - case _: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") + case _: CalendarIntervalType => defineCodeGen(ctx, ev, c => code"$c.negate()") } protected override def nullSafeEval(input: Any): Any = { @@ -78,7 +79,7 @@ case class UnaryPositive(child: Expression) override def dataType: DataType = child.dataType override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = - defineCodeGen(ctx, ev, c => c) + defineCodeGen(ctx, ev, c => code"$c") protected override def nullSafeEval(input: Any): Any = input @@ -106,9 +107,10 @@ case class Abs(child: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { case _: DecimalType => - defineCodeGen(ctx, ev, c => s"$c.abs()") + defineCodeGen(ctx, ev, c => code"$c.abs()") case dt: NumericType => - defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dt)})(java.lang.Math.abs($c))") + val javaType = inline"${CodeGenerator.javaType(dt)}" + defineCodeGen(ctx, ev, c => code"($javaType)(java.lang.Math.abs($c))") } protected override def nullSafeEval(input: Any): Any = numeric.abs(input) @@ -121,24 +123,25 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess /** Name of the function for this expression on a [[Decimal]] type. */ - def decimalMethod: String = + def decimalMethod: JavaCode = sys.error("BinaryArithmetics must override either decimalMethod or genCode") /** Name of the function for this expression on a [[CalendarInterval]] type. */ - def calendarIntervalMethod: String = + def calendarIntervalMethod: JavaCode = sys.error("BinaryArithmetics must override either calendarIntervalMethod or genCode") override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { case _: DecimalType => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") + defineCodeGen(ctx, ev, (eval1, eval2) => code"$eval1.$decimalMethod($eval2)") case CalendarIntervalType => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$calendarIntervalMethod($eval2)") + defineCodeGen(ctx, ev, (eval1, eval2) => code"$eval1.$calendarIntervalMethod($eval2)") // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => + val javaType = inline"${CodeGenerator.javaType(dataType)}" defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") + (eval1, eval2) => code"($javaType)($eval1 $symbol $eval2)") case _ => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") + defineCodeGen(ctx, ev, (eval1, eval2) => code"$eval1 $symbol $eval2") } } @@ -157,11 +160,11 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = TypeCollection.NumericAndInterval - override def symbol: String = "+" + override def symbol: JavaCode = inline"+" - override def decimalMethod: String = "$plus" + override def decimalMethod: JavaCode = inline"$$plus" - override def calendarIntervalMethod: String = "add" + override def calendarIntervalMethod: JavaCode = inline"add" private lazy val numeric = TypeUtils.getNumeric(dataType) @@ -185,11 +188,11 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti override def inputType: AbstractDataType = TypeCollection.NumericAndInterval - override def symbol: String = "-" + override def symbol: JavaCode = inline"-" - override def decimalMethod: String = "$minus" + override def decimalMethod: JavaCode = inline"$$minus" - override def calendarIntervalMethod: String = "subtract" + override def calendarIntervalMethod: JavaCode = inline"subtract" private lazy val numeric = TypeUtils.getNumeric(dataType) @@ -213,8 +216,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti override def inputType: AbstractDataType = NumericType - override def symbol: String = "*" - override def decimalMethod: String = "$times" + override def symbol: JavaCode = inline"*" + override def decimalMethod: JavaCode = inline"$$times" private lazy val numeric = TypeUtils.getNumeric(dataType) @@ -249,15 +252,15 @@ trait DivModLike extends BinaryArithmetic { val eval1 = left.genCode(ctx) val eval2 = right.genCode(ctx) val isZero = if (dataType.isInstanceOf[DecimalType]) { - s"${eval2.value}.isZero()" + code"${eval2.value}.isZero()" } else { - s"${eval2.value} == 0" + code"${eval2.value} == 0" } - val javaType = CodeGenerator.javaType(dataType) + val javaType = inline"${CodeGenerator.javaType(dataType)}" val operation = if (dataType.isInstanceOf[DecimalType]) { - s"${eval1.value}.$decimalMethod(${eval2.value})" + code"${eval1.value}.$decimalMethod(${eval2.value})" } else { - s"($javaType)(${eval1.value} $symbol ${eval2.value})" + code"($javaType)(${eval1.value} $symbol ${eval2.value})" } if (!left.nullable && !right.nullable) { ev.copy(code = code""" @@ -304,8 +307,8 @@ case class Divide(left: Expression, right: Expression) extends DivModLike { override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) - override def symbol: String = "/" - override def decimalMethod: String = "$div" + override def symbol: JavaCode = inline"/" + override def decimalMethod: JavaCode = inline"$$div" private lazy val div: (Any, Any) => Any = dataType match { case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div @@ -327,8 +330,8 @@ case class Remainder(left: Expression, right: Expression) extends DivModLike { override def inputType: AbstractDataType = NumericType - override def symbol: String = "%" - override def decimalMethod: String = "remainder" + override def symbol: JavaCode = inline"%" + override def decimalMethod: JavaCode = inline"remainder" private lazy val mod: (Any, Any) => Any = dataType match { // special cases to make float/double primitive types faster @@ -362,7 +365,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { override def toString: String = s"pmod($left, $right)" - override def symbol: String = "pmod" + override def symbol: JavaCode = inline"pmod" protected def checkTypesInternal(t: DataType): TypeCheckResult = TypeUtils.checkForNumericExpr(t, "pmod") @@ -397,17 +400,17 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { val eval1 = left.genCode(ctx) val eval2 = right.genCode(ctx) val isZero = if (dataType.isInstanceOf[DecimalType]) { - s"${eval2.value}.isZero()" + code"${eval2.value}.isZero()" } else { - s"${eval2.value} == 0" + code"${eval2.value} == 0" } - val remainder = ctx.freshName("remainder") - val javaType = CodeGenerator.javaType(dataType) + val remainder = JavaCode.variable(ctx.freshName("remainder"), dataType) + val javaType = inline"${CodeGenerator.javaType(dataType)}" val result = dataType match { case DecimalType.Fixed(_, _) => - val decimalAdd = "$plus" - s""" + val decimalAdd = inline"$$plus" + code""" $javaType $remainder = ${eval1.value}.remainder(${eval2.value}); if ($remainder.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) { ${ev.value}=($remainder.$decimalAdd(${eval2.value})).remainder(${eval2.value}); @@ -417,7 +420,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { """ // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => - s""" + code""" $javaType $remainder = ($javaType)(${eval1.value} % ${eval2.value}); if ($remainder < 0) { ${ev.value}=($javaType)(($remainder + ${eval2.value}) % ${eval2.value}); @@ -426,7 +429,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { } """ case _ => - s""" + code""" $javaType $remainder = ${eval1.value} % ${eval2.value}; if ($remainder < 0) { ${ev.value}=($remainder + ${eval2.value}) % ${eval2.value}; @@ -551,24 +554,25 @@ case class Least(children: Seq[Expression]) extends Expression { val evalChildren = children.map(_.genCode(ctx)) ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull)) val evals = evalChildren.map(eval => - s""" + code""" |${eval.code} |${ctx.reassignIfSmaller(dataType, ev, eval)} """.stripMargin ) - val resultType = CodeGenerator.javaType(dataType) + val resultType = inline"${CodeGenerator.javaType(dataType)}" val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "least", - extraArguments = Seq(resultType -> ev.value), - returnType = resultType, + extraArguments = Seq(ev.value), + returnType = resultType.code, makeSplitFunction = body => s""" |$body |return ${ev.value}; """.stripMargin, - foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) + foldFunctions = + funcCalls => Blocks(funcCalls.map(funcCall => code"${ev.value} = $funcCall;"))) ev.copy(code = code""" |${ev.isNull} = true; @@ -626,24 +630,25 @@ case class Greatest(children: Seq[Expression]) extends Expression { val evalChildren = children.map(_.genCode(ctx)) ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull)) val evals = evalChildren.map(eval => - s""" + code""" |${eval.code} |${ctx.reassignIfGreater(dataType, ev, eval)} """.stripMargin ) - val resultType = CodeGenerator.javaType(dataType) + val resultType = inline"${CodeGenerator.javaType(dataType)}" val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "greatest", - extraArguments = Seq(resultType -> ev.value), - returnType = resultType, + extraArguments = Seq(ev.value), + returnType = resultType.code, makeSplitFunction = body => s""" |$body |return ${ev.value}; """.stripMargin, - foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) + foldFunctions = + funcCalls => Blocks(funcCalls.map(funcCall => code"${ev.value} = $funcCall;"))) ev.copy(code = code""" |${ev.isNull} = true; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index cc24e397cc14..3a626177df3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ @@ -37,7 +38,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme override def inputType: AbstractDataType = IntegralType - override def symbol: String = "&" + override def symbol: JavaCode = inline"&" private lazy val and: (Any, Any) => Any = dataType match { case ByteType => @@ -69,7 +70,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet override def inputType: AbstractDataType = IntegralType - override def symbol: String = "|" + override def symbol: JavaCode = inline"|" private lazy val or: (Any, Any) => Any = dataType match { case ByteType => @@ -101,7 +102,7 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme override def inputType: AbstractDataType = IntegralType - override def symbol: String = "^" + override def symbol: JavaCode = inline"^" private lazy val xor: (Any, Any) => Any = dataType match { case ByteType => @@ -147,7 +148,8 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dataType)}) ~($c)") + val javaType = inline"${CodeGenerator.javaType(dataType)}" + defineCodeGen(ctx, ev, c => code"($javaType) ~($c)") } protected override def nullSafeEval(input: Any): Any = not(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 66315e590625..e77c38cb4be8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import java.io.ByteArrayInputStream +import java.lang.{Boolean => JBool, Byte => JByte, Double => JDouble, Float => JFloat, Integer => JInt, Long => JLong, Short => JShort} import java.util.{Map => JavaMap} import scala.collection.JavaConverters._ @@ -146,7 +147,7 @@ class CodegenContext { * `currentVars` to null, or set `currentVars(i)` to null for certain columns, before calling * `Expression.genCode`. */ - var INPUT_ROW = "i" + var INPUT_ROW: ExprValue = JavaCode.variable("i", classOf[InternalRow]) /** * Holding a list of generated columns as input of current operator, will be used by @@ -329,11 +330,13 @@ class CodegenContext { * data types like: UTF8String, ArrayData, MapData & InternalRow. */ def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = { - val value = addMutableState(javaType(dataType), variableName) + val value = JavaCode.variable(addMutableState(javaType(dataType), variableName), + dataType) + val initExpr = JavaCode.expression(initCode, dataType) val code = dataType match { - case StringType => code"$value = $initCode.clone();" - case _: StructType | _: ArrayType | _: MapType => code"$value = $initCode.copy();" - case _ => code"$value = $initCode;" + case StringType => code"$value = $initExpr.clone();" + case _: StructType | _: ArrayType | _: MapType => code"$value = $initExpr.copy();" + case _ => code"$value = $initExpr;" } ExprCode(code, FalseLiteral, JavaCode.global(value, dataType)) } @@ -370,11 +373,11 @@ class CodegenContext { def initMutableStates(): String = { // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. - val initCodes = mutableStateInitCode.distinct.map(_ + "\n") + val initCodes = mutableStateInitCode.distinct.map(initCode => inline"$initCode") // The generated initialization code may exceed 64kb function size limit in JVM if there are too // many mutable states, so split it into multiple functions. - splitExpressions(expressions = initCodes, funcName = "init", arguments = Nil) + splitExpressions(expressions = initCodes, funcName = "init", arguments = Nil).code } /** @@ -553,7 +556,7 @@ class CodegenContext { * The map from a variable name to it's next ID. */ private val freshNameIds = new mutable.HashMap[String, Int] - freshNameIds += INPUT_ROW -> 1 + freshNameIds += INPUT_ROW.toString -> 1 /** * A prefix used to generate fresh name. @@ -582,18 +585,18 @@ class CodegenContext { /** * Generates code for equal expression in Java. */ - def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match { - case BinaryType => s"java.util.Arrays.equals($c1, $c2)" + def genEqual(dataType: DataType, c1: ExprValue, c2: ExprValue): Block = dataType match { + case BinaryType => code"java.util.Arrays.equals($c1, $c2)" case FloatType => - s"((java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2)" + code"((java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2)" case DoubleType => - s"((java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2)" - case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2" - case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)" - case array: ArrayType => genComp(array, c1, c2) + " == 0" - case struct: StructType => genComp(struct, c1, c2) + " == 0" + code"((java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2)" + case dt: DataType if isPrimitiveType(dt) => code"$c1 == $c2" + case dt: DataType if dt.isInstanceOf[AtomicType] => code"$c1.equals($c2)" + case array: ArrayType => code"${genComp(array, c1, c2)} == 0" + case struct: StructType => code"${genComp(struct, c1, c2)} == 0" case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2) - case NullType => "false" + case NullType => code"false" case _ => throw new IllegalArgumentException( "cannot generate equality code for un-comparable type: " + dataType.simpleString) @@ -606,38 +609,41 @@ class CodegenContext { * @param c1 name of the variable of expression 1's output * @param c2 name of the variable of expression 2's output */ - def genComp(dataType: DataType, c1: String, c2: String): String = dataType match { + def genComp(dataType: DataType, c1: ExprValue, c2: ExprValue): Block = dataType match { // java boolean doesn't support > or < operator - case BooleanType => s"($c1 == $c2 ? 0 : ($c1 ? 1 : -1))" - case DoubleType => s"org.apache.spark.util.Utils.nanSafeCompareDoubles($c1, $c2)" - case FloatType => s"org.apache.spark.util.Utils.nanSafeCompareFloats($c1, $c2)" + case BooleanType => code"($c1 == $c2 ? 0 : ($c1 ? 1 : -1))" + case DoubleType => code"org.apache.spark.util.Utils.nanSafeCompareDoubles($c1, $c2)" + case FloatType => code"org.apache.spark.util.Utils.nanSafeCompareFloats($c1, $c2)" // use c1 - c2 may overflow - case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" - case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" - case NullType => "0" + case dt: DataType if isPrimitiveType(dt) => code"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" + case BinaryType => code"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" + case NullType => code"0" case array: ArrayType => val elementType = array.elementType - val elementA = freshName("elementA") - val isNullA = freshName("isNullA") - val elementB = freshName("elementB") - val isNullB = freshName("isNullB") - val compareFunc = freshName("compareArray") - val minLength = freshName("minLength") - val jt = javaType(elementType) - val funcCode: String = - s""" - public int $compareFunc(ArrayData a, ArrayData b) { + val elementA = JavaCode.variable(freshName("elementA"), elementType) + val isNullA = JavaCode.isNullVariable(freshName("isNullA")) + val elementB = JavaCode.variable(freshName("elementB"), elementType) + val isNullB = JavaCode.isNullVariable(freshName("isNullB")) + val compareFunc = inline"${freshName("compareArray")}" + val minLength = JavaCode.variable(freshName("minLength"), IntegerType) + val a = JavaCode.variable("a", classOf[ArrayData]) + val b = JavaCode.variable("b", classOf[ArrayData]) + val i = JavaCode.variable("i", IntegerType) + val jt = JavaCode.javaType(elementA) + val funcCode: Block = + code""" + public int $compareFunc(ArrayData $a, ArrayData $b) { // when comparing unsafe arrays, try equals first as it compares the binary directly // which is very fast. - if (a instanceof UnsafeArrayData && b instanceof UnsafeArrayData && a.equals(b)) { + if ($a instanceof UnsafeArrayData && $b instanceof UnsafeArrayData && $a.equals($b)) { return 0; } - int lengthA = a.numElements(); - int lengthB = b.numElements(); + int lengthA = $a.numElements(); + int lengthB = $b.numElements(); int $minLength = (lengthA > lengthB) ? lengthB : lengthA; - for (int i = 0; i < $minLength; i++) { - boolean $isNullA = a.isNullAt(i); - boolean $isNullB = b.isNullAt(i); + for (int $i = 0; $i < $minLength; $i++) { + boolean $isNullA = $a.isNullAt($i); + boolean $isNullB = $b.isNullAt($i); if ($isNullA && $isNullB) { // Nothing } else if ($isNullA) { @@ -645,8 +651,8 @@ class CodegenContext { } else if ($isNullB) { return 1; } else { - $jt $elementA = ${getValue("a", elementType, "i")}; - $jt $elementB = ${getValue("b", elementType, "i")}; + $jt $elementA = ${getValue(a, elementType, i)}; + $jt $elementB = ${getValue(b, elementType, i)}; int comp = ${genComp(elementType, elementA, elementB)}; if (comp != 0) { return comp; @@ -662,7 +668,8 @@ class CodegenContext { return 0; } """ - s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)" + val funcCall = inline"${addNewFunction(compareFunc.code, funcCode.code)}" + code"$funcCall($c1, $c2)" case schema: StructType => val comparisons = GenerateOrdering.genComparisons(this, schema) val compareFunc = freshName("compareStruct") @@ -678,8 +685,9 @@ class CodegenContext { return 0; } """ - s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)" - case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)" + val funcCall = inline"${addNewFunction(compareFunc, funcCode)}" + code"$funcCall($c1, $c2)" + case other if other.isInstanceOf[AtomicType] => code"$c1.compare($c2)" case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2) case _ => throw new IllegalArgumentException( @@ -693,10 +701,11 @@ class CodegenContext { * @param c1 name of the variable of expression 1's output * @param c2 name of the variable of expression 2's output */ - def genGreater(dataType: DataType, c1: String, c2: String): String = javaType(dataType) match { - case JAVA_BYTE | JAVA_SHORT | JAVA_INT | JAVA_LONG => s"$c1 > $c2" - case _ => s"(${genComp(dataType, c1, c2)}) > 0" - } + def genGreater(dataType: DataType, c1: ExprValue, c2: ExprValue): Block = + javaType(dataType) match { + case JAVA_BYTE | JAVA_SHORT | JAVA_INT | JAVA_LONG => code"$c1 > $c2" + case _ => code"(${genComp(dataType, c1, c2)}) > 0" + } /** * Generates code for updating `partialResult` if `item` is smaller than it. @@ -705,8 +714,8 @@ class CodegenContext { * @param partialResult `ExprCode` representing the partial result which has to be updated * @param item `ExprCode` representing the new expression to evaluate for the result */ - def reassignIfSmaller(dataType: DataType, partialResult: ExprCode, item: ExprCode): String = { - s""" + def reassignIfSmaller(dataType: DataType, partialResult: ExprCode, item: ExprCode): Block = { + code""" |if (!${item.isNull} && (${partialResult.isNull} || | ${genGreater(dataType, partialResult.value, item.value)})) { | ${partialResult.isNull} = false; @@ -722,8 +731,8 @@ class CodegenContext { * @param partialResult `ExprCode` representing the partial result which has to be updated * @param item `ExprCode` representing the new expression to evaluate for the result */ - def reassignIfGreater(dataType: DataType, partialResult: ExprCode, item: ExprCode): String = { - s""" + def reassignIfGreater(dataType: DataType, partialResult: ExprCode, item: ExprCode): Block = { + code""" |if (!${item.isNull} && (${partialResult.isNull} || | ${genGreater(dataType, item.value, partialResult.value)})) { | ${partialResult.isNull} = false; @@ -741,14 +750,14 @@ class CodegenContext { * @param additionalErrorMessage string to include in the error message */ def createUnsafeArray( - arrayName: String, - numElements: String, + arrayName: ExprValue, + numElements: ExprValue, elementType: DataType, - additionalErrorMessage: String): String = { - val arraySize = freshName("size") - val arrayBytes = freshName("arrayBytes") + additionalErrorMessage: ExprValue): Block = { + val arraySize = JavaCode.variable(freshName("size"), LongType) + val arrayBytes = JavaCode.variable(freshName("arrayBytes"), BinaryType) - s""" + code""" |long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( | $numElements, | ${elementType.defaultSize}); @@ -776,14 +785,14 @@ class CodegenContext { * @param fallbackCode a piece of code executed when the array size limit is exceeded */ def createUnsafeArrayWithFallback( - arrayName: String, - numElements: String, + arrayName: ExprValue, + numElements: ExprValue, elementSize: Int, - bodyCode: String => String, - fallbackCode: String): String = { - val arraySize = freshName("size") - val arrayBytes = freshName("arrayBytes") - s""" + bodyCode: ExprValue => Block, + fallbackCode: Block): Block = { + val arraySize = JavaCode.variable(freshName("size"), LongType) + val arrayBytes = JavaCode.variable(freshName("arrayBytes"), BinaryType) + code""" |final long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( | $numElements, | $elementSize); @@ -807,15 +816,15 @@ class CodegenContext { * @param isNull the code to check if the input is null. * @param execute the code that should only be executed when the input is not null. */ - def nullSafeExec(nullable: Boolean, isNull: String)(execute: String): String = { + def nullSafeExec(nullable: Boolean, isNull: ExprValue)(execute: Block): Block = { if (nullable) { - s""" + code""" if (!$isNull) { $execute } """ } else { - "\n" + execute + code"\n" + execute } } @@ -839,20 +848,21 @@ class CodegenContext { * @param foldFunctions folds the split function calls. */ def splitExpressionsWithCurrentInputs( - expressions: Seq[String], + expressions: Seq[JavaCode], funcName: String = "apply", - extraArguments: Seq[(String, String)] = Nil, + extraArguments: Seq[ExprValue] = Nil, returnType: String = "void", makeSplitFunction: String => String = identity, - foldFunctions: Seq[String] => String = _.mkString("", ";\n", ";")): String = { + foldFunctions: Seq[Block] => Block = + _.foldLeft(code"")((blocks: Block, block: Block) => code"$blocks$block;\n")): Block = { // TODO: support whole stage codegen if (INPUT_ROW == null || currentVars != null) { - expressions.mkString("\n") + Blocks(expressions.map(expr => code"$expr")) } else { splitExpressions( expressions, funcName, - ("InternalRow", INPUT_ROW) +: extraArguments, + (INPUT_ROW) +: extraArguments, returnType, makeSplitFunction, foldFunctions) @@ -867,18 +877,19 @@ class CodegenContext { * * @param expressions the codes to evaluate expressions. * @param funcName the split function name base. - * @param arguments the list of (type, name) of the arguments of the split function. + * @param arguments the list of the arguments of the split function. * @param returnType the return type of the split function. * @param makeSplitFunction makes split function body, e.g. add preparation or cleanup. * @param foldFunctions folds the split function calls. */ def splitExpressions( - expressions: Seq[String], + expressions: Seq[JavaCode], funcName: String, - arguments: Seq[(String, String)], + arguments: Seq[ExprValue], returnType: String = "void", makeSplitFunction: String => String = identity, - foldFunctions: Seq[String] => String = _.mkString("", ";\n", ";")): String = { + foldFunctions: Seq[Block] => Block = + _.foldLeft(code"")((blocks: Block, block: Block) => code"$blocks$block;\n")): Block = { val blocks = buildCodeBlocks(expressions) if (blocks.length == 1) { @@ -888,19 +899,25 @@ class CodegenContext { if (Utils.isTesting) { // Passing global variables to the split method is dangerous, as any mutating to it is // ignored and may lead to unexpected behavior. - arguments.foreach { case (_, name) => - assert(!mutableStateNames.contains(name), - s"split function argument $name cannot be a global variable.") + arguments.foreach { arg => + assert(!arg.isInstanceOf[GlobalValue], + s"split function argument $arg cannot be a global variable.") } } val func = freshName(funcName) - val argString = arguments.map { case (t, name) => s"$t $name" }.mkString(", ") + val paramBlock = arguments.foldLeft[Block](EmptyBlock) { (args: Block, arg: ExprValue) => + if (args.length == 0) { + code"${JavaCode.javaType(arg)} $arg" + } else { + code"$args, ${JavaCode.javaType(arg)} $arg" + } + } val functions = blocks.zipWithIndex.map { case (body, i) => val name = s"${func}_$i" val code = s""" - |private $returnType $name($argString) { - | ${makeSplitFunction(body)} + |private $returnType $name($paramBlock) { + | ${makeSplitFunction(body.code)} |} """.stripMargin addNewFunctionInternal(name, code, inlineToOuterClass = false) @@ -908,8 +925,17 @@ class CodegenContext { val (outerClassFunctions, innerClassFunctions) = functions.partition(_.innerClassName.isEmpty) - val argsString = arguments.map(_._2).mkString(", ") - val outerClassFunctionCalls = outerClassFunctions.map(f => s"${f.functionName}($argsString)") + val argsBlock = arguments.foldLeft[Block](EmptyBlock) { (args: Block, arg: ExprValue) => + if (args.length == 0) { + code"$arg" + } else { + code"$args, $arg" + } + } + val outerClassFunctionCalls = outerClassFunctions.map { f => + val functionName = inline"${f.functionName}" + code"$functionName($argsBlock)" + } val innerClassFunctionCalls = generateInnerClassesFunctionCalls( innerClassFunctions, @@ -929,9 +955,9 @@ class CodegenContext { * * @param expressions the codes to evaluate expressions. */ - private def buildCodeBlocks(expressions: Seq[String]): Seq[String] = { - val blocks = new ArrayBuffer[String]() - val blockBuilder = new StringBuilder() + private def buildCodeBlocks(expressions: Seq[JavaCode]): Seq[Block] = { + val blocks = new ArrayBuffer[Block]() + var blockBuilder: Block = EmptyBlock var length = 0 for (code <- expressions) { // We can't know how many bytecode will be generated, so use the length of source code @@ -939,14 +965,14 @@ class CodegenContext { // also not be too small, or it will have many function calls (for wide table), see the // results in BenchmarkWideTable. if (length > 1024) { - blocks += blockBuilder.toString() - blockBuilder.clear() + blocks += blockBuilder + blockBuilder = EmptyBlock length = 0 } - blockBuilder.append(code) - length += CodeFormatter.stripExtraNewLinesAndComments(code).length + blockBuilder = blockBuilder + code"$code" + length += CodeFormatter.stripExtraNewLinesAndComments(code.code).length } - blocks += blockBuilder.toString() + blocks += blockBuilder } /** @@ -970,10 +996,10 @@ class CodegenContext { private def generateInnerClassesFunctionCalls( functions: Seq[NewFunctionSpec], funcName: String, - arguments: Seq[(String, String)], + arguments: Seq[ExprValue], returnType: String, makeSplitFunction: String => String, - foldFunctions: Seq[String] => String): Iterable[String] = { + foldFunctions: Seq[Block] => Block): Iterable[Block] = { val innerClassToFunctions = mutable.LinkedHashMap.empty[(String, String), Seq[String]] functions.foreach(f => { val key = (f.innerClassName.get, f.innerClassInstance.get) @@ -981,11 +1007,27 @@ class CodegenContext { innerClassToFunctions.put(key, value) }) - val argDefinitionString = arguments.map { case (t, name) => s"$t $name" }.mkString(", ") - val argInvocationString = arguments.map(_._2).mkString(", ") + val argDefinitionString = + arguments.foldLeft[Block](EmptyBlock) { (args: Block, arg: ExprValue) => + val typeName = inline"${JavaCode.javaType(arg)}" + if (args.length == 0) { + code"$typeName $arg" + } else { + code"$args, $typeName $arg" + } + } + val argInvocationString = + arguments.foldLeft[Block](EmptyBlock) { (args: Block, arg: ExprValue) => + if (args.length == 0) { + code"$arg" + } else { + code"$args, $arg" + } + } innerClassToFunctions.flatMap { case ((innerClassName, innerClassInstance), innerClassFunctions) => + val innerClassIdentifier = inline"$innerClassInstance" // for performance reasons, the functions are prepended, instead of appended, // thus here they are in reversed order val orderedFunctions = innerClassFunctions.reverse @@ -1002,16 +1044,23 @@ class CodegenContext { // ... // } // } - val body = foldFunctions(orderedFunctions.map(name => s"$name($argInvocationString)")) + val body = foldFunctions(orderedFunctions.map { name => + val funcName = inline"$name" + code"$funcName($argInvocationString)" + }) val code = s""" |private $returnType $funcName($argDefinitionString) { - | ${makeSplitFunction(body)} + | ${makeSplitFunction(body.code)} |} """.stripMargin + val funcNameIdentifier = inline"$funcName" addNewFunctionToClass(funcName, code, innerClassName) - Seq(s"$innerClassInstance.$funcName($argInvocationString)") + Seq(code"$innerClassIdentifier.$funcNameIdentifier($argInvocationString)") } else { - orderedFunctions.map(f => s"$innerClassInstance.$f($argInvocationString)") + orderedFunctions.map { f => + val funcName = inline"$f" + code"$innerClassIdentifier.$funcName($argInvocationString)" + } } } } @@ -1161,7 +1210,7 @@ class CodegenContext { s"// $text" } placeHolderToComments += (name -> comment) - code"/*$name*/" + inline"/*$name*/" } else { EmptyBlock } @@ -1327,7 +1376,7 @@ object CodeGenerator extends Logging { val msg = s"failed to compile: $e" logError(msg, e) val maxLines = SQLConf.get.loggingMaxLinesForCodegen - logInfo(s"\n${CodeFormatter.format(code, maxLines)}") + println(s"\n${CodeFormatter.format(code, maxLines)}") throw new CompileException(msg, e.getLocation) } @@ -1427,10 +1476,11 @@ object CodeGenerator extends Logging { /** * Returns the specialized code to access a value from `inputRow` at `ordinal`. */ - def getValue(input: String, dataType: DataType, ordinal: String): String = { + def getValue(input: ExprValue, dataType: DataType, ordinal: String): ExprValue = { val jt = javaType(dataType) - dataType match { - case _ if isPrimitiveType(jt) => s"$input.get${primitiveTypeName(jt)}($ordinal)" + val expression = dataType match { + case _ if isPrimitiveType(jt) => + s"$input.get${primitiveTypeName(jt)}($ordinal)" case t: DecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})" case StringType => s"$input.getUTF8String($ordinal)" case BinaryType => s"$input.getBinary($ordinal)" @@ -1439,25 +1489,28 @@ object CodeGenerator extends Logging { case _: ArrayType => s"$input.getArray($ordinal)" case _: MapType => s"$input.getMap($ordinal)" case NullType => "null" - case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal) + case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal).toString case _ => s"($jt)$input.get($ordinal, null)" } + JavaCode.expression(expression, dataType) } /** * Returns the code to update a column in Row for a given DataType. */ - def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = { + def setColumn(row: ExprValue, dataType: DataType, ordinal: Int, value: ExprValue): Block = { val jt = javaType(dataType) dataType match { - case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" - case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" + case _ if isPrimitiveType(jt) => + val typeName = inline"${primitiveTypeName(jt)}" + code"$row.set$typeName($ordinal, $value)" + case t: DecimalType => code"$row.setDecimal($ordinal, $value, ${t.precision})" case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) // The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy // it to avoid keeping a "pointer" to a memory region which may get updated afterwards. case StringType | _: StructType | _: ArrayType | _: MapType => - s"$row.update($ordinal, $value.copy())" - case _ => s"$row.update($ordinal, $value)" + code"$row.update($ordinal, $value.copy())" + case _ => code"$row.update($ordinal, $value)" } } @@ -1467,24 +1520,24 @@ object CodeGenerator extends Logging { * @param isVectorized True if the underlying row is of type `ColumnarBatch.Row`, false otherwise */ def updateColumn( - row: String, + row: ExprValue, dataType: DataType, ordinal: Int, ev: ExprCode, nullable: Boolean, - isVectorized: Boolean = false): String = { + isVectorized: Boolean = false): Block = { if (nullable) { // Can't call setNullAt on DecimalType, because we need to keep the offset if (!isVectorized && dataType.isInstanceOf[DecimalType]) { - s""" + code""" |if (!${ev.isNull}) { | ${setColumn(row, dataType, ordinal, ev.value)}; |} else { - | ${setColumn(row, dataType, ordinal, "null")}; + | ${setColumn(row, dataType, ordinal, JavaCode.literal("null", dataType))}; |} """.stripMargin } else { - s""" + code""" |if (!${ev.isNull}) { | ${setColumn(row, dataType, ordinal, ev.value)}; |} else { @@ -1493,20 +1546,25 @@ object CodeGenerator extends Logging { """.stripMargin } } else { - s"""${setColumn(row, dataType, ordinal, ev.value)};""" + code"""${setColumn(row, dataType, ordinal, ev.value)};""" } } /** * Returns the specialized code to set a given value in a column vector for a given `DataType`. */ - def setValue(vector: String, rowId: String, dataType: DataType, value: String): String = { + def setValue( + vector: ExprValue, + rowId: ExprValue, + dataType: DataType, + value: ExprValue): Block = { val jt = javaType(dataType) dataType match { case _ if isPrimitiveType(jt) => - s"$vector.put${primitiveTypeName(jt)}($rowId, $value);" - case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});" - case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());" + val typeName = inline"${primitiveTypeName(jt)}" + code"$vector.put$typeName($rowId, $value);" + case t: DecimalType => code"$vector.putDecimal($rowId, $value, ${t.precision});" + case t: StringType => code"$vector.putByteArray($rowId, $value.getBytes());" case _ => throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType") } @@ -1517,13 +1575,13 @@ object CodeGenerator extends Logging { * that could potentially be nullable. */ def updateColumn( - vector: String, - rowId: String, + vector: ExprValue, + rowId: ExprValue, dataType: DataType, ev: ExprCode, - nullable: Boolean): String = { + nullable: Boolean): Block = { if (nullable) { - s""" + code""" |if (!${ev.isNull}) { | ${setValue(vector, rowId, dataType, ev.value)} |} else { @@ -1531,18 +1589,18 @@ object CodeGenerator extends Logging { |} """.stripMargin } else { - s"""${setValue(vector, rowId, dataType, ev.value)};""" + code"""${setValue(vector, rowId, dataType, ev.value)};""" } } /** * Returns the specialized code to access a value from a column vector for a given `DataType`. */ - def getValueFromVector(vector: String, dataType: DataType, rowId: String): String = { + def getValueFromVector(vector: ExprValue, dataType: DataType, rowId: ExprValue): ExprValue = { if (dataType.isInstanceOf[StructType]) { // `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an // `ordinal` parameter. - s"$vector.getStruct($rowId)" + JavaCode.expression(s"$vector.getStruct($rowId)", dataType) } else { getValue(vector, dataType, rowId) } @@ -1623,18 +1681,21 @@ object CodeGenerator extends Logging { * @param jt the string name of the Java type * @param typedNull if true, for null literals, return a typed (with a cast) version */ - def defaultValue(jt: String, typedNull: Boolean): String = jt match { - case JAVA_BOOLEAN => "false" - case JAVA_BYTE => "(byte)-1" - case JAVA_SHORT => "(short)-1" - case JAVA_INT => "-1" - case JAVA_LONG => "-1L" - case JAVA_FLOAT => "-1.0f" - case JAVA_DOUBLE => "-1.0" - case _ => if (typedNull) s"(($jt)null)" else "null" + def defaultValue(jt: String, typedNull: Boolean): JavaCode = { + val value = jt match { + case JAVA_BOOLEAN => "false" + case JAVA_BYTE => "(byte)-1" + case JAVA_SHORT => "(short)-1" + case JAVA_INT => "-1" + case JAVA_LONG => "-1L" + case JAVA_FLOAT => "-1.0f" + case JAVA_DOUBLE => "-1.0" + case _ => if (typedNull) s"(($jt)null)" else "null" + } + inline"$value" } - def defaultValue(dt: DataType, typedNull: Boolean = false): String = + def defaultValue(dt: DataType, typedNull: Boolean = false): JavaCode = defaultValue(javaType(dt), typedNull) /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 3f4704d287cb..8aee6668128f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -27,7 +27,11 @@ trait CodegenFallback extends Expression { protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // LeafNode does not need `input` - val input = if (this.isInstanceOf[LeafExpression]) "null" else ctx.INPUT_ROW + val input = if (this.isInstanceOf[LeafExpression]) { + JavaCode.literal("null", dataType) + } else { + ctx.INPUT_ROW + } val idx = ctx.references.length ctx.references += this var childIndex = idx @@ -43,9 +47,10 @@ trait CodegenFallback extends Expression { """.stripMargin) case _ => } - val objectTerm = ctx.freshName("obj") + val objectTerm = JavaCode.variable(ctx.freshName("obj"), classOf[Object]) val placeHolder = ctx.registerComment(this.toString) - val javaType = CodeGenerator.javaType(this.dataType) + val javaType = inline"${CodeGenerator.javaType(this.dataType)}" + val boxedType = inline"${CodeGenerator.boxedType(this.dataType)}" if (nullable) { ev.copy(code = code""" $placeHolder @@ -53,13 +58,13 @@ trait CodegenFallback extends Expression { boolean ${ev.isNull} = $objectTerm == null; $javaType ${ev.value} = ${CodeGenerator.defaultValue(this.dataType)}; if (!${ev.isNull}) { - ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; + ${ev.value} = ($boxedType) $objectTerm; }""") } else { ev.copy(code = code""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); - $javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; + $javaType ${ev.value} = ($boxedType) $objectTerm; """, isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 33d14329ec95..6e807fdf130a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.catalyst.expressions.codegen +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.types.BooleanType // MutableProjection is not accessible in Java abstract class BaseMutableProjection extends MutableProjection @@ -59,26 +62,28 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val exprVals = ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination) // 4-tuples: (code for projection, isNull variable name, value variable name, column index) - val projectionCodes: Seq[(String, String)] = validExpr.zip(exprVals).map { + val projectionCodes: Seq[(Block, Block)] = validExpr.zip(exprVals).map { case ((e, i), ev) => val value = JavaCode.global( ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value"), e.dataType) val (code, isNull) = if (e.nullable) { - val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "isNull") - (s""" + val isNull = JavaCode.global(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "isNull"), + BooleanType) + (code""" |${ev.code} |$isNull = ${ev.isNull}; |$value = ${ev.value}; """.stripMargin, JavaCode.isNullGlobal(isNull)) } else { - (s""" + (code""" |${ev.code} |$value = ${ev.value}; """.stripMargin, FalseLiteral) } + val mutableRow = JavaCode.variable("mutableRow", classOf[InternalRow]) val update = CodeGenerator.updateColumn( - "mutableRow", + mutableRow, e.dataType, i, ExprCode(isNull, value), @@ -89,8 +94,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP // Evaluate all the subexpressions. val evalSubexpr = ctx.subexprFunctions.mkString("\n") - val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1)) - val allUpdates = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._2)) + val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1)).code + val allUpdates = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._2)).code val codeBody = s""" public java.lang.Object generate(Object[] references) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 9a51be6ed5ae..500d785a2669 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -25,7 +25,8 @@ import com.esotericsoftware.kryo.io.{Input, Output} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.util.Utils /** @@ -74,30 +75,33 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR def genComparisons(ctx: CodegenContext, ordering: Seq[SortOrder]): String = { val oldInputRow = ctx.INPUT_ROW val oldCurrentVars = ctx.currentVars - val inputRow = "i" + val inputRow = JavaCode.variable("i", classOf[InternalRow]) ctx.INPUT_ROW = inputRow // to use INPUT_ROW we must make sure currentVars is null ctx.currentVars = null + val varA = JavaCode.variable("a", classOf[InternalRow]) + val varB = JavaCode.variable("b", classOf[InternalRow]) val comparisons = ordering.map { order => val eval = order.child.genCode(ctx) val asc = order.isAscending - val isNullA = ctx.freshName("isNullA") - val primitiveA = ctx.freshName("primitiveA") - val isNullB = ctx.freshName("isNullB") - val primitiveB = ctx.freshName("primitiveB") - s""" - ${ctx.INPUT_ROW} = a; + val isNullA = JavaCode.isNullVariable(ctx.freshName("isNullA")) + val primitiveType = inline"${CodeGenerator.javaType(order.child.dataType)}" + val primitiveA = JavaCode.variable(ctx.freshName("primitiveA"), order.child.dataType) + val isNullB = JavaCode.isNullVariable(ctx.freshName("isNullB")) + val primitiveB = JavaCode.variable(ctx.freshName("primitiveB"), order.child.dataType) + code""" + ${ctx.INPUT_ROW} = $varA; boolean $isNullA; - ${CodeGenerator.javaType(order.child.dataType)} $primitiveA; + $primitiveType $primitiveA; { ${eval.code} $isNullA = ${eval.isNull}; $primitiveA = ${eval.value}; } - ${ctx.INPUT_ROW} = b; + ${ctx.INPUT_ROW} = $varB; boolean $isNullB; - ${CodeGenerator.javaType(order.child.dataType)} $primitiveB; + $primitiveType $primitiveB; { ${eval.code} $isNullB = ${eval.isNull}; @@ -108,19 +112,19 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR } else if ($isNullA) { return ${ order.nullOrdering match { - case NullsFirst => "-1" - case NullsLast => "1" + case NullsFirst => code"-1" + case NullsLast => code"1" }}; } else if ($isNullB) { return ${ order.nullOrdering match { - case NullsFirst => "1" - case NullsLast => "-1" + case NullsFirst => code"1" + case NullsLast => code"-1" }}; } else { int comp = ${ctx.genComp(order.child.dataType, primitiveA, primitiveB)}; if (comp != 0) { - return ${if (asc) "comp" else "-comp"}; + return ${if (asc) code"comp" else code"-comp"}; } } """ @@ -129,7 +133,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR val code = ctx.splitExpressions( expressions = comparisons, funcName = "compare", - arguments = Seq(("InternalRow", "a"), ("InternalRow", "b")), + arguments = Seq(varA, varB), returnType = "int", makeSplitFunction = { body => s""" @@ -139,15 +143,16 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR """ }, foldFunctions = { funCalls => - funCalls.zipWithIndex.map { case (funCall, i) => - val comp = ctx.freshName("comp") - s""" + val blocks = funCalls.zipWithIndex.map { case (funCall, i) => + val comp = JavaCode.variable(ctx.freshName("comp"), IntegerType) + code""" int $comp = $funCall; if ($comp != 0) { return $comp; } """ - }.mkString + } + Blocks(blocks) }) ctx.currentVars = oldCurrentVars ctx.INPUT_ROW = oldInputRow diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 39778661d1c4..7cf8618fbb43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -45,21 +45,21 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] private def createCodeForStruct( ctx: CodegenContext, - input: String, + input: ExprValue, schema: StructType): ExprCode = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. - val tmpInput = ctx.freshName("tmpInput") - val output = ctx.freshName("safeRow") - val values = ctx.freshName("values") + val tmpInput = JavaCode.variable(ctx.freshName("tmpInput"), classOf[InternalRow]) + val output = JavaCode.variable(ctx.freshName("safeRow"), classOf[InternalRow]) + val values = JavaCode.variable(ctx.freshName("values"), classOf[Array[Object]]) - val rowClass = classOf[GenericInternalRow].getName + val rowClass = inline"${classOf[GenericInternalRow].getName}" val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => val converter = convertToSafe( ctx, JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt), dt) - s""" + code""" if (!$tmpInput.isNullAt($i)) { ${converter.code} $values[$i] = ${converter.value}; @@ -69,7 +69,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val allFields = ctx.splitExpressions( expressions = fieldWriters, funcName = "writeFields", - arguments = Seq("InternalRow" -> tmpInput, "Object[]" -> values) + arguments = Seq(tmpInput, values) ) val code = code""" @@ -84,15 +84,15 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] private def createCodeForArray( ctx: CodegenContext, - input: String, + input: ExprValue, elementType: DataType): ExprCode = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. - val tmpInput = ctx.freshName("tmpInput") - val output = ctx.freshName("safeArray") - val values = ctx.freshName("values") - val numElements = ctx.freshName("numElements") - val index = ctx.freshName("index") - val arrayClass = classOf[GenericArrayData].getName + val tmpInput = JavaCode.variable(ctx.freshName("tmpInput"), classOf[ArrayData]) + val output = JavaCode.variable(ctx.freshName("safeArray"), classOf[ArrayData]) + val values = JavaCode.variable(ctx.freshName("values"), classOf[Array[Object]]) + val numElements = JavaCode.variable(ctx.freshName("numElements"), IntegerType) + val index = JavaCode.variable(ctx.freshName("index"), IntegerType) + val arrayClass = inline"${classOf[GenericArrayData].getName}" val elementConverter = convertToSafe( ctx, @@ -116,15 +116,17 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] private def createCodeForMap( ctx: CodegenContext, - input: String, + input: ExprValue, keyType: DataType, valueType: DataType): ExprCode = { - val tmpInput = ctx.freshName("tmpInput") - val output = ctx.freshName("safeMap") - val mapClass = classOf[ArrayBasedMapData].getName - - val keyConverter = createCodeForArray(ctx, s"$tmpInput.keyArray()", keyType) - val valueConverter = createCodeForArray(ctx, s"$tmpInput.valueArray()", valueType) + val tmpInput = JavaCode.variable(ctx.freshName("tmpInput"), classOf[MapData]) + val output = JavaCode.variable(ctx.freshName("safeMap"), classOf[MapData]) + val mapClass = inline"${classOf[ArrayBasedMapData].getName}" + val keyArray = JavaCode.expression(s"$tmpInput.keyArray()", classOf[ArrayData]) + val valueArray = JavaCode.expression(s"$tmpInput.valueArray()", classOf[ArrayData]) + + val keyConverter = createCodeForArray(ctx, keyArray, keyType) + val valueConverter = createCodeForArray(ctx, valueArray, valueType) val code = code""" final MapData $tmpInput = $input; ${keyConverter.code} @@ -149,22 +151,23 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] protected def create(expressions: Seq[Expression]): Projection = { val ctx = newCodeGenContext() + val mutableRow = JavaCode.variable("mutableRow", classOf[InternalRow]) val expressionCodes = expressions.zipWithIndex.map { - case (NoOp, _) => "" + case (NoOp, _) => EmptyBlock case (e, i) => val evaluationCode = e.genCode(ctx) val converter = convertToSafe(ctx, evaluationCode.value, e.dataType) evaluationCode.code + - s""" + code""" if (${evaluationCode.isNull}) { mutableRow.setNullAt($i); } else { ${converter.code} - ${CodeGenerator.setColumn("mutableRow", e.dataType, i, converter.value)}; + ${CodeGenerator.setColumn(mutableRow, e.dataType, i, converter.value)}; } """ } - val allExpressions = ctx.splitExpressionsWithCurrentInputs(expressionCodes) + val allExpressions = ctx.splitExpressionsWithCurrentInputs(expressionCodes).code val codeBody = s""" public java.lang.Object generate(Object[] references) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 8f2a5a0dce94..81fd8a27a454 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.catalyst.expressions.codegen +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ /** @@ -46,12 +48,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // TODO: if the nullability of field is correct, we can use it to save null check. private def writeStructToBuffer( ctx: CodegenContext, - input: String, - index: String, + input: ExprValue, + index: ExprValue, fieldTypes: Seq[DataType], - rowWriter: String): String = { + rowWriter: ExprValue): Block = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. - val tmpInput = ctx.freshName("tmpInput") + val tmpInput = JavaCode.variable(ctx.freshName("tmpInput"), classOf[InternalRow]) val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => ExprCode( JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)"), @@ -59,10 +61,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } val rowWriterClass = classOf[UnsafeRowWriter].getName - val structRowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", - v => s"$v = new $rowWriterClass($rowWriter, ${fieldEvals.length});") - val previousCursor = ctx.freshName("previousCursor") - s""" + val structRowWriter = JavaCode.global(ctx.addMutableState(rowWriterClass, "rowWriter", + v => s"$v = new $rowWriterClass($rowWriter, ${fieldEvals.length});"), + classOf[UnsafeRowWriter]) + val previousCursor = JavaCode.variable(ctx.freshName("previousCursor"), IntegerType) + code""" |final InternalRow $tmpInput = $input; |if ($tmpInput instanceof UnsafeRow) { | $rowWriter.write($index, (UnsafeRow) $tmpInput); @@ -78,11 +81,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private def writeExpressionsToBuffer( ctx: CodegenContext, - row: String, + row: ExprValue, inputs: Seq[ExprCode], inputTypes: Seq[DataType], - rowWriter: String, - isTopLevel: Boolean = false): String = { + rowWriter: ExprValue, + isTopLevel: Boolean = false): Block = { val resetWriter = if (isTopLevel) { // For top level row writer, it always writes to the beginning of the global buffer holder, // which means its fixed-size region always in the same position, so we don't need to call @@ -90,38 +93,39 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro if (inputs.map(_.isNull).forall(_ == "false")) { // If all fields are not nullable, which means the null bits never changes, then we don't // need to clear it out every time. - "" + EmptyBlock } else { - s"$rowWriter.zeroOutNullBytes();" + code"$rowWriter.zeroOutNullBytes();" } } else { - s"$rowWriter.resetRowWriter();" + code"$rowWriter.resetRowWriter();" } val writeFields = inputs.zip(inputTypes).zipWithIndex.map { case ((input, dataType), index) => val dt = UserDefinedType.sqlType(dataType) + val indexValue = JavaCode.literal(index.toString, IntegerType) val setNull = dt match { case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => // Can't call setNullAt() for DecimalType with precision larger than 18. - s"$rowWriter.write($index, (Decimal) null, ${t.precision}, ${t.scale});" - case _ => s"$rowWriter.setNullAt($index);" + code"$rowWriter.write($index, (Decimal) null, ${t.precision}, ${t.scale});" + case _ => code"$rowWriter.setNullAt($index);" } - val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter) + val writeField = writeElement(ctx, input.value, indexValue, dt, rowWriter) if (input.isNull == FalseLiteral) { - s""" + code""" |${input.code} - |${writeField.trim} + |${writeField} """.stripMargin } else { - s""" + code""" |${input.code} |if (${input.isNull}) { - | ${setNull.trim} + | ${setNull} |} else { - | ${writeField.trim} + | ${writeField} |} """.stripMargin } @@ -129,15 +133,15 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val writeFieldsCode = if (isTopLevel && (row == null || ctx.currentVars != null)) { // TODO: support whole stage codegen - writeFields.mkString("\n") + Blocks(writeFields) } else { assert(row != null, "the input row name cannot be null when generating code to write it.") ctx.splitExpressions( expressions = writeFields, funcName = "writeFields", - arguments = Seq("InternalRow" -> row)) + arguments = Seq(row)) } - s""" + code""" |$resetWriter |$writeFieldsCode """.stripMargin @@ -146,13 +150,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // TODO: if the nullability of array element is correct, we can use it to save null check. private def writeArrayToBuffer( ctx: CodegenContext, - input: String, + input: ExprValue, elementType: DataType, - rowWriter: String): String = { + rowWriter: ExprValue): Block = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. - val tmpInput = ctx.freshName("tmpInput") - val numElements = ctx.freshName("numElements") - val index = ctx.freshName("index") + val tmpInput = JavaCode.variable(ctx.freshName("tmpInput"), classOf[ArrayData]) + val numElements = JavaCode.variable(ctx.freshName("numElements"), IntegerType) + val index = JavaCode.variable(ctx.freshName("index"), IntegerType) val et = UserDefinedType.sqlType(elementType) @@ -165,12 +169,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } val arrayWriterClass = classOf[UnsafeArrayWriter].getName - val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter", - v => s"$v = new $arrayWriterClass($rowWriter, $elementOrOffsetSize);") + val arrayWriter = JavaCode.global(ctx.addMutableState(arrayWriterClass, "arrayWriter", + v => s"$v = new $arrayWriterClass($rowWriter, $elementOrOffsetSize);"), + classOf[UnsafeArrayWriter]) val element = CodeGenerator.getValue(tmpInput, et, index) - s""" + code""" |final ArrayData $tmpInput = $input; |if ($tmpInput instanceof UnsafeArrayData) { | $rowWriter.write((UnsafeArrayData) $tmpInput); @@ -192,18 +197,20 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // TODO: if the nullability of value element is correct, we can use it to save null check. private def writeMapToBuffer( ctx: CodegenContext, - input: String, - index: String, + input: ExprValue, + index: ExprValue, keyType: DataType, valueType: DataType, - rowWriter: String): String = { + rowWriter: ExprValue): Block = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. - val tmpInput = ctx.freshName("tmpInput") - val tmpCursor = ctx.freshName("tmpCursor") - val previousCursor = ctx.freshName("previousCursor") + val tmpInput = JavaCode.variable(ctx.freshName("tmpInput"), classOf[MapData]) + val tmpCursor = JavaCode.variable(ctx.freshName("tmpCursor"), IntegerType) + val previousCursor = JavaCode.variable(ctx.freshName("previousCursor"), IntegerType) + val keyArray = JavaCode.expression(s"$tmpInput.keyArray()", classOf[ArrayData]) + val valueArray = JavaCode.expression(s"$tmpInput.valueArray()", classOf[ArrayData]) // Writes out unsafe map according to the format described in `UnsafeMapData`. - s""" + code""" |final MapData $tmpInput = $input; |if ($tmpInput instanceof UnsafeMapData) { | $rowWriter.write($index, (UnsafeMapData) $tmpInput); @@ -219,7 +226,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | // Remember the current cursor so that we can write numBytes of key array later. | final int $tmpCursor = $rowWriter.cursor(); | - | ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)} + | ${writeArrayToBuffer(ctx, keyArray, keyType, rowWriter)} | | // Write the numBytes of key array into the first 8 bytes. | Platform.putLong( @@ -227,7 +234,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | $tmpCursor - 8, | $rowWriter.cursor() - $tmpCursor); | - | ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)} + | ${writeArrayToBuffer(ctx, valueArray, valueType, rowWriter)} | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); |} """.stripMargin @@ -235,16 +242,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private def writeElement( ctx: CodegenContext, - input: String, - index: String, + input: ExprValue, + index: ExprValue, dt: DataType, - writer: String): String = dt match { + writer: ExprValue): Block = dt match { case t: StructType => writeStructToBuffer(ctx, input, index, t.map(_.dataType), writer) case ArrayType(et, _) => - val previousCursor = ctx.freshName("previousCursor") - s""" + val previousCursor = JavaCode.variable(ctx.freshName("previousCursor"), IntegerType) + code""" |// Remember the current cursor so that we can calculate how many bytes are |// written later. |final int $previousCursor = $writer.cursor(); @@ -256,11 +263,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro writeMapToBuffer(ctx, input, index, kt, vt, writer) case DecimalType.Fixed(precision, scale) => - s"$writer.write($index, $input, $precision, $scale);" + code"$writer.write($index, $input, $precision, $scale);" - case NullType => "" + case NullType => EmptyBlock - case _ => s"$writer.write($index, $input);" + case _ => code"$writer.write($index, $input);" } def createCode( @@ -277,11 +284,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } val rowWriterClass = classOf[UnsafeRowWriter].getName - val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", - v => s"$v = new $rowWriterClass(${expressions.length}, ${numVarLenFields * 32});") + val rowWriter = JavaCode.global(ctx.addMutableState(rowWriterClass, "rowWriter", + v => s"$v = new $rowWriterClass(${expressions.length}, ${numVarLenFields * 32});"), + classOf[UnsafeRowWriter]) // Evaluate all the subexpression. - val evalSubexpr = ctx.subexprFunctions.mkString("\n") + val evalSubexpr = Blocks(ctx.subexprFunctions.map(func => inline"$func")) val writeExpressions = writeExpressionsToBuffer( ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index febf7b0c96c2..452c79cb1c71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -21,7 +21,8 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.types.{LongType, StructType} import org.apache.spark.unsafe.Platform abstract class UnsafeRowJoiner { @@ -56,8 +57,8 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U def create(schema1: StructType, schema2: StructType): UnsafeRowJoiner = { val ctx = new CodegenContext val offset = Platform.BYTE_ARRAY_OFFSET - val getLong = "Platform.getLong" - val putLong = "Platform.putLong" + val getLong = inline"Platform.getLong" + val putLong = inline"Platform.putLong" val bitset1Words = (schema1.size + 63) / 64 val bitset2Words = (schema2.size + 63) / 64 @@ -68,38 +69,46 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U // The only reduction comes from merging the bitset portion of the two rows, saving 1 word. val sizeReduction = (bitset1Words + bitset2Words - outputBitsetWords) * 8 + val obj1 = JavaCode.variable("obj1", classOf[Object]) + val obj2 = JavaCode.variable("obj2", classOf[Object]) + val offset1 = JavaCode.variable("offset1", LongType) + val offset2 = JavaCode.variable("offset2", LongType) + // --------------------- copy bitset from row 1 and row 2 --------------------------- // val copyBitset = Seq.tabulate(outputBitsetWords) { i => val bits = if (bitset1Remainder > 0 && bitset2Words != 0) { if (i < bitset1Words - 1) { - s"$getLong(obj1, offset1 + ${i * 8})" + code"$getLong($obj1, $offset1 + ${i * 8})" } else if (i == bitset1Words - 1) { // combine last work of bitset1 and first word of bitset2 - s"$getLong(obj1, offset1 + ${i * 8}) | ($getLong(obj2, offset2) << $bitset1Remainder)" + val block = code"$getLong($obj1, $offset1 + ${i * 8}) |" + code"$block ($getLong($obj2, $offset2) << $bitset1Remainder)" } else if (i - bitset1Words < bitset2Words - 1) { // combine next two words of bitset2 - s"($getLong(obj2, offset2 + ${(i - bitset1Words) * 8}) >>> (64 - $bitset1Remainder))" + - s" | ($getLong(obj2, offset2 + ${(i - bitset1Words + 1) * 8}) << $bitset1Remainder)" + val block1 = code"($getLong($obj2, $offset2 + ${(i - bitset1Words) * 8}) >>>" + val block2 = code"(64 - $bitset1Remainder))" + val block3 = code"| ($getLong($obj2, $offset2 + ${(i - bitset1Words + 1) * 8})" + code"$block1 $block2 $block3 << $bitset1Remainder)" } else { // last word of bitset2 - s"$getLong(obj2, offset2 + ${(i - bitset1Words) * 8}) >>> (64 - $bitset1Remainder)" + code"$getLong($obj2, $offset2 + ${(i - bitset1Words) * 8}) >>> (64 - $bitset1Remainder)" } } else { // they are aligned by word if (i < bitset1Words) { - s"$getLong(obj1, offset1 + ${i * 8})" + code"$getLong($obj1, $offset1 + ${i * 8})" } else { - s"$getLong(obj2, offset2 + ${(i - bitset1Words) * 8})" + code"$getLong($obj2, $offset2 + ${(i - bitset1Words) * 8})" } } - s"$putLong(buf, ${offset + i * 8}, $bits);\n" + code"$putLong(buf, ${offset + i * 8}, $bits);\n" } val copyBitsets = ctx.splitExpressions( expressions = copyBitset, funcName = "copyBitsetFunc", - arguments = ("java.lang.Object", "obj1") :: ("long", "offset1") :: - ("java.lang.Object", "obj2") :: ("long", "offset2") :: Nil) + arguments = (obj1) :: (offset1) :: + (obj2) :: (offset2) :: Nil) // --------------------- copy fixed length portion from row 1 ----------------------- // var cursor = offset + outputBitsetWords * 8 @@ -124,13 +133,14 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U // --------------------- copy variable length portion from row 1 ----------------------- // val numBytesBitsetAndFixedRow1 = (bitset1Words + schema1.size) * 8 + val numBytesVariableRow1 = JavaCode.variable("numBytesVariableRow1", LongType) val copyVariableLengthRow1 = s""" |// Copy variable length data for row1 - |long numBytesVariableRow1 = row1.getSizeInBytes() - $numBytesBitsetAndFixedRow1; + |long $numBytesVariableRow1 = row1.getSizeInBytes() - $numBytesBitsetAndFixedRow1; |Platform.copyMemory( | obj1, offset1 + ${(bitset1Words + schema1.size) * 8}, | buf, $cursor, - | numBytesVariableRow1); + | $numBytesVariableRow1); """.stripMargin // --------------------- copy variable length portion from row 2 ----------------------- // @@ -148,7 +158,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U val updateOffset = (schema1 ++ schema2).zipWithIndex.map { case (field, i) => // Skip fixed length data types, and only generate code for variable length data if (UnsafeRow.isFixedLength(field.dataType)) { - "" + EmptyBlock } else { // Number of bytes to increase for the offset. Note that since in UnsafeRow we store the // offset in the upper 32 bit of the words, we can just shift the offset to the left by @@ -157,9 +167,10 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U // shift added to it. val shift = if (i < schema1.size) { - s"${(outputBitsetWords - bitset1Words + schema2.size) * 8}L" + code"${(outputBitsetWords - bitset1Words + schema2.size) * 8}L" } else { - s"(${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1)" + code"(${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + " + + code"$numBytesVariableRow1)" } val cursor = offset + outputBitsetWords * 8 + i * 8 // UnsafeRow is a little underspecified, so in what follows we'll treat UnsafeRowWriter's @@ -197,7 +208,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U // // Thus it is safe to perform `existingOffset != 0` checks here in the place of // more expensive null-bit checks. - s""" + code""" |existingOffset = $getLong(buf, $cursor); |if (existingOffset != 0) { | $putLong(buf, $cursor, existingOffset + ($shift << 32)); @@ -209,7 +220,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U val updateOffsets = ctx.splitExpressions( expressions = updateOffset, funcName = "copyBitsetFunc", - arguments = ("long", "numBytesVariableRow1") :: Nil, + arguments = (numBytesVariableRow1) :: Nil, makeSplitFunction = (s: String) => "long existingOffset;\n" + s) // ------------------------ Finally, put everything together --------------------------- // diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 250ce48d059e..a3aadd4c990b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import java.lang.{Boolean => JBool} +import java.lang.{Boolean => JBool, Double => JDouble, Float => JFloat, Integer => JInt, Long => JLong} import scala.collection.mutable.ArrayBuffer import scala.language.{existentials, implicitConversions} -import org.apache.spark.sql.types.{BooleanType, DataType} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Trait representing an opaque fragments of java code. @@ -36,6 +37,8 @@ trait JavaCode { * Utility functions for creating [[JavaCode]] fragments. */ object JavaCode { + import Block._ + /** * Create a java literal. */ @@ -51,7 +54,7 @@ object JavaCode { */ def defaultLiteral(dataType: DataType): LiteralValue = { new LiteralValue( - CodeGenerator.defaultValue(dataType, typedNull = true), + CodeGenerator.defaultValue(dataType, typedNull = true).code, CodeGenerator.javaClass(dataType)) } @@ -88,6 +91,7 @@ object JavaCode { GlobalValue(name, javaClass) } + /** * Create a global isNull variable. */ @@ -113,6 +117,41 @@ object JavaCode { def isNullExpression(code: String): SimpleExprValue = { expression(code, BooleanType) } + + /** + * We only allow values of four basic primitive types to be implicitly converted + * to `LiteralValue`. This conversion is convenient for interpolation in `Block`. + * We explicitly disallow string interpolation so we won't mistakently interpolate + * section of code as string and loss any references to `Block` and `ExprValue`. + */ + implicit def intToLiteral(i: Int): LiteralValue = + new LiteralValue(i.toString, JInt.TYPE) + implicit def longToLiteral(l: Long): LiteralValue = + new LiteralValue(l.toString, JLong.TYPE) + implicit def floatToLiteral(f: Float): LiteralValue = + new LiteralValue(f.toString, JFloat.TYPE) + implicit def doubleToLiteral(d: Double): LiteralValue = + new LiteralValue(d.toString, JDouble.TYPE) + + + val javaClassDataTypeMapping = Map[Class[_], DataType]( + classOf[Boolean] -> BooleanType, + classOf[Byte] -> ByteType, + classOf[Short] -> ShortType, + classOf[Int] -> IntegerType, + classOf[Long] -> LongType, + classOf[Float] -> FloatType, + classOf[Double] -> DoubleType, + classOf[UTF8String] -> StringType + ) + + /** + * Return an inline block of Java Type for given `ExprValue`. + */ + def javaType(expr: ExprValue): Block = { + val dt = javaClassDataTypeMapping.get(expr.javaType).getOrElse(ObjectType(expr.javaType)) + inline"${CodeGenerator.javaType(dt)}" + } } /** @@ -123,8 +162,7 @@ trait Block extends JavaCode { // The expressions to be evaluated inside this block. def exprValues: Set[ExprValue] - // Returns java code string for this code block. - override def toString: String = _marginChar match { + protected def doStripMargin(code: String): String = _marginChar match { case Some(c) => code.stripMargin(c).trim case _ => code.trim } @@ -134,14 +172,14 @@ trait Block extends JavaCode { def nonEmpty: Boolean = toString.nonEmpty // The leading prefix that should be stripped from each line. - // By default we strip blanks or control characters followed by '|' from the line. - var _marginChar: Option[Char] = Some('|') + var _marginChar: Option[Char] = None def stripMargin(c: Char): this.type = { _marginChar = Some(c) this } + // Strip blanks or control characters followed by '|' from the line when materialized the code. def stripMargin: this.type = { _marginChar = Some('|') this @@ -155,23 +193,35 @@ object Block { val CODE_BLOCK_BUFFER_LENGTH: Int = 512 + /** + * A custom string interpolator which inlines all types of input arguments into a string without + * tracking any reference of `JavaCode` instances. + */ + implicit class InlineHelper(val sc: StringContext) extends AnyVal { + def inline(args: Any*): Block = { + val inlineString = sc.raw(args: _*) + InlineBlock(inlineString) + } + } + implicit def blocksToBlock(blocks: Seq[Block]): Block = Blocks(blocks) + /** + * A custom string interpolator which allows only `JavaCode` instances to be interpolated into + * a Java code `Block`. `Block`s will be interpolated. `ExprValue`s will be interpolated too and + * also referneced in `exprValues` property of returned `Block`. + */ implicit class BlockHelper(val sc: StringContext) extends AnyVal { - def code(args: Any*): Block = { + def code(args: JavaCode*): Block = { sc.checkLengths(args) - if (sc.parts.length == 0) { + val processedArgs = args.map { + case l: LiteralValue => l.value + case other => other + } + val (codeParts, blockInputs) = foldLiteralArgs(sc.parts, processedArgs) + if (codeParts.length == 0) { EmptyBlock } else { - args.foreach { - case _: ExprValue => - case _: Int | _: Long | _: Float | _: Double | _: String => - case _: Block => - case other => throw new IllegalArgumentException( - s"Can not interpolate ${other.getClass.getName} into code block.") - } - - val (codeParts, blockInputs) = foldLiteralArgs(sc.parts, args) CodeBlock(codeParts, blockInputs) } } @@ -186,7 +236,7 @@ object Block { val inputs = args.iterator val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH) - buf.append(strings.next) + buf.append(StringContext.treatEscapes(strings.next)) while (strings.hasNext) { val input = inputs.next input match { @@ -197,16 +247,26 @@ object Block { case _ => buf.append(input) } - buf.append(strings.next) - } - if (buf.nonEmpty) { - codeParts += buf.toString + buf.append(StringContext.treatEscapes(strings.next)) } + codeParts += buf.toString (codeParts.toSeq, blockInputs.toSeq) } } +case class InlineBlock(block: String) extends Block { + override val code: String = block + override val exprValues: Set[ExprValue] = Set.empty + + override def + (other: Block): Block = other match { + case c: CodeBlock => Blocks(Seq(this, c)) + case i: InlineBlock => InlineBlock(block + i.block) + case b: Blocks => Blocks(Seq(this) ++ b.blocks) + case EmptyBlock => this + } +} + /** * A block of java code. Including a sequence of code parts and some inputs to this block. * The actual java code is generated by embedding the inputs into the code parts. @@ -225,14 +285,15 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH) buf.append(StringContext.treatEscapes(strings.next)) while (strings.hasNext) { - buf.append(inputs.next) + buf.append(inputs.next.code) buf.append(StringContext.treatEscapes(strings.next)) } - buf.toString + doStripMargin(buf.toString) } override def + (other: Block): Block = other match { case c: CodeBlock => Blocks(Seq(this, c)) + case i: InlineBlock => Blocks(Seq(this, i)) case b: Blocks => Blocks(Seq(this) ++ b.blocks) case EmptyBlock => this } @@ -240,10 +301,11 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends case class Blocks(blocks: Seq[Block]) extends Block { override lazy val exprValues: Set[ExprValue] = blocks.flatMap(_.exprValues).toSet - override lazy val code: String = blocks.map(_.toString).mkString("\n") + override lazy val code: String = doStripMargin(blocks.map(_.toString).mkString("\n")) override def + (other: Block): Block = other match { case c: CodeBlock => Blocks(blocks :+ c) + case i: InlineBlock => Blocks(blocks :+ i) case b: Blocks => Blocks(blocks ++ b.blocks) case EmptyBlock => this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 03b3b21a1661..b73f98679d36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -95,7 +95,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType ev.copy(code = code""" boolean ${ev.isNull} = false; ${childGen.code} - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : + ${inline"${CodeGenerator.javaType(dataType)}"} ${ev.value} = ${childGen.isNull} ? -1 : (${childGen.value}).numElements();""", isNull = FalseLiteral) } } @@ -122,7 +122,7 @@ case class MapKeys(child: Expression) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).keyArray();") + nullSafeCodeGen(ctx, ev, c => code"${ev.value} = ($c).keyArray();") } override def prettyName: String = "map_keys" @@ -150,7 +150,7 @@ case class MapValues(child: Expression) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).valueArray();") + nullSafeCodeGen(ctx, ev, c => code"${ev.value} = ($c).valueArray();") } override def prettyName: String = "map_values" @@ -201,17 +201,17 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => { - val numElements = ctx.freshName("numElements") - val keys = ctx.freshName("keys") - val values = ctx.freshName("values") + val numElements = JavaCode.variable(ctx.freshName("numElements"), IntegerType) + val keys = JavaCode.variable(ctx.freshName("keys"), classOf[ArrayData]) + val values = JavaCode.variable(ctx.freshName("values"), classOf[ArrayData]) val isKeyPrimitive = CodeGenerator.isPrimitiveType(childDataType.keyType) val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) val code = if (isKeyPrimitive && isValuePrimitive) { - genCodeForPrimitiveElements(ctx, keys, values, ev.value, numElements) + genCodeForPrimitiveElements(ctx, keys, values, ev.value, numElements) } else { - genCodeForAnyElements(ctx, keys, values, ev.value, numElements) + genCodeForAnyElements(ctx, keys, values, ev.value, numElements) } - s""" + code""" |final int $numElements = $c.numElements(); |final ArrayData $keys = $c.keyArray(); |final ArrayData $values = $c.valueArray(); @@ -220,33 +220,36 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp }) } - private def getKey(varName: String) = CodeGenerator.getValue(varName, childDataType.keyType, "z") + private def getKey(varName: ExprValue, loopCounter: ExprValue) = + CodeGenerator.getValue(varName, childDataType.keyType, loopCounter) - private def getValue(varName: String) = { - CodeGenerator.getValue(varName, childDataType.valueType, "z") + private def getValue(varName: ExprValue, loopCounter: ExprValue) = { + CodeGenerator.getValue(varName, childDataType.valueType, loopCounter) } private def genCodeForPrimitiveElements( ctx: CodegenContext, - keys: String, - values: String, - arrayData: String, - numElements: String): String = { - val unsafeRow = ctx.freshName("unsafeRow") - val unsafeArrayData = ctx.freshName("unsafeArrayData") - val structsOffset = ctx.freshName("structsOffset") - val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes" + keys: ExprValue, + values: ExprValue, + arrayData: ExprValue, + numElements: ExprValue): Block = { + val unsafeRow = JavaCode.variable(ctx.freshName("unsafeRow"), classOf[UnsafeRow]) + val unsafeArrayData = JavaCode.variable(ctx.freshName("unsafeArrayData"), + classOf[UnsafeArrayData]) + val structsOffset = JavaCode.variable(ctx.freshName("structsOffset"), IntegerType) + val calculateHeader = inline"UnsafeArrayData.calculateHeaderPortionInBytes" val baseOffset = Platform.BYTE_ARRAY_OFFSET val wordSize = UnsafeRow.WORD_SIZE val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2 - val structSizeAsLong = structSize + "L" - val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) - val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) + val structSizeAsLong = JavaCode.literal(structSize + "L", LongType) + val keyTypeName = inline"${CodeGenerator.primitiveTypeName(childDataType.keyType)}" + val valueTypeName = inline"${CodeGenerator.primitiveTypeName(childDataType.keyType)}" + val z = JavaCode.variable("z", IntegerType) - val valueAssignment = s"$unsafeRow.set$valueTypeName(1, ${getValue(values)});" + val valueAssignment = code"$unsafeRow.set$valueTypeName(1, ${getValue(values, z)});" val valueAssignmentChecked = if (childDataType.valueContainsNull) { - s""" + code""" |if ($values.isNullAt(z)) { | $unsafeRow.setNullAt(1); |} else { @@ -257,15 +260,15 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp valueAssignment } - val assignmentLoop = (byteArray: String) => - s""" + val assignmentLoop = (byteArray: ExprValue) => + code""" |final int $structsOffset = $calculateHeader($numElements) + $numElements * $wordSize; |UnsafeRow $unsafeRow = new UnsafeRow(2); - |for (int z = 0; z < $numElements; z++) { - | long offset = $structsOffset + z * $structSizeAsLong; - | $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong); + |for (int $z = 0; $z < $numElements; $z++) { + | long offset = $structsOffset + $z * $structSizeAsLong; + | $unsafeArrayData.setLong($z, (offset << 32) + $structSizeAsLong); | $unsafeRow.pointTo($byteArray, $baseOffset + offset, $structSize); - | $unsafeRow.set$keyTypeName(0, ${getKey(keys)}); + | $unsafeRow.set$keyTypeName(0, ${getKey(keys, z)}); | $valueAssignmentChecked |} |$arrayData = $unsafeArrayData; @@ -281,25 +284,26 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp private def genCodeForAnyElements( ctx: CodegenContext, - keys: String, - values: String, - arrayData: String, - numElements: String): String = { - val genericArrayClass = classOf[GenericArrayData].getName - val rowClass = classOf[GenericInternalRow].getName - val data = ctx.freshName("internalRowArray") + keys: ExprValue, + values: ExprValue, + arrayData: ExprValue, + numElements: ExprValue): Block = { + val genericArrayClass = inline"${classOf[GenericArrayData].getName}" + val rowClass = inline"${classOf[GenericInternalRow].getName}" + val data = JavaCode.variable(ctx.freshName("internalRowArray"), classOf[Array[Object]]) + val z = JavaCode.variable("z", IntegerType) val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) { - s"$values.isNullAt(z) ? null : (Object)${getValue(values)}" + code"$values.isNullAt(z) ? null : (Object)${getValue(values, z)}" } else { - getValue(values) + getValue(values, z) } - s""" + code""" |final Object[] $data = new Object[$numElements]; |for (int z = 0; z < $numElements; z++) { - | $data[z] = new $rowClass(new Object[]{${getKey(keys)}, $getValueWithCheck}); + | $data[z] = new $rowClass(new Object[]{${getKey(keys, z)}, $getValueWithCheck}); |} |$arrayData = new $genericArrayClass($data); """.stripMargin @@ -373,37 +377,40 @@ trait ArraySortLike extends ExpectsInputTypes { new GenericArrayData(data.asInstanceOf[Array[Any]]) } - def sortCodegen(ctx: CodegenContext, ev: ExprCode, base: String, order: String): String = { - val arrayData = classOf[ArrayData].getName - val genericArrayData = classOf[GenericArrayData].getName - val unsafeArrayData = classOf[UnsafeArrayData].getName - val array = ctx.freshName("array") - val c = ctx.freshName("c") + def sortCodegen(ctx: CodegenContext, ev: ExprCode, base: ExprValue, order: ExprValue): Block = { + val arrayData = inline"${classOf[ArrayData].getName}" + val genericArrayData = inline"${classOf[GenericArrayData].getName}" + val unsafeArrayData = inline"${classOf[UnsafeArrayData].getName}" + val array = JavaCode.variable(ctx.freshName("array"), classOf[Array[Object]]) + val c = JavaCode.variable(ctx.freshName("c"), IntegerType) if (elementType == NullType) { - s"${ev.value} = $base.copy();" + code"${ev.value} = $base.copy();" } else { - val elementTypeTerm = ctx.addReferenceObj("elementTypeTerm", elementType) - val sortOrder = ctx.freshName("sortOrder") - val o1 = ctx.freshName("o1") - val o2 = ctx.freshName("o2") - val jt = CodeGenerator.javaType(elementType) + val elementTypeTerm = JavaCode.global( + ctx.addReferenceObj("elementTypeTerm", elementType), elementType) + val sortOrder = JavaCode.variable(ctx.freshName("sortOrder"), IntegerType) + val o1 = JavaCode.variable(ctx.freshName("o1"), classOf[Object]) + val o2 = JavaCode.variable(ctx.freshName("o2"), classOf[Object]) + val jt = inline"${CodeGenerator.javaType(elementType)}" val comp = if (CodeGenerator.isPrimitiveType(elementType)) { - val bt = CodeGenerator.boxedType(elementType) - val v1 = ctx.freshName("v1") - val v2 = ctx.freshName("v2") - s""" + val bt = inline"${CodeGenerator.boxedType(elementType)}" + val v1 = JavaCode.variable(ctx.freshName("v1"), elementType) + val v2 = JavaCode.variable(ctx.freshName("v2"), elementType) + code""" |$jt $v1 = (($bt) $o1).${jt}Value(); |$jt $v2 = (($bt) $o2).${jt}Value(); |int $c = ${ctx.genComp(elementType, v1, v2)}; """.stripMargin } else { - s"int $c = ${ctx.genComp(elementType, s"(($jt) $o1)", s"(($jt) $o2)")};" + val cmp = ctx.genComp(elementType, JavaCode.expression(s"(($jt) $o1)", elementType), + JavaCode.expression(s"(($jt) $o2)", elementType)) + code"int $c = $cmp;" } val nonNullPrimitiveAscendingSort = if (CodeGenerator.isPrimitiveType(elementType) && !containsNull) { - val javaType = CodeGenerator.javaType(elementType) - val primitiveTypeName = CodeGenerator.primitiveTypeName(elementType) - s""" + val javaType = inline"${CodeGenerator.javaType(elementType)}" + val primitiveTypeName = inline"${CodeGenerator.primitiveTypeName(elementType)}" + code""" |if ($order) { | $javaType[] $array = $base.to${primitiveTypeName}Array(); | java.util.Arrays.sort($array); @@ -411,9 +418,9 @@ trait ArraySortLike extends ExpectsInputTypes { |} else """.stripMargin } else { - "" + EmptyBlock } - s""" + code""" |$nonNullPrimitiveAscendingSort |{ | Object[] $array = $base.toObjectArray($elementTypeTerm); @@ -550,7 +557,7 @@ case class ArraySort(child: Expression) extends UnaryExpression with ArraySortLi } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, c => sortCodegen(ctx, ev, c, "true")) + nullSafeCodeGen(ctx, ev, c => sortCodegen(ctx, ev, c, TrueLiteral)) } override def prettyName: String = "array_sort" @@ -592,27 +599,32 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI }) } - private def stringCodeGen(ev: ExprCode, childName: String): String = { - s"${ev.value} = ($childName).reverse();" + private def stringCodeGen(ev: ExprCode, childName: ExprValue): Block = { + code"${ev.value} = ($childName).reverse();" } - private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = { - val length = ctx.freshName("length") - val javaElementType = CodeGenerator.javaType(elementType) + private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: ExprValue): Block = { + val length = JavaCode.variable(ctx.freshName("length"), IntegerType) + val javaElementType = inline"${CodeGenerator.javaType(elementType)}" val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType) val initialization = if (isPrimitiveType) { - s"$childName.copy()" + code"$childName.copy()" } else { - s"new ${classOf[GenericArrayData].getName()}(new Object[$length])" + val arrayDataType = inline"${classOf[GenericArrayData].getName()}" + code"new $arrayDataType(new Object[$length])" } - val numberOfIterations = if (isPrimitiveType) s"$length / 2" else length + val numberOfIterations = if (isPrimitiveType) { + JavaCode.expression(s"$length / 2", IntegerType) + } else { + length + } val swapAssigments = if (isPrimitiveType) { - val setFunc = "set" + CodeGenerator.primitiveTypeName(elementType) + val setFunc = inline"${"set" + CodeGenerator.primitiveTypeName(elementType)}" val getCall = (index: String) => CodeGenerator.getValue(ev.value, elementType, index) - s"""|boolean isNullAtK = ${ev.value}.isNullAt(k); + code"""|boolean isNullAtK = ${ev.value}.isNullAt(k); |boolean isNullAtL = ${ev.value}.isNullAt(l); |if(!isNullAtK) { | $javaElementType el = ${getCall("k")}; @@ -627,10 +639,10 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI | ${ev.value}.setNullAt(l); |}""".stripMargin } else { - s"${ev.value}.update(k, ${CodeGenerator.getValue(childName, elementType, "l")});" + code"${ev.value}.update(k, ${CodeGenerator.getValue(childName, elementType, "l")});" } - s""" + code""" |final int $length = $childName.numElements(); |${ev.value} = $initialization; |for(int k = 0; k < $numberOfIterations; k++) { @@ -703,9 +715,9 @@ case class ArrayContains(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (arr, value) => { - val i = ctx.freshName("i") + val i = JavaCode.variable(ctx.freshName("i"), IntegerType) val getValue = CodeGenerator.getValue(arr, right.dataType, i) - s""" + code""" for (int $i = 0; $i < $arr.numElements(); $i ++) { if ($arr.isNullAt($i)) { ${ev.isNull} = true; @@ -832,14 +844,14 @@ case class ArraysOverlap(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (a1, a2) => { - val smaller = ctx.freshName("smallerArray") - val bigger = ctx.freshName("biggerArray") + val smaller = JavaCode.variable(ctx.freshName("smallerArray"), classOf[ArrayData]) + val bigger = JavaCode.variable(ctx.freshName("biggerArray"), classOf[ArrayData]) val comparisonCode = if (elementTypeSupportEquals) { fastCodegen(ctx, ev, smaller, bigger) } else { bruteForceCodegen(ctx, ev, smaller, bigger) } - s""" + code""" |ArrayData $smaller; |ArrayData $bigger; |if ($a1.numElements() > $a2.numElements()) { @@ -861,27 +873,28 @@ case class ArraysOverlap(left: Expression, right: Expression) * in a set and then performs a lookup on it for each element of the bigger one. * It works only for data types which implements properly the equals method. */ - private def fastCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = { - val i = ctx.freshName("i") + private def fastCodegen(ctx: CodegenContext, ev: ExprCode, smaller: ExprValue, + bigger: ExprValue): Block = { + val i = JavaCode.variable(ctx.freshName("i"), IntegerType) val getFromSmaller = CodeGenerator.getValue(smaller, elementType, i) val getFromBigger = CodeGenerator.getValue(bigger, elementType, i) - val javaElementClass = CodeGenerator.boxedType(elementType) - val javaSet = classOf[java.util.HashSet[_]].getName - val set = ctx.freshName("set") + val javaElementClass = inline"${CodeGenerator.boxedType(elementType)}" + val javaSet = inline"${classOf[java.util.HashSet[_]].getName}" + val set = JavaCode.variable(ctx.freshName("set"), classOf[java.util.HashSet[_]]) val addToSetFromSmallerCode = nullSafeElementCodegen( - smaller, i, s"$set.add($getFromSmaller);", s"${ev.isNull} = true;") + smaller, i, code"$set.add($getFromSmaller);", code"${ev.isNull} = true;") val elementIsInSetCode = nullSafeElementCodegen( bigger, i, - s""" + code""" |if ($set.contains($getFromBigger)) { | ${ev.isNull} = false; | ${ev.value} = true; | break; |} """.stripMargin, - s"${ev.isNull} = true;") - s""" + code"${ev.isNull} = true;") + code""" |$javaSet<$javaElementClass> $set = new $javaSet<$javaElementClass>(); |for (int $i = 0; $i < $smaller.numElements(); $i ++) { | $addToSetFromSmallerCode @@ -895,31 +908,32 @@ case class ArraysOverlap(left: Expression, right: Expression) /** * Code generation for a slower evaluation which performs a nested loop and supports all the data types. */ - private def bruteForceCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = { - val i = ctx.freshName("i") - val j = ctx.freshName("j") + private def bruteForceCodegen(ctx: CodegenContext, ev: ExprCode, smaller: ExprValue, + bigger: ExprValue): Block = { + val i = JavaCode.variable(ctx.freshName("i"), IntegerType) + val j = JavaCode.variable(ctx.freshName("j"), IntegerType) val getFromSmaller = CodeGenerator.getValue(smaller, elementType, j) val getFromBigger = CodeGenerator.getValue(bigger, elementType, i) val compareValues = nullSafeElementCodegen( smaller, j, - s""" + code""" |if (${ctx.genEqual(elementType, getFromSmaller, getFromBigger)}) { | ${ev.isNull} = false; | ${ev.value} = true; |} """.stripMargin, - s"${ev.isNull} = true;") + code"${ev.isNull} = true;") val isInSmaller = nullSafeElementCodegen( bigger, i, - s""" + code""" |for (int $j = 0; $j < $smaller.numElements() && !${ev.value}; $j ++) { | $compareValues |} """.stripMargin, - s"${ev.isNull} = true;") - s""" + code"${ev.isNull} = true;") + code""" |for (int $i = 0; $i < $bigger.numElements() && !${ev.value}; $i ++) { | $isInSmaller |} @@ -927,12 +941,12 @@ case class ArraysOverlap(left: Expression, right: Expression) } def nullSafeElementCodegen( - arrayVar: String, - index: String, - code: String, - isNullCode: String): String = { + arrayVar: ExprValue, + index: ExprValue, + code: Block, + isNullCode: Block): Block = { if (inputTypes.exists(_.asInstanceOf[ArrayType].containsNull)) { - s""" + code""" |if ($arrayVar.isNullAt($index)) { | $isNullCode |} else { @@ -999,14 +1013,16 @@ case class Slice(x: Expression, start: Expression, length: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (x, start, length) => { - val startIdx = ctx.freshName("startIdx") - val resLength = ctx.freshName("resLength") + val startIdx = JavaCode.variable(ctx.freshName("startIdx"), IntegerType) + val resLength = JavaCode.variable(ctx.freshName("resLength"), IntegerType) + val intType = inline"${CodeGenerator.JAVA_INT}" val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) - s""" - |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue; - |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue; + val nameOfFunction = inline"${prettyName}" + code""" + |$intType $startIdx = $defaultIntValue; + |$intType $resLength = $defaultIntValue; |if ($start == 0) { - | throw new RuntimeException("Unexpected value for start in function $prettyName: " + | throw new RuntimeException("Unexpected value for start in function $nameOfFunction: " | + "SQL array indices start at 1."); |} else if ($start < 0) { | $startIdx = $start + $x.numElements(); @@ -1015,7 +1031,7 @@ case class Slice(x: Expression, start: Expression, length: Expression) | $startIdx = $start - 1; |} |if ($length < 0) { - | throw new RuntimeException("Unexpected value for length in function $prettyName: " + | throw new RuntimeException("Unexpected value for length in function $nameOfFunction: " | + "length must be greater than or equal to 0."); |} else if ($length > $x.numElements() - $startIdx) { | $resLength = $x.numElements() - $startIdx; @@ -1030,15 +1046,15 @@ case class Slice(x: Expression, start: Expression, length: Expression) def genCodeForResult( ctx: CodegenContext, ev: ExprCode, - inputArray: String, - startIdx: String, - resLength: String): String = { - val values = ctx.freshName("values") - val i = ctx.freshName("i") + inputArray: ExprValue, + startIdx: ExprValue, + resLength: ExprValue): Block = { + val values = JavaCode.variable(ctx.freshName("values"), classOf[Array[Object]]) + val i = JavaCode.variable(ctx.freshName("i"), IntegerType) val getValue = CodeGenerator.getValue(inputArray, elementType, s"$i + $startIdx") if (!CodeGenerator.isPrimitiveType(elementType)) { - val arrayClass = classOf[GenericArrayData].getName - s""" + val arrayClass = inline"${classOf[GenericArrayData].getName}" + code""" |Object[] $values; |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { | $values = new Object[0]; @@ -1051,12 +1067,14 @@ case class Slice(x: Expression, start: Expression, length: Expression) |${ev.value} = new $arrayClass($values); """.stripMargin } else { - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - s""" + val primitiveValueTypeName = inline"${CodeGenerator.primitiveTypeName(elementType)}" + val createUnsafeArray = ctx.createUnsafeArray(values, resLength, elementType, + JavaCode.literal(s" $prettyName failed.", StringType)) + code""" |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { | $resLength = 0; |} - |${ctx.createUnsafeArray(values, resLength, elementType, s" $prettyName failed.")} + |$createUnsafeArray |for (int $i = 0; $i < $resLength; $i ++) { | if ($inputArray.isNullAt($i + $startIdx)) { | $values.setNullAt($i); @@ -1153,8 +1171,8 @@ case class ArrayJoin( val code = nullReplacement match { case Some(replacement) => val replacementGen = replacement.genCode(ctx) - val nullHandling = (buffer: String, delimiter: String, firstItem: String) => { - s""" + val nullHandling = (buffer: ExprValue, delimiter: ExprValue, firstItem: ExprValue) => { + code""" |if (!$firstItem) { | $buffer.append($delimiter); |} @@ -1169,12 +1187,12 @@ case class ArrayJoin( } else { genCodeForArrayAndDelimiter(ctx, ev, nullHandling) } - s""" + code""" |${replacementGen.code} |$execCode """.stripMargin case None => genCodeForArrayAndDelimiter(ctx, ev, - (_: String, _: String, _: String) => "// nulls are ignored") + (_: ExprValue, _: ExprValue, _: ExprValue) => code"// nulls are ignored") } if (nullable) { ev.copy( @@ -1195,15 +1213,15 @@ case class ArrayJoin( private def genCodeForArrayAndDelimiter( ctx: CodegenContext, ev: ExprCode, - nullEval: (String, String, String) => String): String = { + nullEval: (ExprValue, ExprValue, ExprValue) => Block): Block = { val arrayGen = array.genCode(ctx) val delimiterGen = delimiter.genCode(ctx) - val buffer = ctx.freshName("buffer") - val bufferClass = classOf[UTF8StringBuilder].getName - val i = ctx.freshName("i") - val firstItem = ctx.freshName("firstItem") + val buffer = JavaCode.variable(ctx.freshName("buffer"), classOf[UTF8StringBuilder]) + val bufferClass = inline"${classOf[UTF8StringBuilder].getName}" + val i = JavaCode.variable(ctx.freshName("i"), IntegerType) + val firstItem = JavaCode.variable(ctx.freshName("firstItem"), BooleanType) val resultCode = - s""" + code""" |$bufferClass $buffer = new $bufferClass(); |boolean $firstItem = true; |for (int $i = 0; $i < ${arrayGen.value}.numElements(); $i ++) { @@ -1222,13 +1240,13 @@ case class ArrayJoin( if (array.nullable || delimiter.nullable) { arrayGen.code + ctx.nullSafeExec(array.nullable, arrayGen.isNull) { delimiterGen.code + ctx.nullSafeExec(delimiter.nullable, delimiterGen.isNull) { - s""" + code""" |${ev.isNull} = false; |$resultCode""".stripMargin } } } else { - s""" + code""" |${arrayGen.code} |${delimiterGen.code} |$resultCode""".stripMargin @@ -1268,8 +1286,8 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childGen = child.genCode(ctx) - val javaType = CodeGenerator.javaType(dataType) - val i = ctx.freshName("i") + val javaType = inline"${CodeGenerator.javaType(dataType)}" + val i = JavaCode.variable(ctx.freshName("i"), IntegerType) val item = ExprCode(EmptyBlock, isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) @@ -1333,8 +1351,8 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childGen = child.genCode(ctx) - val javaType = CodeGenerator.javaType(dataType) - val i = ctx.freshName("i") + val javaType = inline"${CodeGenerator.javaType(dataType)}" + val i = JavaCode.variable(ctx.freshName("i"), IntegerType) val item = ExprCode(EmptyBlock, isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) @@ -1419,10 +1437,10 @@ case class ArrayPosition(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (arr, value) => { - val pos = ctx.freshName("arrayPosition") - val i = ctx.freshName("i") + val pos = JavaCode.variable(ctx.freshName("arrayPosition"), IntegerType) + val i = JavaCode.variable(ctx.freshName("i"), IntegerType) val getValue = CodeGenerator.getValue(arr, right.dataType, i) - s""" + code""" |int $pos = 0; |for (int $i = 0; $i < $arr.numElements(); $i ++) { | if (!$arr.isNullAt($i) && ${ctx.genEqual(right.dataType, value, getValue)}) { @@ -1517,17 +1535,17 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti left.dataType match { case _: ArrayType => nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - val index = ctx.freshName("elementAtIndex") + val index = JavaCode.variable(ctx.freshName("elementAtIndex"), IntegerType) val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) { - s""" + code""" |if ($eval1.isNullAt($index)) { | ${ev.isNull} = true; |} else """.stripMargin } else { - "" + EmptyBlock } - s""" + code""" |int $index = (int) $eval2; |if ($eval1.numElements() < Math.abs($index)) { | ${ev.isNull} = true; @@ -1627,10 +1645,15 @@ case class Concat(children: Seq[Expression]) extends Expression { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evals = children.map(_.genCode(ctx)) - val args = ctx.freshName("args") + val argsType = dataType match { + case BinaryType => classOf[Array[Array[java.lang.Byte]]] + case StringType => classOf[Array[UTF8String]] + case ArrayType(elementType, _) => classOf[Array[ArrayData]] + } + val args = JavaCode.variable(ctx.freshName("args"), argsType) val inputs = evals.zipWithIndex.map { case (eval, index) => - s""" + code""" ${eval.code} if (!${eval.isNull}) { $args[$index] = ${eval.value}; @@ -1640,21 +1663,24 @@ case class Concat(children: Seq[Expression]) extends Expression { val (concatenator, initCode) = dataType match { case BinaryType => - (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") + (inline"${classOf[ByteArray].getName}", + code"byte[][] $args = new byte[${evals.length}][];") case StringType => - ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") + (code"UTF8String", code"UTF8String[] $args = new UTF8String[${evals.length}];") case ArrayType(elementType, _) => val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) { genCodeForPrimitiveArrays(ctx, elementType) } else { genCodeForNonPrimitiveArrays(ctx, elementType) } - (arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];") + (arrayConcatClass, + code"ArrayData[] $args = new ArrayData[${evals.length}];") } val codes = ctx.splitExpressionsWithCurrentInputs( expressions = inputs, funcName = "valueConcat", - extraArguments = (s"$javaType[]", args) :: Nil) + extraArguments = args :: Nil) + val javaType = inline"${CodeGenerator.javaType(dataType)}" ev.copy(code""" $initCode $codes @@ -1663,9 +1689,9 @@ case class Concat(children: Seq[Expression]) extends Expression { """) } - private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { - val numElements = ctx.freshName("numElements") - val code = s""" + private def genCodeForNumberOfElements(ctx: CodegenContext) : (Block, ExprValue) = { + val numElements = JavaCode.variable(ctx.freshName("numElements"), LongType) + val code = code""" |long $numElements = 0L; |for (int z = 0; z < ${children.length}; z++) { | $numElements += args[z].numElements(); @@ -1679,41 +1705,45 @@ case class Concat(children: Seq[Expression]) extends Expression { (code, numElements) } - private def nullArgumentProtection() : String = { + private def nullArgumentProtection(): Block = { if (nullable) { - s""" + code""" |for (int z = 0; z < ${children.length}; z++) { | if (args[z] == null) return null; |} """.stripMargin } else { - "" + EmptyBlock } } - private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { - val counter = ctx.freshName("counter") - val arrayData = ctx.freshName("arrayData") + private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): Block = { + val arrVariable = JavaCode.variable("args[y]", dataType) + val z = JavaCode.variable("z", IntegerType) + val counter = JavaCode.variable(ctx.freshName("counter"), IntegerType) + val arrayData = JavaCode.variable(ctx.freshName("arrayData"), classOf[UnsafeArrayData]) val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + val primitiveValueTypeName = inline"${CodeGenerator.primitiveTypeName(elementType)}" + val errorMessage = JavaCode.literal(s" $prettyName failed.", StringType) + val javaTypeStr = inline"${javaType}" - s""" + code""" |new Object() { - | public ArrayData concat($javaType[] args) { + | public ArrayData concat($javaTypeStr[] args) { | ${nullArgumentProtection()} | $numElemCode - | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} + | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, errorMessage)} | int $counter = 0; | for (int y = 0; y < ${children.length}; y++) { - | for (int z = 0; z < args[y].numElements(); z++) { - | if (args[y].isNullAt(z)) { + | for (int $z = 0; $z < $arrVariable.numElements(); $z++) { + | if ($arrVariable.isNullAt($z)) { | $arrayData.setNullAt($counter); | } else { | $arrayData.set$primitiveValueTypeName( | $counter, - | ${CodeGenerator.getValue(s"args[y]", elementType, "z")} + | ${CodeGenerator.getValue(arrVariable, elementType, z)} | ); | } | $counter++; @@ -1721,32 +1751,34 @@ case class Concat(children: Seq[Expression]) extends Expression { | } | return $arrayData; | } - |}""".stripMargin.stripPrefix("\n") + |}""".stripMargin } - private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { - val genericArrayClass = classOf[GenericArrayData].getName - val arrayData = ctx.freshName("arrayObjects") - val counter = ctx.freshName("counter") + private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): Block = { + val genericArrayClass = inline"${classOf[GenericArrayData].getName}" + val arrayData = JavaCode.variable(ctx.freshName("arrayObjects"), classOf[Array[Object]]) + val counter = JavaCode.variable(ctx.freshName("counter"), IntegerType) + val arg = JavaCode.variable("args[y]", dataType) + val z = JavaCode.variable("z", IntegerType) val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) - s""" + code""" |new Object() { - | public ArrayData concat($javaType[] args) { + | public ArrayData concat(${inline"$javaType"}[] args) { | ${nullArgumentProtection()} | $numElemCode | Object[] $arrayData = new Object[(int)$numElemName]; | int $counter = 0; | for (int y = 0; y < ${children.length}; y++) { - | for (int z = 0; z < args[y].numElements(); z++) { - | $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")}; + | for (int $z = 0; $z < args[y].numElements(); $z++) { + | $arrayData[$counter] = ${CodeGenerator.getValue(arg, elementType, z)}; | $counter++; | } | } | return new $genericArrayClass($arrayData); | } - |}""".stripMargin.stripPrefix("\n") + |}""".stripMargin } override def toString: String = s"concat(${children.mkString(", ")})" @@ -1823,9 +1855,9 @@ case class Flatten(child: Expression) extends UnaryExpression { private def nullElementsProtection( ev: ExprCode, - childVariableName: String, - coreLogic: String): String = { - s""" + childVariableName: ExprValue, + coreLogic: Block): Block = { + code""" |for (int z = 0; !${ev.isNull} && z < $childVariableName.numElements(); z++) { | ${ev.isNull} |= $childVariableName.isNullAt(z); |} @@ -1837,9 +1869,9 @@ case class Flatten(child: Expression) extends UnaryExpression { private def genCodeForNumberOfElements( ctx: CodegenContext, - childVariableName: String) : (String, String) = { - val variableName = ctx.freshName("numElements") - val code = s""" + childVariableName: ExprValue) : (Block, ExprValue) = { + val variableName = JavaCode.variable(ctx.freshName("numElements"), LongType) + val code = code""" |long $variableName = 0; |for (int z = 0; z < $childVariableName.numElements(); z++) { | $variableName += $childVariableName.getArray(z).numElements(); @@ -1854,28 +1886,33 @@ case class Flatten(child: Expression) extends UnaryExpression { private def genCodeForFlattenOfPrimitiveElements( ctx: CodegenContext, - childVariableName: String, - arrayDataName: String): String = { - val counter = ctx.freshName("counter") - val tempArrayDataName = ctx.freshName("tempArrayData") + childVariableName: ExprValue, + arrayDataName: ExprValue): Block = { + val arrVariable = JavaCode.variable("arr", classOf[ArrayData]) + val loopCounter = JavaCode.variable("l", IntegerType) + val counter = JavaCode.variable(ctx.freshName("counter"), IntegerType) + val tempArrayDataName = JavaCode.variable(ctx.freshName("tempArrayData"), + classOf[UnsafeArrayData]) val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + val primitiveValueTypeName = inline"${CodeGenerator.primitiveTypeName(elementType)}" - s""" + val createUnsafeArray = ctx.createUnsafeArray(tempArrayDataName, numElemName, + elementType, JavaCode.literal(s" $prettyName failed.", StringType)) + code""" |$numElemCode - |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, s" $prettyName failed.")} + |$createUnsafeArray |int $counter = 0; |for (int k = 0; k < $childVariableName.numElements(); k++) { - | ArrayData arr = $childVariableName.getArray(k); - | for (int l = 0; l < arr.numElements(); l++) { - | if (arr.isNullAt(l)) { + | ArrayData $arrVariable = $childVariableName.getArray(k); + | for (int $loopCounter = 0; $loopCounter < $arrVariable.numElements(); $loopCounter++) { + | if ($arrVariable.isNullAt(l)) { | $tempArrayDataName.setNullAt($counter); | } else { | $tempArrayDataName.set$primitiveValueTypeName( | $counter, - | ${CodeGenerator.getValue("arr", elementType, "l")} + | ${CodeGenerator.getValue(arrVariable, elementType, loopCounter)} | ); | } | $counter++; @@ -1887,21 +1924,23 @@ case class Flatten(child: Expression) extends UnaryExpression { private def genCodeForFlattenOfNonPrimitiveElements( ctx: CodegenContext, - childVariableName: String, - arrayDataName: String): String = { - val genericArrayClass = classOf[GenericArrayData].getName - val arrayName = ctx.freshName("arrayObject") - val counter = ctx.freshName("counter") + childVariableName: ExprValue, + arrayDataName: ExprValue): Block = { + val genericArrayClass = inline"${classOf[GenericArrayData].getName}" + val arrVariable = JavaCode.variable("arr", classOf[ArrayData]) + val arrayName = JavaCode.variable(ctx.freshName("arrayObject"), classOf[Array[Object]]) + val counter = JavaCode.variable(ctx.freshName("counter"), IntegerType) + val loopCounter = JavaCode.variable("l", IntegerType) val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) - s""" + code""" |$numElemCode |Object[] $arrayName = new Object[(int)$numElemName]; |int $counter = 0; |for (int k = 0; k < $childVariableName.numElements(); k++) { - | ArrayData arr = $childVariableName.getArray(k); - | for (int l = 0; l < arr.numElements(); l++) { - | $arrayName[$counter] = ${CodeGenerator.getValue("arr", elementType, "l")}; + | ArrayData $arrVariable = $childVariableName.getArray(k); + | for (int $loopCounter = 0; $loopCounter < $arrVariable.numElements(); $loopCounter++) { + | $arrayName[$counter] = ${CodeGenerator.getValue(arrVariable, elementType, loopCounter)}; | $counter++; | } |} @@ -1969,7 +2008,7 @@ case class ArrayRepeat(left: Expression, right: Expression) |boolean ${ev.isNull} = false; |${leftGen.code} |${rightGen.code} - |${CodeGenerator.javaType(dataType)} ${ev.value} = + |${inline"${CodeGenerator.javaType(dataType)}"} ${ev.value} = | ${CodeGenerator.defaultValue(dataType)}; |$resultCode """.stripMargin) @@ -1977,10 +2016,10 @@ case class ArrayRepeat(left: Expression, right: Expression) private def nullElementsProtection( ev: ExprCode, - rightIsNull: String, - coreLogic: String): String = { + rightIsNull: ExprValue, + coreLogic: Block): Block = { if (nullable) { - s""" + code""" |if ($rightIsNull) { | ${ev.isNull} = true; |} else { @@ -1992,10 +2031,11 @@ case class ArrayRepeat(left: Expression, right: Expression) } } - private def genCodeForNumberOfElements(ctx: CodegenContext, count: String): (String, String) = { - val numElements = ctx.freshName("numElements") + private def genCodeForNumberOfElements(ctx: CodegenContext, + count: ExprValue): (ExprValue, Block) = { + val numElements = JavaCode.variable(ctx.freshName("numElements"), IntegerType) val numElementsCode = - s""" + code""" |int $numElements = 0; |if ($count > 0) { | $numElements = $count; @@ -2012,18 +2052,20 @@ case class ArrayRepeat(left: Expression, right: Expression) private def genCodeForPrimitiveElement( ctx: CodegenContext, elementType: DataType, - element: String, - count: String, - leftIsNull: String, - arrayDataName: String): String = { - val tempArrayDataName = ctx.freshName("tempArrayData") - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - val errorMessage = s" $prettyName failed." + element: ExprValue, + count: ExprValue, + leftIsNull: ExprValue, + arrayDataName: ExprValue): Block = { + val tempArrayDataName = JavaCode.variable(ctx.freshName("tempArrayData"), + classOf[UnsafeArrayData]) + val primitiveValueTypeName = inline"${CodeGenerator.primitiveTypeName(elementType)}" + val errorMessage = JavaCode.literal(s" $prettyName failed.", StringType) val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count) - - s""" + val createUnsafeArray = ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, + errorMessage) + code""" |$numElemCode - |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, errorMessage)} + |$createUnsafeArray |if (!$leftIsNull) { | for (int k = 0; k < $tempArrayDataName.numElements(); k++) { | $tempArrayDataName.set$primitiveValueTypeName(k, $element); @@ -2039,15 +2081,15 @@ case class ArrayRepeat(left: Expression, right: Expression) private def genCodeForNonPrimitiveElement( ctx: CodegenContext, - element: String, - count: String, - leftIsNull: String, - arrayDataName: String): String = { - val genericArrayClass = classOf[GenericArrayData].getName - val arrayName = ctx.freshName("arrayObject") + element: ExprValue, + count: ExprValue, + leftIsNull: ExprValue, + arrayDataName: ExprValue): Block = { + val genericArrayClass = inline"${classOf[GenericArrayData].getName}" + val arrayName = JavaCode.variable(ctx.freshName("arrayObject"), classOf[Array[Object]]) val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count) - s""" + code""" |$numElemCode |Object[] $arrayName = new Object[(int)$numElemName]; |if (!$leftIsNull) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index a9867aaeb0cf..f17430b85524 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods @@ -87,21 +87,21 @@ private [sql] object GenArrayData { ctx: CodegenContext, elementType: DataType, elementsCode: Seq[ExprCode], - isMapKey: Boolean): (String, String, String, String) = { - val arrayDataName = ctx.freshName("arrayData") + isMapKey: Boolean): (Block, Block, Block, ExprValue) = { val numElements = elementsCode.length if (!CodeGenerator.isPrimitiveType(elementType)) { - val arrayName = ctx.freshName("arrayObject") - val genericArrayClass = classOf[GenericArrayData].getName + val arrayDataName = JavaCode.variable(ctx.freshName("arrayData"), classOf[ArrayData]) + val arrayName = JavaCode.variable(ctx.freshName("arrayObject"), classOf[Array[Object]]) + val genericArrayClass = inline"${classOf[GenericArrayData].getName}" val assignments = elementsCode.zipWithIndex.map { case (eval, i) => val isNullAssignment = if (!isMapKey) { - s"$arrayName[$i] = null;" + code"$arrayName[$i] = null;" } else { - "throw new RuntimeException(\"Cannot use null as map key!\");" + code"""throw new RuntimeException("Cannot use null as map key!");""" } - eval.code + s""" + eval.code + code""" if (${eval.isNull}) { $isNullAssignment } else { @@ -112,27 +112,28 @@ private [sql] object GenArrayData { val assignmentString = ctx.splitExpressionsWithCurrentInputs( expressions = assignments, funcName = "apply", - extraArguments = ("Object[]", arrayName) :: Nil) + extraArguments = (arrayName) :: Nil) - (s"Object[] $arrayName = new Object[$numElements];", + (code"Object[] $arrayName = new Object[$numElements];", assignmentString, - s"final ArrayData $arrayDataName = new $genericArrayClass($arrayName);", + code"final ArrayData $arrayDataName = new $genericArrayClass($arrayName);", arrayDataName) } else { - val arrayName = ctx.freshName("array") + val arrayDataName = JavaCode.variable(ctx.freshName("arrayData"), classOf[UnsafeArrayData]) + val arrayName = JavaCode.variable(ctx.freshName("array"), BinaryType) val unsafeArraySizeInBytes = UnsafeArrayData.calculateHeaderPortionInBytes(numElements) + ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements) val baseOffset = Platform.BYTE_ARRAY_OFFSET - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + val primitiveValueTypeName = inline"${CodeGenerator.primitiveTypeName(elementType)}" val assignments = elementsCode.zipWithIndex.map { case (eval, i) => val isNullAssignment = if (!isMapKey) { - s"$arrayDataName.setNullAt($i);" + code"$arrayDataName.setNullAt($i);" } else { - "throw new RuntimeException(\"Cannot use null as map key!\");" + code"""throw new RuntimeException("Cannot use null as map key!");""" } - eval.code + s""" + eval.code + code""" if (${eval.isNull}) { $isNullAssignment } else { @@ -143,16 +144,16 @@ private [sql] object GenArrayData { val assignmentString = ctx.splitExpressionsWithCurrentInputs( expressions = assignments, funcName = "apply", - extraArguments = ("UnsafeArrayData", arrayDataName) :: Nil) + extraArguments = (arrayDataName) :: Nil) - (s""" + (code""" byte[] $arrayName = new byte[$unsafeArraySizeInBytes]; UnsafeArrayData $arrayDataName = new UnsafeArrayData(); Platform.putLong($arrayName, $baseOffset, $numElements); $arrayDataName.pointTo($arrayName, $baseOffset, $unsafeArraySizeInBytes); """, assignmentString, - "", + EmptyBlock, arrayDataName) } } @@ -211,7 +212,7 @@ case class CreateMap(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val mapClass = classOf[ArrayBasedMapData].getName + val mapClass = inline"${classOf[ArrayBasedMapData].getName}" val MapType(keyDt, valueDt, _) = dataType val evalKeys = keys.map(e => e.genCode(ctx)) val evalValues = values.map(e => e.genCode(ctx)) @@ -355,11 +356,11 @@ trait CreateNamedStructLike extends Expression { case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStructLike { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val rowClass = classOf[GenericInternalRow].getName - val values = ctx.freshName("values") + val rowClass = inline"${classOf[GenericInternalRow].getName}" + val values = JavaCode.variable(ctx.freshName("values"), classOf[Array[Object]]) val valCodes = valExprs.zipWithIndex.map { case (e, i) => val eval = e.genCode(ctx) - s""" + code""" |${eval.code} |if (${eval.isNull}) { | $values[$i] = null; @@ -371,7 +372,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc val valuesCode = ctx.splitExpressionsWithCurrentInputs( expressions = valCodes, funcName = "createNamedStruct", - extraArguments = "Object[]" -> values :: Nil) + extraArguments = values :: Nil) ev.copy(code = code""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 99671d5b863c..085b4f725e32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, EmptyBlock, ExprCode, JavaCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ @@ -125,7 +126,7 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { if (nullable) { - s""" + code""" if ($eval.isNullAt($ordinal)) { ${ev.isNull} = true; } else { @@ -133,7 +134,7 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] } """ } else { - s""" + code""" ${ev.value} = ${CodeGenerator.getValue(eval, dataType, ordinal.toString)}; """ } @@ -180,23 +181,23 @@ case class GetArrayStructFields( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val arrayClass = classOf[GenericArrayData].getName + val arrayClass = inline"${classOf[GenericArrayData].getName}" nullSafeCodeGen(ctx, ev, eval => { - val n = ctx.freshName("n") - val values = ctx.freshName("values") - val j = ctx.freshName("j") - val row = ctx.freshName("row") + val n = JavaCode.variable(ctx.freshName("n"), IntegerType) + val values = JavaCode.variable(ctx.freshName("values"), classOf[Array[Object]]) + val j = JavaCode.variable(ctx.freshName("j"), IntegerType) + val row = JavaCode.variable(ctx.freshName("row"), classOf[InternalRow]) val nullSafeEval = if (field.nullable) { - s""" + code""" if ($row.isNullAt($ordinal)) { $values[$j] = null; } else """ } else { - "" + EmptyBlock } - s""" + code""" final int $n = $eval.numElements(); final Object[] $values = new Object[$n]; for (int $j = 0; $j < $n; $j++) { @@ -249,13 +250,13 @@ case class GetArrayItem(child: Expression, ordinal: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - val index = ctx.freshName("index") + val index = JavaCode.variable(ctx.freshName("index"), IntegerType) val nullCheck = if (child.dataType.asInstanceOf[ArrayType].containsNull) { - s" || $eval1.isNullAt($index)" + code" || $eval1.isNullAt($index)" } else { - "" + EmptyBlock } - s""" + code""" final int $index = (int) $eval2; if ($index >= $eval1.numElements() || $index < 0$nullCheck) { ${ev.isNull} = true; @@ -297,21 +298,21 @@ abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTy } def doGetValueGenCode(ctx: CodegenContext, ev: ExprCode, mapType: MapType): ExprCode = { - val index = ctx.freshName("index") - val length = ctx.freshName("length") - val keys = ctx.freshName("keys") - val found = ctx.freshName("found") - val key = ctx.freshName("key") - val values = ctx.freshName("values") + val index = JavaCode.variable(ctx.freshName("index"), IntegerType) + val length = JavaCode.variable(ctx.freshName("length"), IntegerType) + val keys = JavaCode.variable(ctx.freshName("keys"), classOf[ArrayData]) + val found = JavaCode.variable(ctx.freshName("found"), BooleanType) val keyType = mapType.keyType + val key = JavaCode.variable(ctx.freshName("key"), keyType) + val values = JavaCode.variable(ctx.freshName("values"), classOf[ArrayData]) val nullCheck = if (mapType.valueContainsNull) { - s" || $values.isNullAt($index)" + code" || $values.isNullAt($index)" } else { - "" + EmptyBlock } - val keyJavaType = CodeGenerator.javaType(keyType) + val keyJavaType = inline"${CodeGenerator.javaType(keyType)}" nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - s""" + code""" final int $length = $eval1.numElements(); final ArrayData $keys = $eval1.keyArray(); final ArrayData $values = $eval1.valueArray(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 77ac6c088022..40941b72cf0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -65,12 +65,12 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi val condEval = predicate.genCode(ctx) val trueEval = trueValue.genCode(ctx) val falseEval = falseValue.genCode(ctx) - + val javaType = inline"${CodeGenerator.javaType(dataType)}" val code = code""" |${condEval.code} |boolean ${ev.isNull} = false; - |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |if (!${condEval.isNull} && ${condEval.value}) { | ${trueEval.code} | ${ev.isNull} = ${trueEval.isNull}; @@ -191,7 +191,7 @@ case class CaseWhen( val HAS_NULL = 1 // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, // We won't go on anymore on the computation. - val resultState = ctx.freshName("caseWhenResultState") + val resultState = JavaCode.variable(ctx.freshName("caseWhenResultState"), ByteType) ev.value = JavaCode.global( ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value), dataType) @@ -204,7 +204,7 @@ case class CaseWhen( val cases = branches.map { case (condExpr, valueExpr) => val cond = condExpr.genCode(ctx) val res = valueExpr.genCode(ctx) - s""" + code""" |${cond.code} |if (!${cond.isNull} && ${cond.value}) { | ${res.code} @@ -217,7 +217,7 @@ case class CaseWhen( val elseCode = elseValue.map { elseExpr => val res = elseExpr.genCode(ctx) - s""" + code""" |${res.code} |$resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL); |${ev.value} = ${res.value}; @@ -256,18 +256,20 @@ case class CaseWhen( |} while (false); |return $resultState; """.stripMargin, - foldFunctions = _.map { funcCall => - s""" - |$resultState = $funcCall; - |if ($resultState != $NOT_MATCHED) { - | continue; - |} - """.stripMargin - }.mkString) + foldFunctions = funcCalls => + Blocks(funcCalls.map { funcCall => + code""" + |$resultState = $funcCall; + |if ($resultState != $NOT_MATCHED) { + | continue; + |} + """ + }) + ) ev.copy(code = code""" - |${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED; + |${inline"${CodeGenerator.JAVA_BYTE}"} $resultState = $NOT_MATCHED; |do { | $codes |} while (false); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index e8d85f72f7a7..aeb8199977c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -164,7 +164,7 @@ case class DateAdd(startDate: Expression, days: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (sd, d) => { - s"""${ev.value} = $sd + $d;""" + code"""${ev.value} = $sd + $d;""" }) } @@ -197,7 +197,7 @@ case class DateSub(startDate: Expression, days: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (sd, d) => { - s"""${ev.value} = $sd - $d;""" + code"""${ev.value} = $sd - $d;""" }) } @@ -229,9 +229,9 @@ case class Hour(child: Expression, timeZoneId: Option[String] = None) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val tz = ctx.addReferenceObj("timeZone", timeZone) - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, c => s"$dtu.getHours($c, $tz)") + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), classOf[TimeZone]) + val dtu = inline"${DateTimeUtils.getClass.getName.stripSuffix("$")}" + defineCodeGen(ctx, ev, c => code"$dtu.getHours($c, $tz)") } } @@ -260,9 +260,9 @@ case class Minute(child: Expression, timeZoneId: Option[String] = None) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val tz = ctx.addReferenceObj("timeZone", timeZone) - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, c => s"$dtu.getMinutes($c, $tz)") + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), classOf[TimeZone]) + val dtu = inline"${DateTimeUtils.getClass.getName.stripSuffix("$")}" + defineCodeGen(ctx, ev, c => code"$dtu.getMinutes($c, $tz)") } } @@ -291,9 +291,9 @@ case class Second(child: Expression, timeZoneId: Option[String] = None) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val tz = ctx.addReferenceObj("timeZone", timeZone) - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, c => s"$dtu.getSeconds($c, $tz)") + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), classOf[TimeZone]) + val dtu = inline"${DateTimeUtils.getClass.getName.stripSuffix("$")}" + defineCodeGen(ctx, ev, c => code"$dtu.getSeconds($c, $tz)") } } @@ -316,8 +316,8 @@ case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCas } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, c => s"$dtu.getDayInYear($c)") + val dtu = inline"${DateTimeUtils.getClass.getName.stripSuffix("$")}" + defineCodeGen(ctx, ev, c => code"$dtu.getDayInYear($c)") } } @@ -340,8 +340,8 @@ case class Year(child: Expression) extends UnaryExpression with ImplicitCastInpu } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, c => s"$dtu.getYear($c)") + val dtu = inline"${DateTimeUtils.getClass.getName.stripSuffix("$")}" + defineCodeGen(ctx, ev, c => code"$dtu.getYear($c)") } } @@ -364,8 +364,8 @@ case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastI } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, c => s"$dtu.getQuarter($c)") + val dtu = inline"${DateTimeUtils.getClass.getName.stripSuffix("$")}" + defineCodeGen(ctx, ev, c => code"$dtu.getQuarter($c)") } } @@ -388,8 +388,8 @@ case class Month(child: Expression) extends UnaryExpression with ImplicitCastInp } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, c => s"$dtu.getMonth($c)") + val dtu = inline"${DateTimeUtils.getClass.getName.stripSuffix("$")}" + defineCodeGen(ctx, ev, c => code"$dtu.getMonth($c)") } } @@ -412,8 +412,8 @@ case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCa } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, c => s"$dtu.getDayOfMonth($c)") + val dtu = inline"${DateTimeUtils.getClass.getName.stripSuffix("$")}" + defineCodeGen(ctx, ev, c => code"$dtu.getDayOfMonth($c)") } } @@ -436,12 +436,12 @@ case class DayOfWeek(child: Expression) extends DayWeek { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, time => { - val cal = classOf[Calendar].getName + val cal = inline"${classOf[Calendar].getName}" val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - val c = "calDayOfWeek" - ctx.addImmutableStateIfNotExists(cal, c, + val c = JavaCode.global("calDayOfWeek", classOf[Calendar]) + ctx.addImmutableStateIfNotExists(cal.code, c, v => s"""$v = $cal.getInstance($dtu.getTimeZone("UTC"));""") - s""" + code""" $c.setTimeInMillis($time * 1000L * 3600L * 24L); ${ev.value} = $c.get($cal.DAY_OF_WEEK); """ @@ -468,12 +468,12 @@ case class WeekDay(child: Expression) extends DayWeek { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, time => { - val cal = classOf[Calendar].getName + val cal = inline"${classOf[Calendar].getName}" val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - val c = "calWeekDay" - ctx.addImmutableStateIfNotExists(cal, c, + val c = JavaCode.global("calWeekDay", classOf[Calendar]) + ctx.addImmutableStateIfNotExists(cal.code, c, v => s"""$v = $cal.getInstance($dtu.getTimeZone("UTC"));""") - s""" + code""" $c.setTimeInMillis($time * 1000L * 3600L * 24L); ${ev.value} = ($c.get($cal.DAY_OF_WEEK) + 5) % 7; """ @@ -522,16 +522,16 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, time => { - val cal = classOf[Calendar].getName - val c = "calWeekOfYear" - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - ctx.addImmutableStateIfNotExists(cal, c, v => + val cal = inline"${classOf[Calendar].getName}" + val c = JavaCode.global("calWeekOfYear", classOf[Calendar]) + val dtu = inline"${DateTimeUtils.getClass.getName.stripSuffix("$")}" + ctx.addImmutableStateIfNotExists(cal.code, c, v => s""" |$v = $cal.getInstance($dtu.getTimeZone("UTC")); |$v.setFirstDayOfWeek($cal.MONDAY); |$v.setMinimalDaysInFirstWeek(4); """.stripMargin) - s""" + code""" |$c.setTimeInMillis($time * 1000L * 3600L * 24L); |${ev.value} = $c.get($cal.WEEK_OF_YEAR); """.stripMargin @@ -567,10 +567,10 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - val tz = ctx.addReferenceObj("timeZone", timeZone) + val dtu = inline"${DateTimeUtils.getClass.getName.stripSuffix("$")}" + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), classOf[TimeZone]) defineCodeGen(ctx, ev, (timestamp, format) => { - s"""UTF8String.fromString($dtu.newDateFormat($format.toString(), $tz) + code"""UTF8String.fromString($dtu.newDateFormat($format.toString(), $tz) .format(new java.util.Date($timestamp / 1000)))""" }) } @@ -709,14 +709,15 @@ abstract class UnixTime } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = CodeGenerator.javaType(dataType) + val javaType = inline"${CodeGenerator.javaType(dataType)}" left.dataType match { case StringType if right.foldable => val df = classOf[DateFormat].getName if (formatter == null) { ExprCode.forNullValue(dataType) } else { - val formatterName = ctx.addReferenceObj("formatter", formatter, df) + val formatterName = JavaCode.global(ctx.addReferenceObj("formatter", formatter, df), + classOf[DateFormat]) val eval1 = left.genCode(ctx) ev.copy(code = code""" ${eval1.code} @@ -731,10 +732,10 @@ abstract class UnixTime }""") } case StringType => - val tz = ctx.addReferenceObj("timeZone", timeZone) - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), classOf[TimeZone]) + val dtu = inline"${DateTimeUtils.getClass.getName.stripSuffix("$")}" nullSafeCodeGen(ctx, ev, (string, format) => { - s""" + code""" try { ${ev.value} = $dtu.newDateFormat($format.toString(), $tz) .parse($string.toString()).getTime() / 1000L; @@ -755,8 +756,8 @@ abstract class UnixTime ${ev.value} = ${eval1.value} / 1000000L; }""") case DateType => - val tz = ctx.addReferenceObj("timeZone", timeZone) - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), classOf[TimeZone]) + val dtu = inline"${DateTimeUtils.getClass.getName.stripSuffix("$")}" val eval1 = left.genCode(ctx) ev.copy(code = code""" ${eval1.code} @@ -851,12 +852,14 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ if (formatter == null) { ExprCode.forNullValue(StringType) } else { - val formatterName = ctx.addReferenceObj("formatter", formatter, df) + val formatterName = JavaCode.global(ctx.addReferenceObj("formatter", formatter, df), + classOf[DateFormat]) + val javaType = inline"${CodeGenerator.javaType(dataType)}" val t = left.genCode(ctx) ev.copy(code = code""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { try { ${ev.value} = UTF8String.fromString($formatterName.format( @@ -867,10 +870,10 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ }""") } } else { - val tz = ctx.addReferenceObj("timeZone", timeZone) - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), classOf[TimeZone]) + val dtu = inline"${DateTimeUtils.getClass.getName.stripSuffix("$")}" nullSafeCodeGen(ctx, ev, (seconds, f) => { - s""" + code""" try { ${ev.value} = UTF8String.fromString($dtu.newDateFormat($f.toString(), $tz).format( new java.util.Date($seconds * 1000L))); @@ -905,8 +908,8 @@ case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitC } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, sd => s"$dtu.getLastDayOfMonth($sd)") + val dtu = inline"${DateTimeUtils.getClass.getName.stripSuffix("$")}" + defineCodeGen(ctx, ev, sd => code"$dtu.getLastDayOfMonth($sd)") } override def prettyName: String = "last_day" @@ -952,22 +955,22 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (sd, dowS) => { - val dateTimeUtilClass = DateTimeUtils.getClass.getName.stripSuffix("$") - val dayOfWeekTerm = ctx.freshName("dayOfWeek") + val dateTimeUtilClass = inline"${DateTimeUtils.getClass.getName.stripSuffix("$")}" + val dayOfWeekTerm = JavaCode.variable(ctx.freshName("dayOfWeek"), IntegerType) if (dayOfWeek.foldable) { val input = dayOfWeek.eval().asInstanceOf[UTF8String] if ((input eq null) || DateTimeUtils.getDayOfWeekFromString(input) == -1) { - s""" + code""" |${ev.isNull} = true; """.stripMargin } else { val dayOfWeekValue = DateTimeUtils.getDayOfWeekFromString(input) - s""" + code""" |${ev.value} = $dateTimeUtilClass.getNextDateForDayOfWeek($sd, $dayOfWeekValue); """.stripMargin } } else { - s""" + code""" |int $dayOfWeekTerm = $dateTimeUtilClass.getDayOfWeekFromString($dowS); |if ($dayOfWeekTerm == -1) { | ${ev.isNull} = true; @@ -1009,10 +1012,10 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val tz = ctx.addReferenceObj("timeZone", timeZone) - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), classOf[TimeZone]) + val dtu = inline"${DateTimeUtils.getClass.getName.stripSuffix("$")}" defineCodeGen(ctx, ev, (sd, i) => { - s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds, $tz)""" + code"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds, $tz)""" }) } } @@ -1039,14 +1042,16 @@ case class StringToTimestampWithoutTimezone(child: Expression, timeZoneId: Optio } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - val tz = ctx.addReferenceObj("timeZone", timeZone) - val longOpt = ctx.freshName("longOpt") + val dtu = inline"${DateTimeUtils.getClass.getName.stripSuffix("$")}" + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), classOf[TimeZone]) + val longOpt = JavaCode.variable(ctx.freshName("longOpt"), classOf[Option[Long]]) val eval = child.genCode(ctx) + val intType = inline"${CodeGenerator.JAVA_BOOLEAN}" + val longType = inline"${CodeGenerator.JAVA_LONG}" val code = code""" |${eval.code} - |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = true; - |${CodeGenerator.JAVA_LONG} ${ev.value} = ${CodeGenerator.defaultValue(TimestampType)}; + |$intType ${ev.isNull} = true; + |$longType ${ev.value} = ${CodeGenerator.defaultValue(TimestampType)}; |if (!${eval.isNull}) { | scala.Option $longOpt = $dtu.stringToTimestamp(${eval.value}, $tz, true); | if ($longOpt.isDefined()) { @@ -1087,7 +1092,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val dtu = inline"${DateTimeUtils.getClass.getName.stripSuffix("$")}" if (right.foldable) { val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { @@ -1097,11 +1102,11 @@ case class FromUTCTimestamp(left: Expression, right: Expression) """.stripMargin) } else { val tzClass = classOf[TimeZone].getName - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val dtu = inline"${DateTimeUtils.getClass.getName.stripSuffix("$")}" val escapedTz = StringEscapeUtils.escapeJava(tz.toString) - val tzTerm = ctx.addMutableState(tzClass, "tz", - v => s"""$v = $dtu.getTimeZone("$escapedTz");""") - val utcTerm = "tzUTC" + val tzTerm = JavaCode.global(ctx.addMutableState(tzClass, "tz", + v => s"""$v = $dtu.getTimeZone("$escapedTz");"""), classOf[TimeZone]) + val utcTerm = JavaCode.global("tzUTC", classOf[TimeZone]) ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) @@ -1116,7 +1121,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) } } else { defineCodeGen(ctx, ev, (timestamp, format) => { - s"""$dtu.fromUTCTime($timestamp, $format.toString())""" + code"""$dtu.fromUTCTime($timestamp, $format.toString())""" }) } } @@ -1149,10 +1154,10 @@ case class TimeSub(start: Expression, interval: Expression, timeZoneId: Option[S } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val tz = ctx.addReferenceObj("timeZone", timeZone) - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), classOf[TimeZone]) + val dtu = inline"${DateTimeUtils.getClass.getName.stripSuffix("$")}" defineCodeGen(ctx, ev, (sd, i) => { - s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds, $tz)""" + code"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds, $tz)""" }) } } @@ -1185,9 +1190,9 @@ case class AddMonths(startDate: Expression, numMonths: Expression) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val dtu = inline"${DateTimeUtils.getClass.getName.stripSuffix("$")}" defineCodeGen(ctx, ev, (sd, m) => { - s"""$dtu.dateAddMonths($sd, $m)""" + code"""$dtu.dateAddMonths($sd, $m)""" }) } @@ -1246,10 +1251,10 @@ case class MonthsBetween( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val tz = ctx.addReferenceObj("timeZone", timeZone) - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), classOf[TimeZone]) + val dtu = inline"${DateTimeUtils.getClass.getName.stripSuffix("$")}" defineCodeGen(ctx, ev, (d1, d2, roundOff) => { - s"""$dtu.monthsBetween($d1, $d2, $roundOff, $tz)""" + code"""$dtu.monthsBetween($d1, $d2, $roundOff, $tz)""" }) } @@ -1284,7 +1289,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val dtu = inline"${DateTimeUtils.getClass.getName.stripSuffix("$")}" if (right.foldable) { val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { @@ -1294,11 +1299,11 @@ case class ToUTCTimestamp(left: Expression, right: Expression) """.stripMargin) } else { val tzClass = classOf[TimeZone].getName - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val dtu = inline"${DateTimeUtils.getClass.getName.stripSuffix("$")}" val escapedTz = StringEscapeUtils.escapeJava(tz.toString) - val tzTerm = ctx.addMutableState(tzClass, "tz", - v => s"""$v = $dtu.getTimeZone("$escapedTz");""") - val utcTerm = "tzUTC" + val tzTerm = JavaCode.global(ctx.addMutableState(tzClass, "tz", + v => s"""$v = $dtu.getTimeZone("$escapedTz");"""), classOf[TimeZone]) + val utcTerm = JavaCode.global("tzUTC", classOf[TimeZone]) ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) @@ -1313,7 +1318,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) } } else { defineCodeGen(ctx, ev, (timestamp, format) => { - s"""$dtu.toUTCTime($timestamp, $format.toString())""" + code"""$dtu.toUTCTime($timestamp, $format.toString())""" }) } } @@ -1438,11 +1443,11 @@ trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes { ev: ExprCode, maxLevel: Int, orderReversed: Boolean = false)( - truncFunc: (String, String) => String) + truncFunc: (ExprValue, ExprValue) => Block) : ExprCode = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val dtu = inline"${DateTimeUtils.getClass.getName.stripSuffix("$")}" - val javaType = CodeGenerator.javaType(dataType) + val javaType = inline"${CodeGenerator.javaType(dataType)}" if (format.foldable) { if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) { ev.copy(code = code""" @@ -1450,7 +1455,7 @@ trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes { $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { val t = instant.genCode(ctx) - val truncFuncStr = truncFunc(t.value, truncLevel.toString) + val truncFuncStr = truncFunc(t.value, JavaCode.variable(truncLevel.toString, IntegerType)) ev.copy(code = code""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; @@ -1461,14 +1466,14 @@ trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes { } } else { nullSafeCodeGen(ctx, ev, (left, right) => { - val form = ctx.freshName("form") + val form = JavaCode.variable(ctx.freshName("form"), IntegerType) val (dateVal, fmt) = if (orderReversed) { (right, left) } else { (left, right) } val truncFuncStr = truncFunc(dateVal, form) - s""" + code""" int $form = $dtu.parseTruncLevel($fmt); if ($form == -1 || $form > $maxLevel) { ${ev.isNull} = true; @@ -1516,8 +1521,9 @@ case class TruncDate(date: Expression, format: Expression) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - codeGenHelper(ctx, ev, maxLevel = DateTimeUtils.TRUNC_TO_MONTH) { (date: String, fmt: String) => - s"truncDate($date, $fmt);" + codeGenHelper(ctx, ev, maxLevel = DateTimeUtils.TRUNC_TO_MONTH) { + (date: ExprValue, fmt: ExprValue) => + code"truncDate($date, $fmt);" } } } @@ -1568,10 +1574,10 @@ case class TruncTimestamp( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val tz = ctx.addReferenceObj("timeZone", timeZone) + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), classOf[TimeZone]) codeGenHelper(ctx, ev, maxLevel = DateTimeUtils.TRUNC_TO_SECOND, true) { - (date: String, fmt: String) => - s"truncTimestamp($date, $fmt, $tz);" + (date: ExprValue, fmt: ExprValue) => + code"truncTimestamp($date, $fmt, $tz);" } } } @@ -1603,6 +1609,6 @@ case class DateDiff(endDate: Expression, startDate: Expression) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (end, start) => s"$end - $start") + defineCodeGen(ctx, ev, (end, start) => code"$end - $start") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 04de83343be7..a0fd98d88cfa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode, JavaCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ /** @@ -35,7 +36,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { input.asInstanceOf[Decimal].toUnscaledLong override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()") + defineCodeGen(ctx, ev, c => code"$c.toUnscaledLong()") } } @@ -55,7 +56,7 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { - s""" + code""" ${ev.value} = (new Decimal()).setOrNull($eval, $precision, $scale); ${ev.isNull} = ${ev.value} == null; """ @@ -92,8 +93,8 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { - val tmp = ctx.freshName("tmp") - s""" + val tmp = JavaCode.variable(ctx.freshName("tmp"), classOf[Decimal]) + code""" | Decimal $tmp = $eval.clone(); | if ($tmp.changePrecision(${dataType.precision}, ${dataType.scale})) { | ${ev.value} = $tmp; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index b7c52f1d7b40..6402956916e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -200,8 +200,8 @@ case class Stack(children: Seq[Expression]) extends Generator { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Rows - we write these into an array. - val rowData = ctx.addMutableState("InternalRow[]", "rows", - v => s"$v = new InternalRow[$numRows];") + val rowData = JavaCode.global(ctx.addMutableState("InternalRow[]", "rows", + v => s"$v = new InternalRow[$numRows];"), classOf[Array[InternalRow]]) val values = children.tail val dataTypes = values.take(numFields).map(_.dataType) val code = ctx.splitExpressionsWithCurrentInputs(Seq.tabulate(numRows) { row => @@ -210,11 +210,11 @@ case class Stack(children: Seq[Expression]) extends Generator { if (index < values.length) values(index) else Literal(null, dataTypes(col)) } val eval = CreateStruct(fields).genCode(ctx) - s"${eval.code}\n$rowData[$row] = ${eval.value};" + code"${eval.code}\n$rowData[$row] = ${eval.value};" }) // Create the collection. - val wrapperClass = classOf[mutable.WrappedArray[_]].getName + val wrapperClass = inline"${classOf[mutable.WrappedArray[_]].getName}" ev.copy(code = code""" |$code diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index cec00b66f873..57e5545af8b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -62,7 +62,7 @@ case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInput override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => - s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))") + code"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))") } } @@ -119,9 +119,9 @@ case class Sha2(left: Expression, right: Expression) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val digestUtils = "org.apache.commons.codec.digest.DigestUtils" + val digestUtils = inline"org.apache.commons.codec.digest.DigestUtils" nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - s""" + code""" if ($eval2 == 224) { try { java.security.MessageDigest md = java.security.MessageDigest.getInstance("SHA-224"); @@ -169,7 +169,7 @@ case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInpu override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => - s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.sha1Hex($c))" + code"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.sha1Hex($c))" ) } } @@ -198,10 +198,10 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val CRC32 = "java.util.zip.CRC32" - val checksum = ctx.freshName("checksum") + val CRC32 = inline"java.util.zip.CRC32" + val checksum = JavaCode.variable(ctx.freshName("checksum"), classOf[java.util.zip.CRC32]) nullSafeCodeGen(ctx, ev, value => { - s""" + code""" $CRC32 $checksum = new $CRC32(); $checksum.update($value, 0, $value.length); ${ev.value} = $checksum.getValue(); @@ -284,98 +284,101 @@ abstract class HashExpression[E] extends Expression { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = childrenHash, funcName = "computeHash", - extraArguments = Seq(hashResultType -> ev.value), + extraArguments = Seq(ev.value), returnType = hashResultType, makeSplitFunction = body => s""" |$body |return ${ev.value}; """.stripMargin, - foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) + foldFunctions = funcCalls => { + Blocks(funcCalls.map(funcCall => code"${ev.value} = $funcCall;")) + }) + val seedExpr = JavaCode.literal(s"$seed", dataType) ev.copy(code = code""" - |$hashResultType ${ev.value} = $seed; + |${inline"$hashResultType"} ${ev.value} = $seedExpr; |$codes """.stripMargin) } protected def nullSafeElementHash( - input: String, - index: String, + input: ExprValue, + index: ExprValue, nullable: Boolean, elementType: DataType, - result: String, - ctx: CodegenContext): String = { - val element = ctx.freshName("element") + result: ExprValue, + ctx: CodegenContext): Block = { + val element = JavaCode.variable(ctx.freshName("element"), elementType) - val jt = CodeGenerator.javaType(elementType) - ctx.nullSafeExec(nullable, s"$input.isNullAt($index)") { - s""" + val jt = inline"${CodeGenerator.javaType(elementType)}" + ctx.nullSafeExec(nullable, JavaCode.expression(s"$input.isNullAt($index)", BooleanType)) { + code""" final $jt $element = ${CodeGenerator.getValue(input, elementType, index)}; ${computeHash(element, elementType, result, ctx)} """ } } - protected def genHashInt(i: String, result: String): String = - s"$result = $hasherClassName.hashInt($i, $result);" + protected def genHashInt(i: ExprValue, result: ExprValue): Block = + code"$result = $hasherClassName.hashInt($i, $result);" - protected def genHashLong(l: String, result: String): String = - s"$result = $hasherClassName.hashLong($l, $result);" + protected def genHashLong(l: ExprValue, result: ExprValue): Block = + code"$result = $hasherClassName.hashLong($l, $result);" - protected def genHashBytes(b: String, result: String): String = { - val offset = "Platform.BYTE_ARRAY_OFFSET" - s"$result = $hasherClassName.hashUnsafeBytes($b, $offset, $b.length, $result);" + protected def genHashBytes(b: ExprValue, result: ExprValue): Block = { + val offset = JavaCode.global("Platform.BYTE_ARRAY_OFFSET", LongType) + code"$result = $hasherClassName.hashUnsafeBytes($b, $offset, $b.length, $result);" } - protected def genHashBoolean(input: String, result: String): String = - genHashInt(s"$input ? 1 : 0", result) + protected def genHashBoolean(input: ExprValue, result: ExprValue): Block = + genHashInt(JavaCode.expression(s"$input ? 1 : 0", IntegerType), result) - protected def genHashFloat(input: String, result: String): String = - genHashInt(s"Float.floatToIntBits($input)", result) + protected def genHashFloat(input: ExprValue, result: ExprValue): Block = + genHashInt(JavaCode.expression(s"Float.floatToIntBits($input)", IntegerType), result) - protected def genHashDouble(input: String, result: String): String = - genHashLong(s"Double.doubleToLongBits($input)", result) + protected def genHashDouble(input: ExprValue, result: ExprValue): Block = + genHashLong(JavaCode.expression(s"Double.doubleToLongBits($input)", LongType), result) protected def genHashDecimal( ctx: CodegenContext, d: DecimalType, - input: String, - result: String): String = { + input: ExprValue, + result: ExprValue): Block = { if (d.precision <= Decimal.MAX_LONG_DIGITS) { - genHashLong(s"$input.toUnscaledLong()", result) + genHashLong(JavaCode.expression(s"$input.toUnscaledLong()", LongType), result) } else { - val bytes = ctx.freshName("bytes") - s""" + val bytes = JavaCode.variable(ctx.freshName("bytes"), BinaryType) + code""" |final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray(); |${genHashBytes(bytes, result)} """.stripMargin } } - protected def genHashTimestamp(t: String, result: String): String = genHashLong(t, result) + protected def genHashTimestamp(t: ExprValue, result: ExprValue): Block = genHashLong(t, result) - protected def genHashCalendarInterval(input: String, result: String): String = { - val microsecondsHash = s"$hasherClassName.hashLong($input.microseconds, $result)" - s"$result = $hasherClassName.hashInt($input.months, $microsecondsHash);" + protected def genHashCalendarInterval(input: ExprValue, result: ExprValue): Block = { + val microsecondsHash = code"$hasherClassName.hashLong($input.microseconds, $result)" + code"$result = $hasherClassName.hashInt($input.months, $microsecondsHash);" } - protected def genHashString(input: String, result: String): String = { - s"$result = $hasherClassName.hashUTF8String($input, $result);" + protected def genHashString(input: ExprValue, result: ExprValue): Block = { + code"$result = $hasherClassName.hashUTF8String($input, $result);" } protected def genHashForMap( ctx: CodegenContext, - input: String, - result: String, + input: ExprValue, + result: ExprValue, keyType: DataType, valueType: DataType, - valueContainsNull: Boolean): String = { - val index = ctx.freshName("index") - val keys = ctx.freshName("keys") - val values = ctx.freshName("values") - s""" + valueContainsNull: Boolean): Block = { + val index = JavaCode.variable(ctx.freshName("index"), IntegerType) + val keys = JavaCode.variable(ctx.freshName("keys"), classOf[ArrayData]) + val values = JavaCode.variable(ctx.freshName("values"), classOf[ArrayData]) + code""" final ArrayData $keys = $input.keyArray(); final ArrayData $values = $input.valueArray(); for (int $index = 0; $index < $input.numElements(); $index++) { @@ -387,12 +390,12 @@ abstract class HashExpression[E] extends Expression { protected def genHashForArray( ctx: CodegenContext, - input: String, - result: String, + input: ExprValue, + result: ExprValue, elementType: DataType, - containsNull: Boolean): String = { - val index = ctx.freshName("index") - s""" + containsNull: Boolean): Block = { + val index = JavaCode.variable(ctx.freshName("index"), IntegerType) + code""" for (int $index = 0; $index < $input.numElements(); $index++) { ${nullSafeElementHash(input, index, containsNull, elementType, result, ctx)} } @@ -401,33 +404,36 @@ abstract class HashExpression[E] extends Expression { protected def genHashForStruct( ctx: CodegenContext, - input: String, - result: String, - fields: Array[StructField]): String = { + input: ExprValue, + result: ExprValue, + fields: Array[StructField]): Block = { val fieldsHash = fields.zipWithIndex.map { case (field, index) => - nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) + val indexExpr = JavaCode.literal(index.toString, IntegerType) + nullSafeElementHash(input, indexExpr, field.nullable, field.dataType, result, ctx) } val hashResultType = CodeGenerator.javaType(dataType) ctx.splitExpressions( expressions = fieldsHash, funcName = "computeHashForStruct", - arguments = Seq("InternalRow" -> input, hashResultType -> result), + arguments = Seq(input, result), returnType = hashResultType, makeSplitFunction = body => s""" |$body |return $result; """.stripMargin, - foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n")) + foldFunctions = funcCalls => { + Blocks(funcCalls.map(funcCall => code"$result = $funcCall;")) + }) } @tailrec private def computeHashWithTailRec( - input: String, + input: ExprValue, dataType: DataType, - result: String, - ctx: CodegenContext): String = dataType match { - case NullType => "" + result: ExprValue, + ctx: CodegenContext): Block = dataType match { + case NullType => EmptyBlock case BooleanType => genHashBoolean(input, result) case ByteType | ShortType | IntegerType | DateType => genHashInt(input, result) case LongType => genHashLong(input, result) @@ -446,12 +452,12 @@ abstract class HashExpression[E] extends Expression { } protected def computeHash( - input: String, + input: ExprValue, dataType: DataType, - result: String, - ctx: CodegenContext): String = computeHashWithTailRec(input, dataType, result, ctx) + result: ExprValue, + ctx: CodegenContext): Block = computeHashWithTailRec(input, dataType, result, ctx) - protected def hasherClassName: String + protected def hasherClassName: JavaCode } /** @@ -562,7 +568,8 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends HashExpress override def prettyName: String = "hash" - override protected def hasherClassName: String = classOf[Murmur3_x86_32].getName + override protected def hasherClassName: JavaCode = + inline"${classOf[Murmur3_x86_32].getName}" override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = { Murmur3HashFunction.hash(value, dataType, seed).toInt @@ -599,7 +606,7 @@ case class XxHash64(children: Seq[Expression], seed: Long) extends HashExpressio override def prettyName: String = "xxHash" - override protected def hasherClassName: String = classOf[XXH64].getName + override protected def hasherClassName: JavaCode = inline"${classOf[XXH64].getName}" override protected def computeHash(value: Any, dataType: DataType, seed: Long): Long = { XxHash64Function.hash(value, dataType, seed) @@ -637,7 +644,8 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override def prettyName: String = "hive-hash" - override protected def hasherClassName: String = classOf[HiveHasher].getName + override protected def hasherClassName: JavaCode = + inline"${classOf[HiveHasher].getName}" override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = { HiveHashFunction.hash(value, dataType, this.seed).toInt @@ -646,13 +654,13 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.isNull = FalseLiteral - val childHash = ctx.freshName("childHash") + val childHash = JavaCode.variable(ctx.freshName("childHash"), IntegerType) val childrenHash = children.map { child => val childGen = child.genCode(ctx) val codeToComputeHash = ctx.nullSafeExec(child.nullable, childGen.isNull) { computeHash(childGen.value, child.dataType, childHash, ctx) } - s""" + code""" |${childGen.code} |$childHash = 0; |$codeToComputeHash @@ -663,7 +671,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = childrenHash, funcName = "computeHash", - extraArguments = Seq(CodeGenerator.JAVA_INT -> ev.value), + extraArguments = Seq(ev.value), returnType = CodeGenerator.JAVA_INT, makeSplitFunction = body => s""" @@ -671,13 +679,15 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { |$body |return ${ev.value}; """.stripMargin, - foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) + foldFunctions = funcCalls => { + Blocks(funcCalls.map(funcCall => code"${ev.value} = $funcCall;")) + }) ev.copy(code = code""" - |${CodeGenerator.JAVA_INT} ${ev.value} = $seed; - |${CodeGenerator.JAVA_INT} $childHash = 0; + |${inline"${CodeGenerator.JAVA_INT}"} ${ev.value} = $seed; + |${inline"${CodeGenerator.JAVA_INT}"} $childHash = 0; |$codes """.stripMargin) } @@ -693,50 +703,54 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { hash } - override protected def genHashInt(i: String, result: String): String = - s"$result = $hasherClassName.hashInt($i);" + override protected def genHashInt(i: ExprValue, result: ExprValue): Block = + code"$result = $hasherClassName.hashInt($i);" - override protected def genHashLong(l: String, result: String): String = - s"$result = $hasherClassName.hashLong($l);" + override protected def genHashLong(l: ExprValue, result: ExprValue): Block = + code"$result = $hasherClassName.hashLong($l);" - override protected def genHashBytes(b: String, result: String): String = - s"$result = $hasherClassName.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length);" + override protected def genHashBytes(b: ExprValue, result: ExprValue): Block = + code"$result = $hasherClassName.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length);" override protected def genHashDecimal( ctx: CodegenContext, d: DecimalType, - input: String, - result: String): String = { - s""" - $result = ${HiveHashFunction.getClass.getName.stripSuffix("$")}.normalizeDecimal( + input: ExprValue, + result: ExprValue): Block = { + val hiveHashFunction = inline"${HiveHashFunction.getClass.getName.stripSuffix("$")}" + code""" + $result = $hiveHashFunction.normalizeDecimal( $input.toJavaBigDecimal()).hashCode();""" } - override protected def genHashCalendarInterval(input: String, result: String): String = { - s""" + override protected def genHashCalendarInterval(input: ExprValue, result: ExprValue): Block = { + val hiveHashFunction = inline"${HiveHashFunction.getClass.getName.stripSuffix("$")}" + code""" $result = (int) - ${HiveHashFunction.getClass.getName.stripSuffix("$")}.hashCalendarInterval($input); + $hiveHashFunction.hashCalendarInterval($input); """ } - override protected def genHashTimestamp(input: String, result: String): String = - s""" - $result = (int) ${HiveHashFunction.getClass.getName.stripSuffix("$")}.hashTimestamp($input); + override protected def genHashTimestamp(input: ExprValue, result: ExprValue): Block = { + val hiveHashFunction = inline"${HiveHashFunction.getClass.getName.stripSuffix("$")}" + code""" + $result = (int) $hiveHashFunction.hashTimestamp($input); """ + } - override protected def genHashString(input: String, result: String): String = { - s"$result = $hasherClassName.hashUTF8String($input);" + override protected def genHashString(input: ExprValue, result: ExprValue): Block = { + code"$result = $hasherClassName.hashUTF8String($input);" } override protected def genHashForArray( ctx: CodegenContext, - input: String, - result: String, + input: ExprValue, + result: ExprValue, elementType: DataType, - containsNull: Boolean): String = { - val index = ctx.freshName("index") - val childResult = ctx.freshName("childResult") - s""" + containsNull: Boolean): Block = { + val index = JavaCode.variable(ctx.freshName("index"), IntegerType) + val childResult = JavaCode.variable(ctx.freshName("childResult"), IntegerType) + code""" int $childResult = 0; for (int $index = 0; $index < $input.numElements(); $index++) { $childResult = 0; @@ -748,17 +762,17 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override protected def genHashForMap( ctx: CodegenContext, - input: String, - result: String, + input: ExprValue, + result: ExprValue, keyType: DataType, valueType: DataType, - valueContainsNull: Boolean): String = { - val index = ctx.freshName("index") - val keys = ctx.freshName("keys") - val values = ctx.freshName("values") - val keyResult = ctx.freshName("keyResult") - val valueResult = ctx.freshName("valueResult") - s""" + valueContainsNull: Boolean): Block = { + val index = JavaCode.variable(ctx.freshName("index"), IntegerType) + val keys = JavaCode.variable(ctx.freshName("keys"), classOf[ArrayData]) + val values = JavaCode.variable(ctx.freshName("values"), classOf[ArrayData]) + val keyResult = JavaCode.variable(ctx.freshName("keyResult"), IntegerType) + val valueResult = JavaCode.variable(ctx.freshName("valueResult"), IntegerType) + code""" final ArrayData $keys = $input.keyArray(); final ArrayData $values = $input.valueArray(); int $keyResult = 0; @@ -775,32 +789,35 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override protected def genHashForStruct( ctx: CodegenContext, - input: String, - result: String, - fields: Array[StructField]): String = { - val childResult = ctx.freshName("childResult") + input: ExprValue, + result: ExprValue, + fields: Array[StructField]): Block = { + val childResult = JavaCode.variable(ctx.freshName("childResult"), IntegerType) val fieldsHash = fields.zipWithIndex.map { case (field, index) => + val indexExpr = JavaCode.literal(index.toString, IntegerType) val computeFieldHash = nullSafeElementHash( - input, index.toString, field.nullable, field.dataType, childResult, ctx) - s""" + input, indexExpr, field.nullable, field.dataType, childResult, ctx) + code""" |$childResult = 0; |$computeFieldHash |$result = (31 * $result) + $childResult; """.stripMargin } - s"${CodeGenerator.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions( + code"${inline"${CodeGenerator.JAVA_INT}"} $childResult = 0;\n" + ctx.splitExpressions( expressions = fieldsHash, funcName = "computeHashForStruct", - arguments = Seq("InternalRow" -> input, CodeGenerator.JAVA_INT -> result), + arguments = Seq(input, result), returnType = CodeGenerator.JAVA_INT, makeSplitFunction = body => s""" - |${CodeGenerator.JAVA_INT} $childResult = 0; + |${inline"${CodeGenerator.JAVA_INT}"} $childResult = 0; |$body |return $result; """.stripMargin, - foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n")) + foldFunctions = funcCalls => { + Blocks(funcCalls.map(funcCall => code"$result = $funcCall;")) + }) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala index 3b0141ad52cc..e6620991ae34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, JavaCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -42,9 +42,9 @@ case class InputFileName() extends LeafExpression with Nondeterministic { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - val typeDef = s"final ${CodeGenerator.javaType(dataType)}" - ev.copy(code = code"$typeDef ${ev.value} = $className.getInputFilePath();", + val className = inline"${InputFileBlockHolder.getClass.getName.stripSuffix("$")}" + val typeDef = inline"${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"final $typeDef ${ev.value} = $className.getInputFilePath();", isNull = FalseLiteral) } } @@ -66,9 +66,10 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - val typeDef = s"final ${CodeGenerator.javaType(dataType)}" - ev.copy(code = code"$typeDef ${ev.value} = $className.getStartOffset();", isNull = FalseLiteral) + val className = inline"${InputFileBlockHolder.getClass.getName.stripSuffix("$")}" + val typeDef = inline"${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"final $typeDef ${ev.value} = $className.getStartOffset();", + isNull = FalseLiteral) } } @@ -89,8 +90,9 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - val typeDef = s"final ${CodeGenerator.javaType(dataType)}" - ev.copy(code = code"$typeDef ${ev.value} = $className.getLength();", isNull = FalseLiteral) + val className = inline"${InputFileBlockHolder.getClass.getName.stripSuffix("$")}" + val typeDef = inline"${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"final $typeDef ${ev.value} = $className.getLength();", + isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index c2e1720259b5..c86a49571b6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -70,10 +70,10 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String) } // name of function in java.lang.Math - def funcName: String = name.toLowerCase(Locale.ROOT) + def funcName: JavaCode = inline"${name.toLowerCase(Locale.ROOT)}" override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)") + defineCodeGen(ctx, ev, c => code"java.lang.Math.${funcName}($c)") } } @@ -92,7 +92,7 @@ abstract class UnaryLogExpression(f: Double => Double, name: String) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => - s""" + code""" if ($c <= $yAsymptote) { ${ev.isNull} = true; } else { @@ -126,8 +126,9 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val nameInCode = inline"${name.toLowerCase(Locale.ROOT)}" defineCodeGen(ctx, ev, (c1, c2) => - s"java.lang.Math.${name.toLowerCase(Locale.ROOT)}($c1, $c2)") + code"java.lang.Math.$nameInCode($c1, $c2)") } } @@ -246,11 +247,11 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.dataType match { - case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") + case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => code"$c") case DecimalType.Fixed(_, _) => - defineCodeGen(ctx, ev, c => s"$c.ceil()") - case LongType => defineCodeGen(ctx, ev, c => s"$c") - case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") + defineCodeGen(ctx, ev, c => code"$c.ceil()") + case LongType => defineCodeGen(ctx, ev, c => code"$c") + case _ => defineCodeGen(ctx, ev, c => code"(long)(java.lang.Math.${funcName}($c))") } } } @@ -319,9 +320,9 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val numconv = NumberConverter.getClass.getName.stripSuffix("$") + val numconv = inline"${NumberConverter.getClass.getName.stripSuffix("$")}" nullSafeCodeGen(ctx, ev, (num, from, to) => - s""" + code""" ${ev.value} = $numconv.convert($num.getBytes(), $from, $to); if (${ev.value} == null) { ${ev.isNull} = true; @@ -377,11 +378,11 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.dataType match { - case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") + case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => code"$c") case DecimalType.Fixed(_, _) => - defineCodeGen(ctx, ev, c => s"$c.floor()") - case LongType => defineCodeGen(ctx, ev, c => s"$c") - case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") + defineCodeGen(ctx, ev, c => code"$c.floor()") + case LongType => defineCodeGen(ctx, ev, c => code"$c") + case _ => defineCodeGen(ctx, ev, c => code"(long)(java.lang.Math.${funcName}($c))") } } } @@ -444,7 +445,7 @@ case class Factorial(child: Expression) extends UnaryExpression with ImplicitCas override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { - s""" + code""" if ($eval > 20 || $eval < 0) { ${ev.isNull} = true; } else { @@ -476,7 +477,7 @@ case class Log2(child: Expression) extends UnaryLogExpression((x: Double) => math.log(x) / math.log(2), "LOG2") { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => - s""" + code""" if ($c <= $yAsymptote) { ${ev.isNull} = true; } else { @@ -517,7 +518,7 @@ case class Log1p(child: Expression) extends UnaryLogExpression(math.log1p, "LOG1 """) // scalastyle:on line.size.limit case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") { - override def funcName: String = "rint" + override def funcName: JavaCode = inline"rint" } @ExpressionDescription( @@ -597,7 +598,7 @@ case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN") case class Cot(child: Expression) extends UnaryMathExpression((x: Double) => 1 / math.tan(x), "COT") { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"${ev.value} = 1 / java.lang.Math.tan($c);") + defineCodeGen(ctx, ev, c => code"${ev.value} = 1 / java.lang.Math.tan($c);") } } @@ -629,7 +630,7 @@ case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH" 180.0 """) case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") { - override def funcName: String = "toDegrees" + override def funcName: JavaCode = inline"toDegrees" } @ExpressionDescription( @@ -644,7 +645,7 @@ case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegre 3.141592653589793 """) case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") { - override def funcName: String = "toRadians" + override def funcName: JavaCode = inline"toRadians" } // scalastyle:off line.size.limit @@ -671,7 +672,7 @@ case class Bin(child: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c) => - s"UTF8String.fromString(java.lang.Long.toBinaryString($c))") + code"UTF8String.fromString(java.lang.Long.toBinaryString($c))") } } @@ -775,10 +776,10 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInput override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (c) => { - val hex = Hex.getClass.getName.stripSuffix("$") - s"${ev.value} = " + (child.dataType match { - case StringType => s"""$hex.hex($c.getBytes());""" - case _ => s"""$hex.hex($c);""" + val hex = inline"${Hex.getClass.getName.stripSuffix("$")}" + code"${ev.value} = " + (child.dataType match { + case StringType => code"""$hex.hex($c.getBytes());""" + case _ => code"""$hex.hex($c);""" }) }) } @@ -807,8 +808,8 @@ case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInp override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (c) => { - val hex = Hex.getClass.getName.stripSuffix("$") - s""" + val hex = inline"${Hex.getClass.getName.stripSuffix("$")}" + code""" ${ev.value} = $hex.unhex($c.getBytes()); ${ev.isNull} = ${ev.value} == null; """ @@ -848,7 +849,7 @@ case class Atan2(left: Expression, right: Expression) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + defineCodeGen(ctx, ev, (c1, c2) => code"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") } } @@ -862,7 +863,7 @@ case class Atan2(left: Expression, right: Expression) case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + defineCodeGen(ctx, ev, (c1, c2) => code"java.lang.Math.pow($c1, $c2)") } } @@ -896,7 +897,7 @@ case class ShiftLeft(left: Expression, right: Expression) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (left, right) => s"$left << $right") + defineCodeGen(ctx, ev, (left, right) => code"$left << $right") } } @@ -930,7 +931,7 @@ case class ShiftRight(left: Expression, right: Expression) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (left, right) => s"$left >> $right") + defineCodeGen(ctx, ev, (left, right) => code"$left >> $right") } } @@ -964,7 +965,7 @@ case class ShiftRightUnsigned(left: Expression, right: Expression) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (left, right) => s"$left >>> $right") + defineCodeGen(ctx, ev, (left, right) => code"$left >>> $right") } } @@ -1014,7 +1015,7 @@ case class Logarithm(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { if (left.isInstanceOf[EulerNumber]) { nullSafeCodeGen(ctx, ev, (c1, c2) => - s""" + code""" if ($c2 <= 0.0) { ${ev.isNull} = true; } else { @@ -1023,7 +1024,7 @@ case class Logarithm(left: Expression, right: Expression) """) } else { nullSafeCodeGen(ctx, ev, (c1, c2) => - s""" + code""" if ($c1 <= 0.0 || $c2 <= 0.0) { ${ev.isNull} = true; } else { @@ -1134,63 +1135,64 @@ abstract class RoundBase(child: Expression, scale: Expression, override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val ce = child.genCode(ctx) + val mode = inline"$modeStr" val evaluationCode = dataType match { case DecimalType.Fixed(_, s) => - s""" - ${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s, Decimal.$modeStr()); + code""" + ${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s, Decimal.$mode()); ${ev.isNull} = ${ev.value} == null;""" case ByteType => if (_scale < 0) { - s""" + code""" ${ev.value} = new java.math.BigDecimal(${ce.value}). - setScale(${_scale}, java.math.BigDecimal.${modeStr}).byteValue();""" + setScale(${_scale}, java.math.BigDecimal.${mode}).byteValue();""" } else { - s"${ev.value} = ${ce.value};" + code"${ev.value} = ${ce.value};" } case ShortType => if (_scale < 0) { - s""" + code""" ${ev.value} = new java.math.BigDecimal(${ce.value}). - setScale(${_scale}, java.math.BigDecimal.${modeStr}).shortValue();""" + setScale(${_scale}, java.math.BigDecimal.${mode}).shortValue();""" } else { - s"${ev.value} = ${ce.value};" + code"${ev.value} = ${ce.value};" } case IntegerType => if (_scale < 0) { - s""" + code""" ${ev.value} = new java.math.BigDecimal(${ce.value}). - setScale(${_scale}, java.math.BigDecimal.${modeStr}).intValue();""" + setScale(${_scale}, java.math.BigDecimal.${mode}).intValue();""" } else { - s"${ev.value} = ${ce.value};" + code"${ev.value} = ${ce.value};" } case LongType => if (_scale < 0) { - s""" + code""" ${ev.value} = new java.math.BigDecimal(${ce.value}). - setScale(${_scale}, java.math.BigDecimal.${modeStr}).longValue();""" + setScale(${_scale}, java.math.BigDecimal.${mode}).longValue();""" } else { - s"${ev.value} = ${ce.value};" + code"${ev.value} = ${ce.value};" } case FloatType => // if child eval to NaN or Infinity, just return it. - s""" + code""" if (Float.isNaN(${ce.value}) || Float.isInfinite(${ce.value})) { ${ev.value} = ${ce.value}; } else { ${ev.value} = java.math.BigDecimal.valueOf(${ce.value}). - setScale(${_scale}, java.math.BigDecimal.${modeStr}).floatValue(); + setScale(${_scale}, java.math.BigDecimal.${mode}).floatValue(); }""" case DoubleType => // if child eval to NaN or Infinity, just return it. - s""" + code""" if (Double.isNaN(${ce.value}) || Double.isInfinite(${ce.value})) { ${ev.value} = ${ce.value}; } else { ${ev.value} = java.math.BigDecimal.valueOf(${ce.value}). - setScale(${_scale}, java.math.BigDecimal.${modeStr}).doubleValue(); + setScale(${_scale}, java.math.BigDecimal.${mode}).doubleValue(); }""" } - val javaType = CodeGenerator.javaType(dataType) + val javaType = inline"${CodeGenerator.javaType(dataType)}" if (scaleV == null) { // if scale is null, no need to eval its child at all ev.copy(code = code""" boolean ${ev.isNull} = true; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 5d98dac46cf1..6668578c1f44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -43,9 +43,10 @@ case class PrintToStderr(child: Expression) extends UnaryExpression { private val outputPrefix = s"Result of ${child.simpleString} is " override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val outputPrefixField = ctx.addReferenceObj("outputPrefix", outputPrefix) + val outputPrefixField = JavaCode.global(ctx.addReferenceObj("outputPrefix", outputPrefix), + classOf[String]) nullSafeCodeGen(ctx, ev, c => - s""" + code""" | System.err.println($outputPrefixField + $c); | ${ev.value} = $c; """.stripMargin) @@ -88,7 +89,7 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the value is null or false. - val errMsgField = ctx.addReferenceObj("errMsg", errMsg) + val errMsgField = JavaCode.global(ctx.addReferenceObj("errMsg", errMsg), classOf[String]) ExprCode(code = code"""${eval.code} |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException($errMsgField); @@ -145,7 +146,8 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Sta randomGenerator.getNextUUIDUTF8String() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val randomGen = ctx.freshName("randomGen") + val randomGen = JavaCode.variable(ctx.freshName("randomGen"), + classOf[org.apache.spark.sql.catalyst.util.RandomUUIDGenerator]) ctx.addMutableState("org.apache.spark.sql.catalyst.util.RandomUUIDGenerator", randomGen, forceInline = true, useFreshName = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 2eeed3bbb2d9..a23ea403b1ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -78,7 +78,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => val eval = e.genCode(ctx) - s""" + code""" |${eval.code} |if (!${eval.isNull}) { | ${ev.isNull} = false; @@ -88,11 +88,11 @@ case class Coalesce(children: Seq[Expression]) extends Expression { """.stripMargin } - val resultType = CodeGenerator.javaType(dataType) + val resultType = inline"${CodeGenerator.javaType(dataType)}" val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "coalesce", - returnType = resultType, + returnType = resultType.code, makeSplitFunction = func => s""" |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -101,14 +101,16 @@ case class Coalesce(children: Seq[Expression]) extends Expression { |} while (false); |return ${ev.value}; """.stripMargin, - foldFunctions = _.map { funcCall => - s""" - |${ev.value} = $funcCall; - |if (!${ev.isNull}) { - | continue; - |} - """.stripMargin - }.mkString) + foldFunctions = funcCalls => { + Blocks(funcCalls.map { funcCall => + code""" + |${ev.value} = $funcCall; + |if (!${ev.isNull}) { + | continue; + |} + """.stripMargin + }) + }) ev.copy(code = @@ -231,11 +233,12 @@ case class IsNaN(child: Expression) extends UnaryExpression override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) + val javaType = inline"${CodeGenerator.javaType(dataType)}" child.dataType match { case DoubleType | FloatType => ev.copy(code = code""" ${eval.code} - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = FalseLiteral) } } @@ -277,12 +280,13 @@ case class NaNvl(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val leftGen = left.genCode(ctx) val rightGen = right.genCode(ctx) + val javaType = inline"${CodeGenerator.javaType(dataType)}" left.dataType match { case DoubleType | FloatType => ev.copy(code = code""" ${leftGen.code} boolean ${ev.isNull} = false; - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (${leftGen.isNull}) { ${ev.isNull} = true; } else { @@ -389,13 +393,13 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val nonnull = ctx.freshName("nonnull") + val nonnull = JavaCode.variable(ctx.freshName("nonnull"), IntegerType) // all evals are meant to be inside a do { ... } while (false); loop val evals = children.map { e => val eval = e.genCode(ctx) e.dataType match { case DoubleType | FloatType => - s""" + code""" |if ($nonnull < $n) { | ${eval.code} | if (!${eval.isNull} && !Double.isNaN(${eval.value})) { @@ -406,7 +410,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate |} """.stripMargin case _ => - s""" + code""" |if ($nonnull < $n) { | ${eval.code} | if (!${eval.isNull}) { @@ -422,7 +426,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "atLeastNNonNulls", - extraArguments = (CodeGenerator.JAVA_INT, nonnull) :: Nil, + extraArguments = (nonnull) :: Nil, returnType = CodeGenerator.JAVA_INT, makeSplitFunction = body => s""" @@ -431,22 +435,26 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate |} while (false); |return $nonnull; """.stripMargin, - foldFunctions = _.map { funcCall => - s""" - |$nonnull = $funcCall; - |if ($nonnull >= $n) { - | continue; - |} - """.stripMargin - }.mkString) + foldFunctions = funcCalls => { + Blocks(funcCalls.map { funcCall => + code""" + |$nonnull = $funcCall; + |if ($nonnull >= $n) { + | continue; + |} + """.stripMargin + }) + }) + val intType = inline"${CodeGenerator.JAVA_INT}" + val booleanType = inline"${CodeGenerator.JAVA_BOOLEAN}" ev.copy(code = code""" - |${CodeGenerator.JAVA_INT} $nonnull = 0; + |$intType $nonnull = 0; |do { | $codes |} while (false); - |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n; + |$booleanType ${ev.value} = $nonnull >= $n; """.stripMargin, isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 2bf4203d0fec..9d8caa3c4218 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -63,7 +63,7 @@ trait InvokeLike extends Expression with NonSQLExpression { * @param ctx a [[CodegenContext]] * @return (code to prepare arguments, argument string, result of argument null check) */ - def prepareArguments(ctx: CodegenContext): (String, String, ExprValue) = { + def prepareArguments(ctx: CodegenContext): (Block, Block, ExprValue) = { val resultIsNull = if (needNullCheck) { val resultIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "resultIsNull") @@ -72,20 +72,22 @@ trait InvokeLike extends Expression with NonSQLExpression { FalseLiteral } val argValues = arguments.map { e => - val argValue = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "argValue") + val argValue = JavaCode.global( + ctx.addMutableState(CodeGenerator.javaType(e.dataType), "argValue"), + e.dataType) argValue } val argCodes = if (needNullCheck) { - val reset = s"$resultIsNull = false;" + val reset = code"$resultIsNull = false;" val argCodes = arguments.zipWithIndex.map { case (e, i) => val expr = e.genCode(ctx) val updateResultIsNull = if (e.nullable) { - s"$resultIsNull = ${expr.isNull};" + code"$resultIsNull = ${expr.isNull};" } else { - "" + EmptyBlock } - s""" + code""" if (!$resultIsNull) { ${expr.code} $updateResultIsNull @@ -97,15 +99,22 @@ trait InvokeLike extends Expression with NonSQLExpression { } else { arguments.zipWithIndex.map { case (e, i) => val expr = e.genCode(ctx) - s""" + code""" ${expr.code} ${argValues(i)} = ${expr.value}; """ } } val argCode = ctx.splitExpressionsWithCurrentInputs(argCodes) + val argValueBlock = argValues.foldLeft[Block](EmptyBlock) { (block, argValue) => + if (block.length == 0) { + code"$argValue" + } else { + code"$block, $argValue" + } + } - (argCode, argValues.mkString(", "), resultIsNull) + (argCode, argValueBlock, resultIsNull) } /** @@ -159,19 +168,19 @@ trait SerializerSupport { * Adds a immutable state to the generated class containing a reference to the serializer. * @return a string containing the name of the variable referencing the serializer */ - def addImmutableSerializerIfNeeded(ctx: CodegenContext): String = { + def addImmutableSerializerIfNeeded(ctx: CodegenContext): ExprValue = { val (serializerInstance, serializerInstanceClass) = { if (kryo) { - ("kryoSerializer", + (JavaCode.variable("kryoSerializer", classOf[KryoSerializerInstance]), classOf[KryoSerializerInstance].getName) } else { - ("javaSerializer", + (JavaCode.variable("javaSerializer", classOf[JavaSerializerInstance]), classOf[JavaSerializerInstance].getName) } } val newSerializerMethod = s"${classOf[SerializerSupport].getName}$$.MODULE$$.newSerializer" // Code to initialize the serializer - ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializerInstance, v => + ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializerInstance.code, v => s""" |$v = ($serializerInstanceClass) $newSerializerMethod($kryo); """.stripMargin) @@ -237,29 +246,32 @@ case class StaticInvoke( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = CodeGenerator.javaType(dataType) + val javaType = inline"${CodeGenerator.javaType(dataType)}" val (argCode, argString, resultIsNull) = prepareArguments(ctx) - val callFunc = s"$objectName.$functionName($argString)" + val callFunc = + code"${inline"$objectName"}.${inline"$functionName"}($argString)" val prepareIsNull = if (nullable) { - s"boolean ${ev.isNull} = $resultIsNull;" + code"boolean ${ev.isNull} = $resultIsNull;" } else { ev.isNull = FalseLiteral - "" + EmptyBlock } val evaluate = if (returnNullable) { - if (CodeGenerator.defaultValue(dataType) == "null") { - s""" + if (CodeGenerator.defaultValue(dataType).code == "null") { + code""" ${ev.value} = $callFunc; ${ev.isNull} = ${ev.value} == null; """ } else { - val boxedResult = ctx.freshName("boxedResult") - s""" - ${CodeGenerator.boxedType(dataType)} $boxedResult = $callFunc; + val boxedResult = JavaCode.variable(ctx.freshName("boxedResult"), + ScalaReflection.javaBoxedType(dataType)) + val boxedJavaType = CodeGenerator.boxedType(dataType) + code""" + ${inline"$boxedJavaType"} $boxedResult = $callFunc; ${ev.isNull} = $boxedResult == null; if (!${ev.isNull}) { ${ev.value} = $boxedResult; @@ -267,7 +279,7 @@ case class StaticInvoke( """ } } else { - s"${ev.value} = $callFunc;" + code"${ev.value} = $callFunc;" } val code = code""" @@ -342,7 +354,7 @@ case class Invoke( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = CodeGenerator.javaType(dataType) + val javaType = inline"${CodeGenerator.javaType(dataType)}" val obj = targetObject.genCode(ctx) val (argCode, argString, resultIsNull) = prepareArguments(ctx) @@ -350,8 +362,8 @@ case class Invoke( val returnPrimitive = method.isDefined && method.get.getReturnType.isPrimitive val needTryCatch = method.isDefined && method.get.getExceptionTypes.nonEmpty - def getFuncResult(resultVal: String, funcCall: String): String = if (needTryCatch) { - s""" + def getFuncResult(resultVal: ExprValue, funcCall: Block): Block = if (needTryCatch) { + code""" try { $resultVal = $funcCall; } catch (Exception e) { @@ -359,29 +371,31 @@ case class Invoke( } """ } else { - s"$resultVal = $funcCall;" + code"$resultVal = $funcCall;" } + val functionName = inline"$encodedFunctionName" val evaluate = if (returnPrimitive) { - getFuncResult(ev.value, s"${obj.value}.$encodedFunctionName($argString)") + getFuncResult(ev.value, code"${obj.value}.$functionName($argString)") } else { - val funcResult = ctx.freshName("funcResult") + val funcResult = JavaCode.variable(ctx.freshName("funcResult"), classOf[Object]) + val boxedType = inline"${CodeGenerator.boxedType(javaType.code)}" // If the function can return null, we do an extra check to make sure our null bit is still // set correctly. val assignResult = if (!returnNullable) { - s"${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult;" + code"${ev.value} = ($boxedType) $funcResult;" } else { - s""" + code""" if ($funcResult != null) { - ${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult; + ${ev.value} = ($boxedType) $funcResult; } else { ${ev.isNull} = true; } """ } - s""" + code""" Object $funcResult = null; - ${getFuncResult(funcResult, s"${obj.value}.$encodedFunctionName($argString)")} + ${getFuncResult(funcResult, code"${obj.value}.$functionName($argString)")} $assignResult """ } @@ -478,7 +492,7 @@ case class NewInstance( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = CodeGenerator.javaType(dataType) + val javaType = inline"${CodeGenerator.javaType(dataType)}" val (argCode, argString, resultIsNull) = prepareArguments(ctx) @@ -487,14 +501,14 @@ case class NewInstance( ev.isNull = resultIsNull val constructorCall = outer.map { gen => - s"${gen.value}.new ${cls.getSimpleName}($argString)" + code"${gen.value}.new ${inline"${cls.getSimpleName}"}($argString)" }.getOrElse { - s"new $className($argString)" + code"new ${inline"$className"}($argString)" } val code = code""" $argCode - ${outer.map(_.code).getOrElse("")} + ${outer.map(_.code).getOrElse(EmptyBlock)} final $javaType ${ev.value} = ${ev.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $constructorCall; """ @@ -529,13 +543,13 @@ case class UnwrapOption( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = CodeGenerator.javaType(dataType) + val javaType = inline"${CodeGenerator.javaType(dataType)}" val inputObject = child.genCode(ctx) val code = inputObject.code + code""" final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty(); $javaType ${ev.value} = ${ev.isNull} ? ${CodeGenerator.defaultValue(dataType)} : - (${CodeGenerator.boxedType(javaType)}) ${inputObject.value}.get(); + (${inline"${CodeGenerator.boxedType(javaType.code)}"}) ${inputObject.value}.get(); """ ev.copy(code = code) } @@ -790,35 +804,38 @@ case class MapObjects private( ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable)) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val elementJavaType = CodeGenerator.javaType(loopVarDataType) - ctx.addMutableState(elementJavaType, loopValue, forceInline = true, useFreshName = false) + val elementJavaType = inline"${CodeGenerator.javaType(loopVarDataType)}" + ctx.addMutableState(elementJavaType.code, loopValue, forceInline = true, useFreshName = false) val genInputData = inputData.genCode(ctx) val genFunction = lambdaFunction.genCode(ctx) - val dataLength = ctx.freshName("dataLength") - val convertedArray = ctx.freshName("convertedArray") - val loopIndex = ctx.freshName("loopIndex") - - val convertedType = CodeGenerator.boxedType(lambdaFunction.dataType) + val convertedType = inline"${CodeGenerator.boxedType(lambdaFunction.dataType)}" + val dataLength = JavaCode.variable(ctx.freshName("dataLength"), IntegerType) + val convertedArray = JavaCode.variable(ctx.freshName("convertedArray"), + ArrayType(lambdaFunction.dataType)) + val loopIndex = JavaCode.variable(ctx.freshName("loopIndex"), IntegerType) + val loopValueInCodegen = JavaCode.variable(loopValue, loopVarDataType) + val loopIsNullInCodegen = JavaCode.isNullVariable(loopIsNull) // Because of the way Java defines nested arrays, we have to handle the syntax specially. // Specifically, we have to insert the [$dataLength] in between the type and any extra nested // array declarations (i.e. new String[1][]). - val arrayConstructor = if (convertedType contains "[]") { - val rawType = convertedType.takeWhile(_ != '[') - val arrayPart = convertedType.reverse.takeWhile(c => c == '[' || c == ']').reverse - s"new $rawType[$dataLength]$arrayPart" + val arrayConstructor = if (convertedType.code contains "[]") { + val rawType = inline"${convertedType.code.takeWhile(_ != '[')}" + val arrayPart = + inline"${convertedType.code.reverse.takeWhile(c => c == '[' || c == ']').reverse}" + code"new $rawType[$dataLength]$arrayPart" } else { - s"new $convertedType[$dataLength]" + code"new $convertedType[$dataLength]" } // In RowEncoder, we use `Object` to represent Array or Seq, so we need to determine the type // of input collection at runtime for this case. - val seq = ctx.freshName("seq") - val array = ctx.freshName("array") + val seq = JavaCode.variable(ctx.freshName("seq"), classOf[Seq[_]]) + val array = JavaCode.variable(ctx.freshName("array"), ArrayType(loopVarDataType)) val determineCollectionType = inputData.dataType match { case ObjectType(cls) if cls == classOf[Object] => - val seqClass = classOf[Seq[_]].getName - s""" + val seqClass = inline"${classOf[Seq[_]].getName}" + code""" $seqClass $seq = null; $elementJavaType[] $array = null; if (${genInputData.value}.getClass().isArray()) { @@ -827,7 +844,7 @@ case class MapObjects private( $seq = ($seqClass) ${genInputData.value}; } """ - case _ => "" + case _ => EmptyBlock } // `MapObjects` generates a while loop to traverse the elements of the input collection. We @@ -835,104 +852,112 @@ case class MapObjects private( // like `list.get(1)`. Here we use Iterator to traverse Seq and List. val (getLength, prepareLoop, getLoopVar) = inputDataType match { case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => - val it = ctx.freshName("it") + val it = JavaCode.variable(ctx.freshName("it"), classOf[scala.collection.Iterator[_]]) ( - s"${genInputData.value}.size()", - s"scala.collection.Iterator $it = ${genInputData.value}.toIterator();", - s"$it.next()" + code"${genInputData.value}.size()", + code"scala.collection.Iterator $it = ${genInputData.value}.toIterator();", + code"$it.next()" ) case ObjectType(cls) if cls.isArray => ( - s"${genInputData.value}.length", - "", - s"${genInputData.value}[$loopIndex]" + code"${genInputData.value}.length", + EmptyBlock, + code"${genInputData.value}[$loopIndex]" ) case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => - val it = ctx.freshName("it") + val it = JavaCode.variable(ctx.freshName("it"), classOf[java.util.Iterator[_]]) ( - s"${genInputData.value}.size()", - s"java.util.Iterator $it = ${genInputData.value}.iterator();", - s"$it.next()" + code"${genInputData.value}.size()", + code"java.util.Iterator $it = ${genInputData.value}.iterator();", + code"$it.next()" ) case ArrayType(et, _) => ( - s"${genInputData.value}.numElements()", - "", + code"${genInputData.value}.numElements()", + EmptyBlock, CodeGenerator.getValue(genInputData.value, et, loopIndex) ) case ObjectType(cls) if cls == classOf[Object] => - val it = ctx.freshName("it") + val it = JavaCode.variable(ctx.freshName("it"), classOf[scala.collection.Iterator[_]]) ( - s"$seq == null ? $array.length : $seq.size()", - s"scala.collection.Iterator $it = $seq == null ? null : $seq.toIterator();", - s"$it == null ? $array[$loopIndex] : $it.next()" + code"$seq == null ? $array.length : $seq.size()", + code"scala.collection.Iterator $it = $seq == null ? null : $seq.toIterator();", + code"$it == null ? $array[$loopIndex] : $it.next()" ) } // Make a copy of the data if it's unsafe-backed - def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = - s"$value instanceof ${clazz.getSimpleName}? ${value}.copy() : $value" - val genFunctionValue: String = lambdaFunction.dataType match { + def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: ExprValue) = + code"$value instanceof ${inline"${clazz.getSimpleName}"}? ${value}.copy() : $value" + val genFunctionValue: Block = lambdaFunction.dataType match { case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value) case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value) case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) - case _ => genFunction.value + case _ => code"${genFunction.value}" } val loopNullCheck = if (loopIsNull != "false") { ctx.addMutableState( CodeGenerator.JAVA_BOOLEAN, loopIsNull, forceInline = true, useFreshName = false) inputDataType match { - case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" - case _ => s"$loopIsNull = $loopValue == null;" + case _: ArrayType => + code"$loopIsNullInCodegen = ${genInputData.value}.isNullAt($loopIndex);" + case _ => code"$loopIsNullInCodegen = $loopValueInCodegen == null;" } } else { - "" + EmptyBlock } - val (initCollection, addElement, getResult): (String, String => String, String) = + val (initCollection, addElement, getResult): (Block, Block => Block, Block) = customCollectionCls match { case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) || classOf[scala.collection.Set[_]].isAssignableFrom(cls) => // Scala sequence or set - val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder()" - val builder = ctx.freshName("collectionBuilder") + val getBuilder = code"${inline"${cls.getName}"}$$.MODULE$$.newBuilder()" + val builder = + JavaCode.variable(ctx.freshName("collectionBuilder"), classOf[Builder[_, _]]) + val builderType = inline"${classOf[Builder[_, _]].getName}" ( - s""" - ${classOf[Builder[_, _]].getName} $builder = $getBuilder; + code""" + $builderType $builder = $getBuilder; $builder.sizeHint($dataLength); """, - genValue => s"$builder.$$plus$$eq($genValue);", - s"(${cls.getName}) $builder.result();" + genValue => code"$builder.$$plus$$eq($genValue);", + code"(${inline"${cls.getName}"}) $builder.result();" ) case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => // Java list - val builder = ctx.freshName("collectionBuilder") + val builder = JavaCode.variable(ctx.freshName("collectionBuilder"), + classOf[java.util.List[_]]) ( if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] || cls == classOf[java.util.AbstractSequentialList[_]]) { - s"${cls.getName} $builder = new java.util.ArrayList($dataLength);" + val className = inline"${cls.getName}" + code"$className $builder = new java.util.ArrayList($dataLength);" } else { - val param = Try(cls.getConstructor(Integer.TYPE)).map(_ => dataLength).getOrElse("") - s"${cls.getName} $builder = new ${cls.getName}($param);" + val className = inline"${cls.getName}" + Try(cls.getConstructor(Integer.TYPE)).map { _ => + code"$className $builder = new $className($dataLength);" + }.getOrElse(code"$className $builder = new $className();") }, - genValue => s"$builder.add($genValue);", - s"$builder;" + genValue => code"$builder.add($genValue);", + code"$builder;" ) case None => // array ( - s""" + code""" $convertedType[] $convertedArray = null; $convertedArray = $arrayConstructor; """, - genValue => s"$convertedArray[$loopIndex] = $genValue;", - s"new ${classOf[GenericArrayData].getName}($convertedArray);" + genValue => code"$convertedArray[$loopIndex] = $genValue;", + code"new ${inline"${classOf[GenericArrayData].getName}"}($convertedArray);" ) } + val javaType = inline"${CodeGenerator.javaType(dataType)}" val code = genInputData.code + code""" - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { $determineCollectionType @@ -942,12 +967,12 @@ case class MapObjects private( int $loopIndex = 0; $prepareLoop while ($loopIndex < $dataLength) { - $loopValue = ($elementJavaType) ($getLoopVar); + $loopValueInCodegen = ($elementJavaType) ($getLoopVar); $loopNullCheck ${genFunction.code} if (${genFunction.isNull}) { - ${addElement("null")} + ${addElement(code"null")} } else { ${addElement(genFunctionValue)} } @@ -1076,34 +1101,40 @@ case class CatalystToExternalMap private( } val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType] - val keyElementJavaType = CodeGenerator.javaType(mapType.keyType) - ctx.addMutableState(keyElementJavaType, keyLoopValue, forceInline = true, useFreshName = false) + val keyElementJavaType = inline"${CodeGenerator.javaType(mapType.keyType)}" + ctx.addMutableState(keyElementJavaType.code, keyLoopValue, forceInline = true, + useFreshName = false) val genKeyFunction = keyLambdaFunction.genCode(ctx) - val valueElementJavaType = CodeGenerator.javaType(mapType.valueType) - ctx.addMutableState(valueElementJavaType, valueLoopValue, forceInline = true, + val valueElementJavaType = inline"${CodeGenerator.javaType(mapType.valueType)}" + ctx.addMutableState(valueElementJavaType.code, valueLoopValue, forceInline = true, useFreshName = false) val genValueFunction = valueLambdaFunction.genCode(ctx) val genInputData = inputData.genCode(ctx) - val dataLength = ctx.freshName("dataLength") - val loopIndex = ctx.freshName("loopIndex") - val tupleLoopValue = ctx.freshName("tupleLoopValue") - val builderValue = ctx.freshName("builderValue") + val dataLength = JavaCode.variable(ctx.freshName("dataLength"), IntegerType) + val loopIndex = JavaCode.variable(ctx.freshName("loopIndex"), IntegerType) + val tupleLoopValue = JavaCode.variable(ctx.freshName("tupleLoopValue"), classOf[(_, _)]) + val builderValue = JavaCode.variable(ctx.freshName("builderValue"), classOf[Builder[_, _]]) - val getLength = s"${genInputData.value}.numElements()" + val keyLoopValueInCodegen = JavaCode.variable(keyLoopValue, mapType.keyType) + val valueLoopValueInCodegen = JavaCode.variable(valueLoopValue, mapType.valueType) + val valueLoopIsNullInCodegen = JavaCode.isNullVariable(valueLoopIsNull) - val keyArray = ctx.freshName("keyArray") - val valueArray = ctx.freshName("valueArray") + val getLength = code"${genInputData.value}.numElements()" + + val keyArray = JavaCode.variable(ctx.freshName("keyArray"), classOf[ArrayData]) + val valueArray = JavaCode.variable(ctx.freshName("valueArray"), classOf[ArrayData]) + val arrayDataType = inline"${classOf[ArrayData].getName}" val getKeyArray = - s"${classOf[ArrayData].getName} $keyArray = ${genInputData.value}.keyArray();" + code"$arrayDataType $keyArray = ${genInputData.value}.keyArray();" val getKeyLoopVar = CodeGenerator.getValue(keyArray, inputDataType(mapType.keyType), loopIndex) val getValueArray = - s"${classOf[ArrayData].getName} $valueArray = ${genInputData.value}.valueArray();" + code"$arrayDataType $valueArray = ${genInputData.value}.valueArray();" val getValueLoopVar = CodeGenerator.getValue( valueArray, inputDataType(mapType.valueType), loopIndex) // Make a copy of the data if it's unsafe-backed - def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = - s"$value instanceof ${clazz.getSimpleName}? $value.copy() : $value" + def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: ExprValue) = + code"$value instanceof ${inline"${clazz.getSimpleName}"}? $value.copy() : $value" def genFunctionValue(lambdaFunction: Expression, genFunction: ExprCode) = lambdaFunction.dataType match { case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value) @@ -1117,19 +1148,20 @@ case class CatalystToExternalMap private( val valueLoopNullCheck = if (valueLoopIsNull != "false") { ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, valueLoopIsNull, forceInline = true, useFreshName = false) - s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" + code"$valueLoopIsNullInCodegen = $valueArray.isNullAt($loopIndex);" } else { - "" + EmptyBlock } - val builderClass = classOf[Builder[_, _]].getName - val constructBuilder = s""" - $builderClass $builderValue = ${collClass.getName}$$.MODULE$$.newBuilder(); + val builderClass = inline"${classOf[Builder[_, _]].getName}" + val collClassName = inline"${collClass.getName}" + val constructBuilder = code""" + $builderClass $builderValue = $collClassName$$.MODULE$$.newBuilder(); $builderValue.sizeHint($dataLength); """ - val tupleClass = classOf[(_, _)].getName - val appendToBuilder = s""" + val tupleClass = inline"${classOf[(_, _)].getName}" + val appendToBuilder = code""" $tupleClass $tupleLoopValue; if (${genValueFunction.isNull}) { @@ -1140,10 +1172,12 @@ case class CatalystToExternalMap private( $builderValue.$$plus$$eq($tupleLoopValue); """ - val getBuilderResult = s"${ev.value} = (${collClass.getName}) $builderValue.result();" + val getBuilderResult = + code"${ev.value} = (${inline"${(collClass.getName)}"}) $builderValue.result();" + val javaType = inline"${CodeGenerator.javaType(dataType)}" val code = genInputData.code + code""" - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { int $dataLength = $getLength; @@ -1153,8 +1187,8 @@ case class CatalystToExternalMap private( int $loopIndex = 0; while ($loopIndex < $dataLength) { - $keyLoopValue = ($keyElementJavaType) ($getKeyLoopVar); - $valueLoopValue = ($valueElementJavaType) ($getValueLoopVar); + $keyLoopValueInCodegen = ($keyElementJavaType) ($getKeyLoopVar); + $valueLoopValueInCodegen = ($valueElementJavaType) ($getValueLoopVar); $valueLoopNullCheck ${genKeyFunction.code} @@ -1320,74 +1354,84 @@ case class ExternalMapToCatalyst private( val inputMap = child.genCode(ctx) val genKeyConverter = keyConverter.genCode(ctx) val genValueConverter = valueConverter.genCode(ctx) - val length = ctx.freshName("length") - val index = ctx.freshName("index") - val convertedKeys = ctx.freshName("convertedKeys") - val convertedValues = ctx.freshName("convertedValues") - val entry = ctx.freshName("entry") - val entries = ctx.freshName("entries") + val length = JavaCode.variable(ctx.freshName("length"), IntegerType) + val index = JavaCode.variable(ctx.freshName("index"), IntegerType) + val convertedKeys = JavaCode.variable(ctx.freshName("convertedKeys"), classOf[Array[Object]]) + val convertedValues = JavaCode.variable(ctx.freshName("convertedValues"), + classOf[Array[Object]]) val keyElementJavaType = CodeGenerator.javaType(keyType) val valueElementJavaType = CodeGenerator.javaType(valueType) ctx.addMutableState(keyElementJavaType, key, forceInline = true, useFreshName = false) ctx.addMutableState(valueElementJavaType, value, forceInline = true, useFreshName = false) - - val (defineEntries, defineKeyValue) = child.dataType match { + val keyBoxedType = inline"${CodeGenerator.boxedType(keyType)}" + val valueBoxedType = inline"${CodeGenerator.boxedType(valueType)}" + val keyInCodegen = JavaCode.variable(key, ScalaReflection.javaBoxedType(keyType)) + val valueInCodegen = JavaCode.variable(value, ScalaReflection.javaBoxedType(valueType)) + val keyIsNullVariable = JavaCode.isNullVariable(keyIsNull) + val valueIsNullVariable = JavaCode.isNullVariable(valueIsNull) + + val (entries, defineEntries, defineKeyValue) = child.dataType match { case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => - val javaIteratorCls = classOf[java.util.Iterator[_]].getName - val javaMapEntryCls = classOf[java.util.Map.Entry[_, _]].getName + val entries = JavaCode.variable(ctx.freshName("entries"), classOf[java.util.Iterator[_]]) + val entry = JavaCode.variable(ctx.freshName("entry"), classOf[java.util.Map.Entry[_, _]]) + val javaIteratorCls = inline"${classOf[java.util.Iterator[_]].getName}" + val javaMapEntryCls = inline"${classOf[java.util.Map.Entry[_, _]].getName}" val defineEntries = - s"final $javaIteratorCls $entries = ${inputMap.value}.entrySet().iterator();" + code"final $javaIteratorCls $entries = ${inputMap.value}.entrySet().iterator();" val defineKeyValue = - s""" + code""" final $javaMapEntryCls $entry = ($javaMapEntryCls) $entries.next(); - $key = (${CodeGenerator.boxedType(keyType)}) $entry.getKey(); - $value = (${CodeGenerator.boxedType(valueType)}) $entry.getValue(); + $keyInCodegen = ($keyBoxedType) $entry.getKey(); + $valueInCodegen = ($valueBoxedType) $entry.getValue(); """ - defineEntries -> defineKeyValue + (entries, defineEntries, defineKeyValue) case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) => - val scalaIteratorCls = classOf[Iterator[_]].getName - val scalaMapEntryCls = classOf[Tuple2[_, _]].getName + val entries = JavaCode.variable(ctx.freshName("entries"), classOf[Iterator[_]]) + val entry = JavaCode.variable(ctx.freshName("entry"), classOf[Tuple2[_, _]]) + val scalaIteratorCls = inline"${classOf[Iterator[_]].getName}" + val scalaMapEntryCls = inline"${classOf[Tuple2[_, _]].getName}" - val defineEntries = s"final $scalaIteratorCls $entries = ${inputMap.value}.iterator();" + val defineEntries = code"final $scalaIteratorCls $entries = ${inputMap.value}.iterator();" val defineKeyValue = - s""" + code""" final $scalaMapEntryCls $entry = ($scalaMapEntryCls) $entries.next(); - $key = (${CodeGenerator.boxedType(keyType)}) $entry._1(); - $value = (${CodeGenerator.boxedType(valueType)}) $entry._2(); + $keyInCodegen = ($keyBoxedType) $entry._1(); + $valueInCodegen = ($valueBoxedType) $entry._2(); """ - defineEntries -> defineKeyValue + (entries, defineEntries, defineKeyValue) } val keyNullCheck = if (keyIsNull != "false") { ctx.addMutableState( CodeGenerator.JAVA_BOOLEAN, keyIsNull, forceInline = true, useFreshName = false) - s"$keyIsNull = $key == null;" + code"$keyIsNullVariable = $keyInCodegen == null;" } else { - "" + EmptyBlock } val valueNullCheck = if (valueIsNull != "false") { ctx.addMutableState( CodeGenerator.JAVA_BOOLEAN, valueIsNull, forceInline = true, useFreshName = false) - s"$valueIsNull = $value == null;" + code"$valueIsNullVariable = $valueInCodegen == null;" } else { - "" + EmptyBlock } - val arrayCls = classOf[GenericArrayData].getName - val mapCls = classOf[ArrayBasedMapData].getName - val convertedKeyType = CodeGenerator.boxedType(keyConverter.dataType) - val convertedValueType = CodeGenerator.boxedType(valueConverter.dataType) + val arrayCls = inline"${classOf[GenericArrayData].getName}" + val mapCls = inline"${classOf[ArrayBasedMapData].getName}" + val convertedKeyType = inline"${CodeGenerator.boxedType(keyConverter.dataType)}" + val convertedValueType = inline"${CodeGenerator.boxedType(valueConverter.dataType)}" + val javaType = inline"${CodeGenerator.javaType(dataType)}" val code = inputMap.code + code""" - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${inputMap.isNull}) { final int $length = ${inputMap.value}.size(); final Object[] $convertedKeys = new Object[$length]; @@ -1442,12 +1486,12 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val rowClass = classOf[GenericRowWithSchema].getName - val values = ctx.freshName("values") + val rowWithSchemaClass = inline"${classOf[GenericRowWithSchema].getName}" + val values = JavaCode.variable(ctx.freshName("values"), classOf[Array[Object]]) val childrenCodes = children.zipWithIndex.map { case (e, i) => val eval = e.genCode(ctx) - s""" + code""" |${eval.code} |if (${eval.isNull}) { | $values[$i] = null; @@ -1460,14 +1504,14 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) val childrenCode = ctx.splitExpressionsWithCurrentInputs( expressions = childrenCodes, funcName = "createExternalRow", - extraArguments = "Object[]" -> values :: Nil) - val schemaField = ctx.addReferenceObj("schema", schema) - + extraArguments = values :: Nil) + val schemaField = JavaCode.global(ctx.addReferenceObj("schema", schema), classOf[StructType]) + val rowClass = inline"${classOf[Row].getName}" val code = code""" |Object[] $values = new Object[${children.size}]; |$childrenCode - |final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); + |final $rowClass ${ev.value} = new $rowWithSchemaClass($values, $schemaField); """.stripMargin ev.copy(code = code, isNull = FalseLiteral) } @@ -1489,8 +1533,8 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) val serializer = addImmutableSerializerIfNeeded(ctx) // Code to serialize. val input = child.genCode(ctx) - val javaType = CodeGenerator.javaType(dataType) - val serialize = s"$serializer.serialize(${input.value}, null).array()" + val javaType = inline"${CodeGenerator.javaType(dataType)}" + val serialize = code"$serializer.serialize(${input.value}, null).array()" val code = input.code + code""" final $javaType ${ev.value} = @@ -1520,9 +1564,9 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B val serializer = addImmutableSerializerIfNeeded(ctx) // Code to deserialize. val input = child.genCode(ctx) - val javaType = CodeGenerator.javaType(dataType) + val javaType = inline"${CodeGenerator.javaType(dataType)}" val deserialize = - s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)" + code"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)" val code = input.code + code""" final $javaType ${ev.value} = @@ -1587,23 +1631,23 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val instanceGen = beanInstance.genCode(ctx) - val javaBeanInstance = ctx.freshName("javaBean") - val beanInstanceJavaType = CodeGenerator.javaType(beanInstance.dataType) + val javaBeanInstance = JavaCode.variable(ctx.freshName("javaBean"), beanInstance.dataType) + val beanInstanceJavaType = inline"${CodeGenerator.javaType(beanInstance.dataType)}" val initialize = setters.map { case (setterMethod, fieldValue) => val fieldGen = fieldValue.genCode(ctx) - s""" + code""" |${fieldGen.code} |if (!${fieldGen.isNull}) { - | $javaBeanInstance.$setterMethod(${fieldGen.value}); + | $javaBeanInstance.${inline"$setterMethod"}(${fieldGen.value}); |} """.stripMargin } val initializeCode = ctx.splitExpressionsWithCurrentInputs( expressions = initialize.toSeq, funcName = "initializeJavaBean", - extraArguments = beanInstanceJavaType -> javaBeanInstance :: Nil) + extraArguments = javaBeanInstance :: Nil) val code = instanceGen.code + code""" @@ -1652,7 +1696,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the value is null. - val errMsgField = ctx.addReferenceObj("errMsg", errMsg) + val errMsgField = JavaCode.global(ctx.addReferenceObj("errMsg", errMsg), classOf[String]) val code = childGen.code + code""" if (${childGen.isNull}) { @@ -1695,7 +1739,7 @@ case class GetExternalRowField( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the field is null. - val errMsgField = ctx.addReferenceObj("errMsg", errMsg) + val errMsgField = JavaCode.global(ctx.addReferenceObj("errMsg", errMsg), classOf[String]) val row = child.genCode(ctx) val code = code""" ${row.code} @@ -1758,26 +1802,34 @@ case class ValidateExternalType(child: Expression, expected: DataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the type doesn't match. - val errMsgField = ctx.addReferenceObj("errMsg", errMsg) + val errMsgField = JavaCode.global(ctx.addReferenceObj("errMsg", errMsg), classOf[String]) val input = child.genCode(ctx) val obj = input.value + val boxedType = inline"${CodeGenerator.boxedType(dataType)}" + val typeCheck = expected match { case _: DecimalType => Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal]) - .map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ") + .map(cls => code"$obj instanceof ${inline"${cls.getName}"}") + .reduceLeft { (blocks, block) => + code"$blocks || $block" + } case _: ArrayType => - s"$obj.getClass().isArray() || $obj instanceof ${classOf[Seq[_]].getName}" + val seqClassName = inline"${classOf[Seq[_]].getName}" + code"$obj.getClass().isArray() || $obj instanceof $seqClassName" case _ => - s"$obj instanceof ${CodeGenerator.boxedType(dataType)}" + code"$obj instanceof $boxedType" } + val javaType = inline"${CodeGenerator.javaType(dataType)}" + val code = code""" ${input.code} - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${input.isNull}) { if ($typeCheck) { - ${ev.value} = (${CodeGenerator.boxedType(dataType)}) $obj; + ${ev.value} = ($boxedType) $obj; } else { throw new RuntimeException($obj.getClass().getName() + $errMsgField); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index f54103c4fbfb..38684e23e40c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -21,7 +21,7 @@ import scala.collection.immutable.TreeSet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen.{Blocks, CodegenContext, CodeGenerator, EmptyBlock, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, JavaCode, Predicate => BasePredicate} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils @@ -132,7 +132,7 @@ case class Not(child: Expression) protected override def nullSafeEval(input: Any): Any = !input.asInstanceOf[Boolean] override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"!($c)") + defineCodeGen(ctx, ev, c => code"!($c)") } override def sql: String = s"(NOT ${child.sql})" @@ -244,7 +244,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaDataType = CodeGenerator.javaType(value.dataType) + val javaDataType = inline"${CodeGenerator.javaType(value.dataType)}" val valueGen = value.genCode(ctx) val listGen = list.map(_.genCode(ctx)) // inTmpResult has 3 possible values: @@ -254,12 +254,12 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val NOT_MATCHED = 0 // 1 means one value in the list is matched val MATCHED = 1 - val tmpResult = ctx.freshName("inTmpResult") - val valueArg = ctx.freshName("valueArg") + val tmpResult = JavaCode.variable(ctx.freshName("inTmpResult"), ByteType) + val valueArg = JavaCode.variable(ctx.freshName("valueArg"), value.dataType) // All the blocks are meant to be inside a do { ... } while (false); loop. // The evaluation of variables can be stopped when we find a matching value. val listCode = listGen.map(x => - s""" + code""" |${x.code} |if (${x.isNull}) { | $tmpResult = $HAS_NULL; // ${ev.isNull} = true; @@ -272,7 +272,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = listCode, funcName = "valueIn", - extraArguments = (javaDataType, valueArg) :: (CodeGenerator.JAVA_BYTE, tmpResult) :: Nil, + extraArguments = (valueArg) :: (tmpResult) :: Nil, returnType = CodeGenerator.JAVA_BYTE, makeSplitFunction = body => s""" @@ -281,14 +281,16 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { |} while (false); |return $tmpResult; """.stripMargin, - foldFunctions = _.map { funcCall => - s""" - |$tmpResult = $funcCall; - |if ($tmpResult == $MATCHED) { - | continue; - |} - """.stripMargin - }.mkString("\n")) + foldFunctions = funcCalls => { + Blocks(funcCalls.map { funcCall => + code""" + |$tmpResult = $funcCall; + |if ($tmpResult == $MATCHED) { + | continue; + |} + """.stripMargin + }) + }) ev.copy(code = code""" @@ -347,18 +349,18 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val setTerm = ctx.addReferenceObj("set", set) + val setTerm = JavaCode.global(ctx.addReferenceObj("set", set), classOf[Set[Any]]) val childGen = child.genCode(ctx) val setIsNull = if (hasNull) { - s"${ev.isNull} = !${ev.value};" + code"${ev.isNull} = !${ev.value};" } else { - "" + EmptyBlock } ev.copy(code = code""" |${childGen.code} - |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; - |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false; + |boolean ${ev.isNull} = ${childGen.isNull}; + |${inline"${CodeGenerator.javaType(dataType)}"} ${ev.value} = false; |if (!${ev.isNull}) { | ${ev.value} = $setTerm.contains(${childGen.value}); | $setIsNull @@ -379,7 +381,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with override def inputType: AbstractDataType = BooleanType - override def symbol: String = "&&" + override def symbol: JavaCode = inline"&&" override def sqlOperator: String = "AND" @@ -442,7 +444,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P override def inputType: AbstractDataType = BooleanType - override def symbol: String = "||" + override def symbol: JavaCode = inline"||" override def sqlOperator: String = "OR" @@ -519,9 +521,9 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { && left.dataType != FloatType && left.dataType != DoubleType) { // faster version - defineCodeGen(ctx, ev, (c1, c2) => s"$c1 $symbol $c2") + defineCodeGen(ctx, ev, (c1, c2) => code"$c1 $symbol $c2") } else { - defineCodeGen(ctx, ev, (c1, c2) => s"${ctx.genComp(left.dataType, c1, c2)} $symbol 0") + defineCodeGen(ctx, ev, (c1, c2) => code"${ctx.genComp(left.dataType, c1, c2)} $symbol 0") } } @@ -567,7 +569,7 @@ object Equality { case class EqualTo(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { - override def symbol: String = "=" + override def symbol: JavaCode = inline"=" protected override def nullSafeEval(left: Any, right: Any): Any = ordering.equiv(left, right) @@ -602,7 +604,7 @@ case class EqualTo(left: Expression, right: Expression) """) case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = "<=>" + override def symbol: JavaCode = inline"<=>" override def nullable: Boolean = false @@ -653,7 +655,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp case class LessThan(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { - override def symbol: String = "<" + override def symbol: JavaCode = inline"<" protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2) } @@ -683,7 +685,7 @@ case class LessThan(left: Expression, right: Expression) case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { - override def symbol: String = "<=" + override def symbol: JavaCode = inline"<=" protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2) } @@ -713,7 +715,7 @@ case class LessThanOrEqual(left: Expression, right: Expression) case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { - override def symbol: String = ">" + override def symbol: JavaCode = inline">" protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2) } @@ -743,7 +745,7 @@ case class GreaterThan(left: Expression, right: Expression) case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { - override def symbol: String = ">=" + override def symbol: JavaCode = inline">=" protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 926c2f00d430..5e3230f9635a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, JavaCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -80,11 +80,12 @@ case class Rand(child: Expression) extends RDG { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = classOf[XORShiftRandom].getName - val rngTerm = ctx.addMutableState(className, "rng") + val rngTerm = JavaCode.global(ctx.addMutableState(className, "rng"), classOf[XORShiftRandom]) ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") + val javaType = inline"${CodeGenerator.javaType(dataType)}" ev.copy(code = code""" - final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", + final $javaType ${ev.value} = $rngTerm.nextDouble();""", isNull = FalseLiteral) } @@ -118,11 +119,12 @@ case class Randn(child: Expression) extends RDG { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = classOf[XORShiftRandom].getName - val rngTerm = ctx.addMutableState(className, "rng") + val rngTerm = JavaCode.global(ctx.addMutableState(className, "rng"), classOf[XORShiftRandom]) ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") + val javaType = inline"${CodeGenerator.javaType(dataType)}" ev.copy(code = code""" - final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", + final $javaType ${ev.value} = $rngTerm.nextGaussian();""", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 7b68bb771faf..f43e42474653 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -111,23 +111,25 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi override def toString: String = s"$left LIKE $right" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val patternClass = classOf[Pattern].getName - val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex" + val patternClass = inline"${classOf[Pattern].getName}" + val escapeFunc = + inline"${StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex"}" + val javaType = inline"${CodeGenerator.javaType(dataType)}" if (right.foldable) { val rVal = right.eval() if (rVal != null) { val regexStr = StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) - val pattern = ctx.addMutableState(patternClass, "patternLike", - v => s"""$v = $patternClass.compile("$regexStr");""") + val pattern = JavaCode.global(ctx.addMutableState(patternClass.code, "patternLike", + v => s"""$v = $patternClass.compile("$regexStr");"""), classOf[Pattern]) // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) ev.copy(code = code""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $pattern.matcher(${eval.value}.toString()).matches(); } @@ -135,14 +137,14 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi } else { ev.copy(code = code""" boolean ${ev.isNull} = true; - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) } } else { - val pattern = ctx.freshName("pattern") - val rightStr = ctx.freshName("rightStr") + val pattern = JavaCode.variable(ctx.freshName("pattern"), classOf[Pattern]) + val rightStr = JavaCode.variable(ctx.freshName("rightStr"), classOf[String]) nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - s""" + code""" String $rightStr = $eval2.toString(); $patternClass $pattern = $patternClass.compile($escapeFunc($rightStr)); ${ev.value} = $pattern.matcher($eval1.toString()).matches(); @@ -187,22 +189,23 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress override def toString: String = s"$left RLIKE $right" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val patternClass = classOf[Pattern].getName + val patternClass = inline"${classOf[Pattern].getName}" + val javaType = inline"${CodeGenerator.javaType(dataType)}" if (right.foldable) { val rVal = right.eval() if (rVal != null) { val regexStr = StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString()) - val pattern = ctx.addMutableState(patternClass, "patternRLike", - v => s"""$v = $patternClass.compile("$regexStr");""") + val pattern = JavaCode.global(ctx.addMutableState(patternClass.code, "patternRLike", + v => s"""$v = $patternClass.compile("$regexStr");"""), classOf[Pattern]) // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) ev.copy(code = code""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $pattern.matcher(${eval.value}.toString()).find(0); } @@ -210,14 +213,14 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress } else { ev.copy(code = code""" boolean ${ev.isNull} = true; - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) } } else { - val rightStr = ctx.freshName("rightStr") - val pattern = ctx.freshName("pattern") + val rightStr = JavaCode.variable(ctx.freshName("rightStr"), classOf[String]) + val pattern = JavaCode.variable(ctx.freshName("pattern"), classOf[Pattern]) nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - s""" + code""" String $rightStr = $eval2.toString(); $patternClass $pattern = $patternClass.compile($rightStr); ${ev.value} = $pattern.matcher($eval1.toString()).find(0); @@ -252,10 +255,10 @@ case class StringSplit(str: Expression, pattern: Expression) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val arrayClass = classOf[GenericArrayData].getName + val arrayClass = inline"${classOf[GenericArrayData].getName}" nullSafeCodeGen(ctx, ev, (str, pattern) => // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. - s"""${ev.value} = new $arrayClass($str.split($pattern, -1));""") + code"""${ev.value} = new $arrayClass($str.split($pattern, -1));""") } override def prettyName: String = "split" @@ -317,26 +320,33 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio override def prettyName: String = "regexp_replace" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val termResult = ctx.freshName("termResult") + val classNameStringBuffer = inline"${classOf[java.lang.StringBuffer].getCanonicalName}" + val classNameMatcher = inline"${classOf[java.util.regex.Matcher].getCanonicalName}" - val classNamePattern = classOf[Pattern].getCanonicalName - val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName + val termResult = JavaCode.variable(ctx.freshName("termResult"), + classOf[java.lang.StringBuffer]) - val matcher = ctx.freshName("matcher") + val classNamePattern = inline"${classOf[Pattern].getCanonicalName}" - val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex") - val termPattern = ctx.addMutableState(classNamePattern, "pattern") - val termLastReplacement = ctx.addMutableState("String", "lastReplacement") - val termLastReplacementInUTF8 = ctx.addMutableState("UTF8String", "lastReplacementInUTF8") + val matcher = JavaCode.variable(ctx.freshName("matcher"), classOf[java.util.regex.Matcher]) + + val termLastRegex = JavaCode.global( + ctx.addMutableState("UTF8String", "lastRegex"), classOf[UTF8String]) + val termPattern = JavaCode.global( + ctx.addMutableState(classNamePattern.code, "pattern"), classOf[Pattern]) + val termLastReplacement = JavaCode.global( + ctx.addMutableState("String", "lastReplacement"), classOf[String]) + val termLastReplacementInUTF8 = JavaCode.global( + ctx.addMutableState("UTF8String", "lastReplacementInUTF8"), classOf[UTF8String]) val setEvNotNull = if (nullable) { - s"${ev.isNull} = false;" + code"${ev.isNull} = false;" } else { - "" + EmptyBlock } nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => { - s""" + code""" if (!$regexp.equals($termLastRegex)) { // regex value changed $termLastRegex = $regexp.clone(); @@ -348,7 +358,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio $termLastReplacement = $termLastReplacementInUTF8.toString(); } $classNameStringBuffer $termResult = new $classNameStringBuffer(); - java.util.regex.Matcher $matcher = $termPattern.matcher($subject.toString()); + $classNameMatcher $matcher = $termPattern.matcher($subject.toString()); while ($matcher.find()) { $matcher.appendReplacement($termResult, $termLastReplacement); @@ -409,21 +419,24 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio override def prettyName: String = "regexp_extract" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val classNamePattern = classOf[Pattern].getCanonicalName - val matcher = ctx.freshName("matcher") - val matchResult = ctx.freshName("matchResult") + val classNamePattern = inline"${classOf[Pattern].getCanonicalName}" + val matcher = JavaCode.variable(ctx.freshName("matcher"), classOf[java.util.regex.Matcher]) + val matchResult = JavaCode.variable(ctx.freshName("matchResult"), + classOf[java.util.regex.MatchResult]) - val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex") - val termPattern = ctx.addMutableState(classNamePattern, "pattern") + val termLastRegex = JavaCode.global(ctx.addMutableState("UTF8String", "lastRegex"), + classOf[UTF8String]) + val termPattern = JavaCode.global(ctx.addMutableState(classNamePattern.code, "pattern"), + classOf[Pattern]) val setEvNotNull = if (nullable) { - s"${ev.isNull} = false;" + code"${ev.isNull} = false;" } else { - "" + EmptyBlock } nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { - s""" + code""" if (!$regexp.equals($termLastRegex)) { // regex value changed $termLastRegex = $regexp.clone(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 9823b2fc5ad9..e9784f5278f0 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -88,24 +88,24 @@ case class ConcatWs(children: Seq[Expression]) val separator = evals.head val strings = evals.tail val numArgs = strings.length - val args = ctx.freshName("args") + val args = JavaCode.variable(ctx.freshName("args"), classOf[Array[UTF8String]]) val inputs = strings.zipWithIndex.map { case (eval, index) => if (eval.isNull != "true") { - s""" + code""" ${eval.code} if (!${eval.isNull}) { $args[$index] = ${eval.value}; } """ } else { - "" + EmptyBlock } } val codes = ctx.splitExpressionsWithCurrentInputs( expressions = inputs, funcName = "valueConcatWs", - extraArguments = ("UTF8String[]", args) :: Nil) + extraArguments = (args) :: Nil) ev.copy(code""" UTF8String[] $args = new UTF8String[$numArgs]; ${separator.code} @@ -114,31 +114,31 @@ case class ConcatWs(children: Seq[Expression]) boolean ${ev.isNull} = ${ev.value} == null; """) } else { - val array = ctx.freshName("array") - val varargNum = ctx.freshName("varargNum") - val idxVararg = ctx.freshName("idxInVararg") + val array = JavaCode.variable(ctx.freshName("array"), classOf[Array[UTF8String]]) + val varargNum = JavaCode.variable(ctx.freshName("varargNum"), IntegerType) + val idxVararg = JavaCode.variable(ctx.freshName("idxInVararg"), IntegerType) val evals = children.map(_.genCode(ctx)) val (varargCount, varargBuild) = children.tail.zip(evals.tail).map { case (child, eval) => child.dataType match { case StringType => - ("", // we count all the StringType arguments num at once below. + (EmptyBlock, // we count all the StringType arguments num at once below. if (eval.isNull == "true") { - "" + EmptyBlock } else { - s"$array[$idxVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};" + code"$array[$idxVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};" }) case _: ArrayType => - val size = ctx.freshName("n") + val size = JavaCode.variable(ctx.freshName("n"), IntegerType) if (eval.isNull == "true") { - ("", "") + (EmptyBlock, EmptyBlock) } else { - (s""" + (code""" if (!${eval.isNull}) { $varargNum += ${eval.value}.numElements(); } """, - s""" + code""" if (!${eval.isNull}) { final int $size = ${eval.value}.numElements(); for (int j = 0; j < $size; j ++) { @@ -150,7 +150,7 @@ case class ConcatWs(children: Seq[Expression]) } }.unzip - val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code.toString)) + val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code)) val varargCounts = ctx.splitExpressionsWithCurrentInputs( expressions = varargCount, @@ -162,19 +162,21 @@ case class ConcatWs(children: Seq[Expression]) |$body |return $varargNum; """.stripMargin, - foldFunctions = _.map(funcCall => s"$varargNum += $funcCall;").mkString("\n")) + foldFunctions = funcCalls => + Blocks(funcCalls.map(funcCall => code"$varargNum += $funcCall;"))) val varargBuilds = ctx.splitExpressionsWithCurrentInputs( expressions = varargBuild, funcName = "varargBuildsConcatWs", - extraArguments = ("UTF8String []", array) :: ("int", idxVararg) :: Nil, + extraArguments = (array) :: (idxVararg) :: Nil, returnType = "int", makeSplitFunction = body => s""" |$body |return $idxVararg; """.stripMargin, - foldFunctions = _.map(funcCall => s"$idxVararg = $funcCall;").mkString("\n")) + foldFunctions = funcCalls => + Blocks(funcCalls.map(funcCall => code"$idxVararg = $funcCall;"))) ev.copy( code""" @@ -250,13 +252,14 @@ case class Elt(children: Seq[Expression]) extends Expression { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val index = indexExpr.genCode(ctx) val inputs = inputExprs.map(_.genCode(ctx)) - val indexVal = ctx.freshName("index") - val indexMatched = ctx.freshName("eltIndexMatched") + val indexVal = JavaCode.variable(ctx.freshName("index"), IntegerType) + val indexMatched = JavaCode.variable(ctx.freshName("eltIndexMatched"), BooleanType) - val inputVal = ctx.addMutableState(CodeGenerator.javaType(dataType), "inputVal") + val inputVal = JavaCode.global( + ctx.addMutableState(CodeGenerator.javaType(dataType), "inputVal"), dataType) val assignInputValue = inputs.zipWithIndex.map { case (eval, index) => - s""" + code""" |if ($indexVal == ${index + 1}) { | ${eval.code} | $inputVal = ${eval.isNull} ? null : ${eval.value}; @@ -269,35 +272,38 @@ case class Elt(children: Seq[Expression]) extends Expression { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = assignInputValue, funcName = "eltFunc", - extraArguments = ("int", indexVal) :: Nil, + extraArguments = (indexVal) :: Nil, returnType = CodeGenerator.JAVA_BOOLEAN, makeSplitFunction = body => s""" - |${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false; + |${indexMatched.javaType} $indexMatched = false; |do { | $body |} while (false); |return $indexMatched; """.stripMargin, - foldFunctions = _.map { funcCall => - s""" - |$indexMatched = $funcCall; - |if ($indexMatched) { - | continue; - |} - """.stripMargin - }.mkString) + foldFunctions = funcCalls => + Blocks(funcCalls.map { funcCall => + code""" + |$indexMatched = $funcCall; + |if ($indexMatched) { + | continue; + |} + """ + }) + ) + val javaType = inline"${CodeGenerator.javaType(dataType)}" ev.copy( code""" |${index.code} |final int $indexVal = ${index.value}; - |${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false; + |boolean $indexMatched = false; |$inputVal = null; |do { | $codes |} while (false); - |final ${CodeGenerator.javaType(dataType)} ${ev.value} = $inputVal; + |final $javaType ${ev.value} = $inputVal; |final boolean ${ev.isNull} = ${ev.value} == null; """.stripMargin) } @@ -332,7 +338,7 @@ case class Upper(child: Expression) override def convert(v: UTF8String): UTF8String = v.toUpperCase override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") + defineCodeGen(ctx, ev, c => code"($c).toUpperCase()") } } @@ -351,7 +357,7 @@ case class Lower(child: Expression) extends UnaryExpression with String2StringEx override def convert(v: UTF8String): UTF8String = v.toLowerCase override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") + defineCodeGen(ctx, ev, c => code"($c).toLowerCase()") } } @@ -375,7 +381,7 @@ abstract class StringPredicate extends BinaryExpression case class Contains(left: Expression, right: Expression) extends StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") + defineCodeGen(ctx, ev, (c1, c2) => code"($c1).contains($c2)") } } @@ -385,7 +391,7 @@ case class Contains(left: Expression, right: Expression) extends StringPredicate case class StartsWith(left: Expression, right: Expression) extends StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") + defineCodeGen(ctx, ev, (c1, c2) => code"($c1).startsWith($c2)") } } @@ -395,7 +401,7 @@ case class StartsWith(left: Expression, right: Expression) extends StringPredica case class EndsWith(left: Expression, right: Expression) extends StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") + defineCodeGen(ctx, ev, (c1, c2) => code"($c1).endsWith($c2)") } } @@ -432,7 +438,7 @@ case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExp override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (src, search, replace) => { - s"""${ev.value} = $src.replace($search, $replace);""" + code"""${ev.value} = $src.replace($search, $replace);""" }) } @@ -495,17 +501,20 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val classNameDict = classOf[JMap[Character, Character]].getCanonicalName - val termLastMatching = ctx.addMutableState("UTF8String", "lastMatching") - val termLastReplace = ctx.addMutableState("UTF8String", "lastReplace") - val termDict = ctx.addMutableState(classNameDict, "dict") + val termLastMatching = JavaCode.global(ctx.addMutableState("UTF8String", "lastMatching"), + classOf[UTF8String]) + val termLastReplace = JavaCode.global(ctx.addMutableState("UTF8String", "lastReplace"), + classOf[UTF8String]) + val termDict = JavaCode.variable(ctx.addMutableState(classNameDict, "dict"), + classOf[JMap[Character, Character]]) nullSafeCodeGen(ctx, ev, (src, matching, replace) => { val check = if (matchingExpr.foldable && replaceExpr.foldable) { - s"$termDict == null" + code"$termDict == null" } else { - s"!$matching.equals($termLastMatching) || !$replace.equals($termLastReplace)" + code"!$matching.equals($termLastMatching) || !$replace.equals($termLastReplace)" } - s"""if ($check) { + code"""if ($check) { // Not all of them is literal or matching or replace value changed $termLastMatching = $matching.clone(); $termLastReplace = $replace.clone(); @@ -550,7 +559,7 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (word, set) => - s"${ev.value} = $set.findInSet($word);" + code"${ev.value} = $set.findInSet($word);" ) } @@ -666,7 +675,7 @@ case class StringTrim( } else { val trimString = evals(1) val getTrimFunction = - s""" + code""" if (${trimString.isNull}) { ${ev.isNull} = true; } else { @@ -766,7 +775,7 @@ case class StringTrimLeft( } else { val trimString = evals(1) val getTrimLeftFunction = - s""" + code""" if (${trimString.isNull}) { ${ev.isNull} = true; } else { @@ -868,7 +877,7 @@ case class StringTrimRight( } else { val trimString = evals(1) val getTrimRightFunction = - s""" + code""" if (${trimString.isNull}) { ${ev.isNull} = true; } else { @@ -918,7 +927,7 @@ case class StringInstr(str: Expression, substr: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (l, r) => - s"($l).indexOf($r, 0) + 1") + code"($l).indexOf($r, 0) + 1") } } @@ -958,7 +967,7 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (str, delim, count) => s"$str.subStringIndex($delim, $count)") + defineCodeGen(ctx, ev, (str, delim, count) => code"$str.subStringIndex($delim, $count)") } } @@ -1078,7 +1087,7 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (str, len, pad) => s"$str.lpad($len, $pad)") + defineCodeGen(ctx, ev, (str, len, pad) => code"$str.lpad($len, $pad)") } override def prettyName: String = "lpad" @@ -1111,7 +1120,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (str, len, pad) => s"$str.rpad($len, $pad)") + defineCodeGen(ctx, ev, (str, len, pad) => code"$str.rpad($len, $pad)") } override def prettyName: String = "rpad" @@ -1326,18 +1335,19 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC val pattern = children.head.genCode(ctx) val argListGen = children.tail.map(x => (x.dataType, x.genCode(ctx))) - val argList = ctx.freshName("argLists") + val argList = JavaCode.variable(ctx.freshName("argLists"), classOf[Array[Object]]) val numArgLists = argListGen.length val argListCode = argListGen.zipWithIndex.map { case(v, index) => + val boxedTypeName = inline"${CodeGenerator.boxedType(v._1)}" val value = if (CodeGenerator.boxedType(v._1) != CodeGenerator.javaType(v._1)) { // Java primitives get boxed in order to allow null values. - s"(${v._2.isNull}) ? (${CodeGenerator.boxedType(v._1)}) null : " + - s"new ${CodeGenerator.boxedType(v._1)}(${v._2.value})" + code"(${v._2.isNull}) ? ($boxedTypeName) null : " + + code"new $boxedTypeName(${v._2.value})" } else { - s"(${v._2.isNull}) ? null : ${v._2.value}" + code"(${v._2.isNull}) ? null : ${v._2.value}" } - s""" + code""" ${v._2.code} $argList[$index] = $value; """ @@ -1345,19 +1355,21 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC val argListCodes = ctx.splitExpressionsWithCurrentInputs( expressions = argListCode, funcName = "valueFormatString", - extraArguments = ("Object[]", argList) :: Nil) - - val form = ctx.freshName("formatter") - val formatter = classOf[java.util.Formatter].getName - val sb = ctx.freshName("sb") - val stringBuffer = classOf[StringBuffer].getName + extraArguments = (argList) :: Nil) + + val form = JavaCode.variable(ctx.freshName("formatter"), classOf[java.util.Formatter]) + val formatter = inline"${classOf[java.util.Formatter].getName}" + val sb = JavaCode.variable(ctx.freshName("sb"), classOf[StringBuffer]) + val stringBuffer = inline"${classOf[StringBuffer].getName}" + val localClass = inline"${classOf[Locale].getName}" + val javaType = inline"${CodeGenerator.javaType(dataType)}" ev.copy(code = code""" ${pattern.code} boolean ${ev.isNull} = ${pattern.isNull}; - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { $stringBuffer $sb = new $stringBuffer(); - $formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US); + $formatter $form = new $formatter($sb, $localClass.US); Object[] $argList = new Object[$numArgLists]; $argListCodes $form.format(${pattern.value}.toString(), $argList); @@ -1391,7 +1403,7 @@ case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastI string.asInstanceOf[UTF8String].toLowerCase.toTitleCase } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, str => s"$str.toLowerCase().toTitleCase()") + defineCodeGen(ctx, ev, str => code"$str.toLowerCase().toTitleCase()") } } @@ -1420,7 +1432,7 @@ case class StringRepeat(str: Expression, times: Expression) override def prettyName: String = "repeat" override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (l, r) => s"($l).repeat($r)") + defineCodeGen(ctx, ev, (l, r) => code"($l).repeat($r)") } } @@ -1447,7 +1459,7 @@ case class StringSpace(child: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (length) => - s"""${ev.value} = UTF8String.blankString(($length < 0) ? 0 : $length);""") + code"""${ev.value} = UTF8String.blankString(($length < 0) ? 0 : $length);""") } override def prettyName: String = "space" @@ -1496,11 +1508,11 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - + val byteArrayType = inline"${classOf[ByteArray].getName}" defineCodeGen(ctx, ev, (string, pos, len) => { str.dataType match { - case StringType => s"$string.substringSQL($pos, $len)" - case BinaryType => s"${classOf[ByteArray].getName}.subStringSQL($string, $pos, $len)" + case StringType => code"$string.substringSQL($pos, $len)" + case BinaryType => code"$byteArrayType.subStringSQL($string, $pos, $len)" } }) } @@ -1577,8 +1589,8 @@ case class Length(child: Expression) extends UnaryExpression with ImplicitCastIn override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.dataType match { - case StringType => defineCodeGen(ctx, ev, c => s"($c).numChars()") - case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") + case StringType => defineCodeGen(ctx, ev, c => code"($c).numChars()") + case BinaryType => defineCodeGen(ctx, ev, c => code"($c).length") } } } @@ -1604,8 +1616,8 @@ case class BitLength(child: Expression) extends UnaryExpression with ImplicitCas override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.dataType match { - case StringType => defineCodeGen(ctx, ev, c => s"($c).numBytes() * 8") - case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length * 8") + case StringType => defineCodeGen(ctx, ev, c => code"($c).numBytes() * 8") + case BinaryType => defineCodeGen(ctx, ev, c => code"($c).length * 8") } } @@ -1634,8 +1646,8 @@ case class OctetLength(child: Expression) extends UnaryExpression with ImplicitC override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.dataType match { - case StringType => defineCodeGen(ctx, ev, c => s"($c).numBytes()") - case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") + case StringType => defineCodeGen(ctx, ev, c => code"($c).numBytes()") + case BinaryType => defineCodeGen(ctx, ev, c => code"($c).length") } } @@ -1663,7 +1675,7 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (left, right) => - s"${ev.value} = $left.levenshteinDistance($right);") + code"${ev.value} = $left.levenshteinDistance($right);") } } @@ -1686,7 +1698,7 @@ case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputT override def nullSafeEval(input: Any): Any = input.asInstanceOf[UTF8String].soundex() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"$c.soundex()") + defineCodeGen(ctx, ev, c => code"$c.soundex()") } } @@ -1718,8 +1730,8 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (child) => { - val bytes = ctx.freshName("bytes") - s""" + val bytes = JavaCode.variable(ctx.freshName("bytes"), ByteType) + code""" byte[] $bytes = $child.getBytes(); if ($bytes.length > 0) { ${ev.value} = (int) $bytes[0]; @@ -1761,7 +1773,7 @@ case class Chr(child: Expression) extends UnaryExpression with ImplicitCastInput override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, lon => { - s""" + code""" if ($lon < 0) { ${ev.value} = UTF8String.EMPTY_UTF8; } else if (($lon & 0xFF) == 0) { @@ -1798,7 +1810,7 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (child) => { - s"""${ev.value} = UTF8String.fromBytes( + code"""${ev.value} = UTF8String.fromBytes( org.apache.commons.codec.binary.Base64.encodeBase64($child)); """}) } @@ -1824,7 +1836,7 @@ case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCast override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (child) => { - s""" + code""" ${ev.value} = org.apache.commons.codec.binary.Base64.decodeBase64($child.toString()); """}) } @@ -1859,7 +1871,7 @@ case class Decode(bin: Expression, charset: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (bytes, charset) => - s""" + code""" try { ${ev.value} = UTF8String.fromString(new String($bytes, $charset.toString())); } catch (java.io.UnsupportedEncodingException e) { @@ -1898,7 +1910,7 @@ case class Encode(value: Expression, charset: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (string, charset) => - s""" + code""" try { ${ev.value} = $string.toString().getBytes($charset.toString()); } catch (java.io.UnsupportedEncodingException e) { @@ -1994,10 +2006,10 @@ case class FormatNumber(x: Expression, d: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (num, d) => { - def typeHelper(p: String): String = { + def typeHelper(p: ExprValue): Block = { x.dataType match { - case _ : DecimalType => s"""$p.toJavaBigDecimal()""" - case _ => s"$p" + case _ : DecimalType => code"""$p.toJavaBigDecimal()""" + case _ => code"$p" } } @@ -2008,15 +2020,15 @@ case class FormatNumber(x: Expression, d: Expression) // SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.') // as a decimal separator. val usLocale = "US" - val i = ctx.freshName("i") - val dFormat = ctx.freshName("dFormat") - val lastDValue = - ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;") - val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();") - val numberFormat = ctx.addMutableState(df, "numberFormat", - v => s"""$v = new $df("", new $dfs($l.$usLocale));""") - - s""" + val i = JavaCode.variable(ctx.freshName("i"), IntegerType) + val lastDValue = JavaCode.global( + ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;"), IntegerType) + val pattern = JavaCode.global(ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();"), + IntegerType) + val numberFormat = JavaCode.global(ctx.addMutableState(df, "numberFormat", + v => s"""$v = new $df("", new $dfs($l.$usLocale));"""), classOf[DecimalFormat]) + + code""" if ($d >= 0) { $pattern.delete(0, $pattern.length()); if ($d != $lastDValue) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 0acd3b490447..5aa5c29f1eb1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -22,6 +22,8 @@ import java.sql.Timestamp import org.apache.spark.sql.catalyst.analysis.TypeCoercion._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.expressions.codegen.JavaCode import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} @@ -1374,13 +1376,13 @@ object TypeCoercionSuite { extends BinaryOperator with Unevaluable { override def dataType: DataType = NullType override def inputType: AbstractDataType = AnyDataType - override def symbol: String = "anytype" + override def symbol: JavaCode = inline"anytype" } case class NumericTypeBinaryOperator(left: Expression, right: Expression) extends BinaryOperator with Unevaluable { override def dataType: DataType = NullType override def inputType: AbstractDataType = NumericType - override def symbol: String = "numerictype" + override def symbol: JavaCode = inline"numerictype" } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala index d2c6420eadb2..4356c1691071 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala @@ -23,10 +23,10 @@ import org.apache.spark.sql.types.{BooleanType, IntegerType} class CodeBlockSuite extends SparkFunSuite { - test("Block interpolates string and ExprValue inputs") { + test("Block interpolates ExprValue inputs") { val isNull = JavaCode.isNullVariable("expr1_isNull") - val stringLiteral = "false" - val code = code"boolean $isNull = $stringLiteral;" + val booleanLiteral = JavaCode.literal("false", BooleanType) + val code = code"boolean $isNull = $booleanLiteral;" assert(code.toString == "boolean expr1_isNull = false;") } @@ -75,7 +75,6 @@ class CodeBlockSuite extends SparkFunSuite { val value1 = JavaCode.variable("expr1", IntegerType) val isNull2 = JavaCode.isNullVariable("expr2_isNull") val value2 = JavaCode.variable("expr2", IntegerType) - val literal = JavaCode.literal("100", IntegerType) val code = code""" @@ -83,7 +82,7 @@ class CodeBlockSuite extends SparkFunSuite { |int $value1 = -1;""".stripMargin + code""" |boolean $isNull2 = true; - |int $value2 = $literal;""".stripMargin + |int $value2 = 100;""".stripMargin val expected = """ @@ -95,16 +94,8 @@ class CodeBlockSuite extends SparkFunSuite { assert(code.toString == expected) val exprValues = code.exprValues - assert(exprValues.size == 5) - assert(exprValues === Set(isNull1, value1, isNull2, value2, literal)) - } - - test("Throws exception when interpolating unexcepted object in code block") { - val obj = Tuple2(1, 1) - val e = intercept[IllegalArgumentException] { - code"$obj" - } - assert(e.getMessage().contains(s"Can not interpolate ${obj.getClass.getName}")) + assert(exprValues.size == 4) + assert(exprValues === Set(isNull1, value1, isNull2, value2)) } test("replace expr values in code block") { @@ -119,7 +110,7 @@ class CodeBlockSuite extends SparkFunSuite { | int $exprInFunc = $expr + 1; |}""".stripMargin - val aliasedParam = JavaCode.variable("aliased", expr.javaType) + val aliasedParam = JavaCode.variable("aliased", IntegerType) val aliasedInputs = code.asInstanceOf[CodeBlock].blockInputs.map { case _: SimpleExprValue => aliasedParam case other => other diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 84d0ba7bef64..476c5e573214 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -118,7 +118,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("+", "1", "*", "2", "-", "3", "4") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression transformDown { - case b: BinaryOperator => actual += b.symbol; b + case b: BinaryOperator => actual += b.symbol.toString; b case l: Literal => actual += l.toString; l } @@ -130,7 +130,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("1", "2", "3", "4", "-", "*", "+") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression transformUp { - case b: BinaryOperator => actual += b.symbol; b + case b: BinaryOperator => actual += b.symbol.toString; b case l: Literal => actual += l.toString; l } @@ -181,7 +181,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("1", "2", "3", "4", "-", "*", "+") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression foreachUp { - case b: BinaryOperator => actual += b.symbol; + case b: BinaryOperator => actual += b.symbol.toString; case l: Literal => actual += l.toString; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 48abad907865..49365748eaab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, IntegerType} import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -46,18 +47,18 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { */ private def genCodeColumnVector( ctx: CodegenContext, - columnVar: String, - ordinal: String, + columnVar: ExprValue, + ordinal: ExprValue, dataType: DataType, nullable: Boolean): ExprCode = { - val javaType = CodeGenerator.javaType(dataType) + val javaType = inline"${CodeGenerator.javaType(dataType)}" val value = CodeGenerator.getValueFromVector(columnVar, dataType, ordinal) val isNullVar = if (nullable) { JavaCode.isNullVariable(ctx.freshName("isNull")) } else { FalseLiteral } - val valueVar = ctx.freshName("value") + val valueVar = JavaCode.variable(ctx.freshName("value"), dataType) val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" val code = code"${ctx.registerComment(str)}" + (if (nullable) { code""" @@ -101,7 +102,8 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { Seq.fill(output.indices.size)(classOf[ColumnVector].getName)) val (colVars, columnAssigns) = columnVectorClzs.zipWithIndex.map { case (columnVectorClz, i) => - val name = ctx.addMutableState(columnVectorClz, s"colInstance$i") + val name = JavaCode.global(ctx.addMutableState(columnVectorClz, s"colInstance$i"), + classOf[ColumnVector]) (name, s"$name = ($columnVectorClz) $batch.column($i);") }.unzip @@ -120,7 +122,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { |}""".stripMargin) ctx.currentVars = null - val rowidx = ctx.freshName("rowIdx") + val rowidx = JavaCode.variable(ctx.freshName("rowIdx"), IntegerType) val columnsBatchInput = (output zip colVars).map { case (attr, colVar) => genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable) } @@ -157,7 +159,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { val numOutputRows = metricTerm(ctx, "numOutputRows") val row = ctx.freshName("row") - ctx.INPUT_ROW = row + ctx.INPUT_ROW = JavaCode.variable(row, classOf[InternalRow]) ctx.currentVars = null // Always provide `outputVars`, so that the framework can help us build unsafe row if the input // row is not unsafe row, i.e. `needsUnsafeRowConversion` is true. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 61c14fee0933..f38319e839c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, JavaCode} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.datasources._ @@ -113,7 +113,7 @@ case class RowDataSourceScanExec( val exprRows = output.zipWithIndex.map{ case (a, i) => BoundReference(i, a.dataType, a.nullable) } - val row = ctx.freshName("row") + val row = JavaCode.variable(ctx.freshName("row"), classOf[InternalRow]) ctx.INPUT_ROW = row ctx.currentVars = null val columnsRowInput = exprRows.map(_.genCode(ctx)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 5b4edf5136e3..b9edda089a31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -151,11 +151,11 @@ case class ExpandExec( // This column is the same across all output rows. Just generate code for it here. BindReferences.bindReference(firstExpr, child.output).genCode(ctx) } else { - val isNull = ctx.freshName("isNull") - val value = ctx.freshName("value") + val isNull = JavaCode.isNullVariable(ctx.freshName("isNull")) + val value = JavaCode.variable(ctx.freshName("value"), firstExpr.dataType) val code = code""" |boolean $isNull = true; - |${CodeGenerator.javaType(firstExpr.dataType)} $value = + |${inline"${CodeGenerator.javaType(firstExpr.dataType)}"} $value = | ${CodeGenerator.defaultValue(firstExpr.dataType)}; """.stripMargin ExprCode( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 2549b9e1537a..55ba8ee7731d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types._ @@ -163,10 +164,10 @@ case class GenerateExec( val data = e.genCode(ctx) // Generate looping variables. - val index = ctx.freshName("index") + val index = JavaCode.variable(ctx.freshName("index"), IntegerType) // Add a check if the generate outer flag is true. - val checks = optionalCode(outer, s"($index == -1)") + val checks = optionalCode(outer, code"($index == -1)") // Add position val position = if (e.position) { @@ -185,13 +186,13 @@ case class GenerateExec( val (initMapData, updateRowData, values) = e.collectionType match { case ArrayType(st: StructType, nullable) if e.inline => val row = codeGenAccessor(ctx, data.value, "col", index, st, nullable, checks) - val fieldChecks = checks ++ optionalCode(nullable, row.isNull) + val fieldChecks = checks ++ optionalCode(nullable, code"${row.isNull}") val columns = st.fields.toSeq.zipWithIndex.map { case (f, i) => codeGenAccessor( ctx, row.value, s"st_col${i}", - i.toString, + JavaCode.variable(i.toString, IntegerType), f.dataType, f.nullable, fieldChecks) @@ -203,8 +204,8 @@ case class GenerateExec( case MapType(keyType, valueType, valueContainsNull) => // Materialize the key and the value arrays before we enter the loop. - val keyArray = ctx.freshName("keyArray") - val valueArray = ctx.freshName("valueArray") + val keyArray = JavaCode.variable(ctx.freshName("keyArray"), classOf[ArrayData]) + val valueArray = JavaCode.variable(ctx.freshName("valueArray"), classOf[ArrayData]) val initArrayData = s""" |ArrayData $keyArray = ${data.isNull} ? null : ${data.value}.keyArray(); @@ -251,16 +252,17 @@ case class GenerateExec( val data = e.genCode(ctx) // Generate looping variables. - val iterator = ctx.freshName("iterator") - val hasNext = ctx.freshName("hasNext") - val current = ctx.freshName("row") + val iterator = JavaCode.variable(ctx.freshName("iterator"), classOf[Iterator[InternalRow]]) + val hasNext = JavaCode.variable(ctx.freshName("hasNext"), BooleanType) + val current = JavaCode.variable(ctx.freshName("row"), classOf[InternalRow]) // Add a check if the generate outer flag is true. - val checks = optionalCode(outer, s"!$hasNext") + val checks = optionalCode(outer, code"!$hasNext") val values = e.dataType match { case ArrayType(st: StructType, nullable) => st.fields.toSeq.zipWithIndex.map { case (f, i) => - codeGenAccessor(ctx, current, s"st_col${i}", s"$i", f.dataType, f.nullable, checks) + val index = JavaCode.variable(s"$i", IntegerType) + codeGenAccessor(ctx, current, s"st_col${i}", index, f.dataType, f.nullable, checks) } } @@ -301,21 +303,22 @@ case class GenerateExec( */ private def codeGenAccessor( ctx: CodegenContext, - source: String, + source: ExprValue, name: String, - index: String, + index: ExprValue, dt: DataType, nullable: Boolean, - initialChecks: Seq[String]): ExprCode = { - val value = ctx.freshName(name) - val javaType = CodeGenerator.javaType(dt) + initialChecks: Seq[Block]): ExprCode = { + val value = JavaCode.variable(ctx.freshName(name), dt) + val javaType = inline"${CodeGenerator.javaType(dt)}" val getter = CodeGenerator.getValue(source, dt, index) - val checks = initialChecks ++ optionalCode(nullable, s"$source.isNullAt($index)") + val checks = initialChecks ++ optionalCode(nullable, code"$source.isNullAt($index)") if (checks.nonEmpty) { - val isNull = ctx.freshName("isNull") + val isNull = JavaCode.isNullVariable(ctx.freshName("isNull")) + val checkBlock = checks.reduceLeft((left, right) => code"$left || $right") val code = code""" - |boolean $isNull = ${checks.mkString(" || ")}; + |boolean $isNull = $checkBlock; |$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter; """.stripMargin ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, dt)) @@ -324,8 +327,8 @@ case class GenerateExec( } } - private def optionalCode(condition: Boolean, code: => String): Seq[String] = { + private def optionalCode(condition: Boolean, code: => Block): Seq[Block] = { if (condition) Seq(code) - else Seq.empty + else Seq(EmptyBlock) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 372dc3db36ce..92feac99470a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -120,7 +120,7 @@ trait CodegenSupport extends SparkPlan { } val evaluateInputs = evaluateVariables(colVars) // generate the code to create a UnsafeRow - ctx.INPUT_ROW = row + ctx.INPUT_ROW = JavaCode.variable(row, classOf[InternalRow]) ctx.currentVars = colVars val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) val code = code""" @@ -149,7 +149,7 @@ trait CodegenSupport extends SparkPlan { } else { assert(row != null, "outputVars and row cannot both be null.") ctx.currentVars = null - ctx.INPUT_ROW = row + ctx.INPUT_ROW = JavaCode.variable(row, classOf[InternalRow]) output.zipWithIndex.map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable).genCode(ctx) } @@ -259,8 +259,8 @@ trait CodegenSupport extends SparkPlan { * Returns source code to evaluate all the variables, and clear the code of them, to prevent * them to be evaluated twice. */ - protected def evaluateVariables(variables: Seq[ExprCode]): String = { - val evaluate = variables.filter(_.code.nonEmpty).map(_.code.toString).mkString("\n") + protected def evaluateVariables(variables: Seq[ExprCode]): Block = { + val evaluate = Blocks(variables.filter(_.code.nonEmpty).map(_.code)) variables.foreach(_.code = EmptyBlock) evaluate } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 8c7b2c187ccc..2126a4a4e2ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.vectorized.MutableColumnarRow -import org.apache.spark.sql.types.{DecimalType, StringType, StructType} +import org.apache.spark.sql.types.{BooleanType, DecimalType, StringType, StructType} import org.apache.spark.unsafe.KVIterator import org.apache.spark.util.Utils @@ -187,8 +187,12 @@ case class HashAggregateExec( val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val initExpr = functions.flatMap(f => f.initialValues) bufVars = initExpr.map { e => - val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull") - val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") + val isNull = JavaCode.global( + ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull"), + BooleanType) + val value = JavaCode.global( + ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue"), + e.dataType) // The initial expression should not access any column val ev = e.genCode(ctx) val initVars = code""" @@ -215,13 +219,13 @@ case class HashAggregateExec( val resultVars = resultExpressions.map { e => BindReferences.bindReference(e, aggregateAttributes).genCode(ctx) } - (resultVars, s""" + (resultVars, code""" |$evaluateAggResults |${evaluateVariables(resultVars)} """.stripMargin) } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { // output the aggregate buffer directly - (bufVars, "") + (bufVars, EmptyBlock) } else { // no aggregate function, the result should be literals val resultVars = resultExpressions.map(_.genCode(ctx)) @@ -250,7 +254,7 @@ case class HashAggregateExec( | $aggTime.add((System.nanoTime() - $beforeAgg) / 1000000); | | // output the result - | ${genResult.trim} + | ${genResult.code} | | $numOutput.add(1); | ${consume(ctx, resultVars).trim} @@ -451,12 +455,12 @@ case class HashAggregateExec( if (modes.contains(Final) || modes.contains(Complete)) { // generate output using resultExpressions ctx.currentVars = null - ctx.INPUT_ROW = keyTerm + ctx.INPUT_ROW = JavaCode.variable(keyTerm, classOf[InternalRow]) val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) => BoundReference(i, e.dataType, e.nullable).genCode(ctx) } val evaluateKeyVars = evaluateVariables(keyVars) - ctx.INPUT_ROW = bufferTerm + ctx.INPUT_ROW = JavaCode.variable(bufferTerm, classOf[InternalRow]) val bufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, i) => BoundReference(i, e.dataType, e.nullable).genCode(ctx) } @@ -487,13 +491,13 @@ case class HashAggregateExec( ctx.currentVars = null - ctx.INPUT_ROW = keyTerm + ctx.INPUT_ROW = JavaCode.variable(keyTerm, classOf[InternalRow]) val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) => BoundReference(i, e.dataType, e.nullable).genCode(ctx) } val evaluateKeyVars = evaluateVariables(keyVars) - ctx.INPUT_ROW = bufferTerm + ctx.INPUT_ROW = JavaCode.variable(bufferTerm, classOf[InternalRow]) val resultBufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, i) => BoundReference(i, e.dataType, e.nullable).genCode(ctx) } @@ -511,7 +515,7 @@ case class HashAggregateExec( """ } else { // generate result based on grouping key - ctx.INPUT_ROW = keyTerm + ctx.INPUT_ROW = JavaCode.variable(keyTerm, classOf[InternalRow]) ctx.currentVars = null val eval = resultExpressions.map{ e => BindReferences.bindReference(e, groupingAttributes).genCode(ctx) @@ -679,7 +683,7 @@ case class HashAggregateExec( def outputFromVectorizedMap: String = { val row = ctx.freshName("fastHashMapRow") ctx.currentVars = null - ctx.INPUT_ROW = row + ctx.INPUT_ROW = JavaCode.variable(row, classOf[InternalRow]) val generateKeyRow = GenerateUnsafeProjection.createCode(ctx, groupingKeySchema.toAttributes.zipWithIndex .map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) } @@ -830,7 +834,8 @@ case class HashAggregateExec( ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input val updateRowInRegularHashMap: String = { - ctx.INPUT_ROW = unsafeRowBuffer + val unsafeRowBufferVar = JavaCode.variable(unsafeRowBuffer, classOf[InternalRow]) + ctx.INPUT_ROW = unsafeRowBufferVar val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) val effectiveCodes = subExprs.codes.mkString("\n") @@ -839,7 +844,7 @@ case class HashAggregateExec( } val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType - CodeGenerator.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) + CodeGenerator.updateColumn(unsafeRowBufferVar, dt, i, ev, updateExpr(i).nullable) } s""" |// common sub-expressions @@ -853,7 +858,8 @@ case class HashAggregateExec( val updateRowInHashMap: String = { if (isFastHashMapEnabled) { - ctx.INPUT_ROW = fastRowBuffer + val fastRowBufferVar = JavaCode.variable(fastRowBuffer, classOf[InternalRow]) + ctx.INPUT_ROW = fastRowBufferVar val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) val effectiveCodes = subExprs.codes.mkString("\n") @@ -863,7 +869,7 @@ case class HashAggregateExec( val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType CodeGenerator.updateColumn( - fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorizedHashMapEnabled) + fastRowBufferVar, dt, i, ev, updateExpr(i).nullable, isVectorizedHashMapEnabled) } // If fast hash map is on, we first generate code to update row in fast hash map, if the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index e1c85823259b..9657226067ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -47,8 +47,12 @@ abstract class HashMapGenerator( val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val initExpr = functions.flatMap(f => f.initialValues) initExpr.map { e => - val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull") - val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") + val isNull = JavaCode.variable( + ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull"), + BooleanType) + val value = JavaCode.variable( + ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue"), + e.dataType) val ev = e.genCode(ctx) val initVars = code""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index d5508275c48c..38f5111eaad2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, JavaCode} import org.apache.spark.sql.types._ /** @@ -114,8 +114,10 @@ class RowBasedHashMapGenerator( def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - s"""(${ctx.genEqual(key.dataType, CodeGenerator.getValue("row", - key.dataType, ordinal.toString()), key.name)})""" + val rowValue = JavaCode.variable("row", classOf[UnsafeRow]) + val keyValue = JavaCode.variable(key.name, key.dataType) + s"""(${ctx.genEqual(key.dataType, CodeGenerator.getValue(rowValue, + key.dataType, ordinal.toString()), keyValue)})""" }.mkString(" && ") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 7b3580cecc60..bc4cd5cbc933 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, JavaCode} import org.apache.spark.sql.execution.vectorized.{MutableColumnarRow, OnHeapColumnVector} import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch @@ -51,6 +51,8 @@ class VectorizedHashMapGenerator( extends HashMapGenerator (ctx, aggregateExpressions, generatedClassName, groupingKeySchema, bufferSchema) { + val numRows = JavaCode.variable("numRows", IntegerType) + override protected def initializeAggregateHashMap(): String = { val generatedSchema: String = s"new org.apache.spark.sql.types.StructType()" + @@ -87,7 +89,7 @@ class VectorizedHashMapGenerator( | private double loadFactor = 0.5; | private int numBuckets = (int) (capacity / loadFactor); | private int maxSteps = 2; - | private int numRows = 0; + | private int $numRows = 0; | private org.apache.spark.sql.types.StructType schema = $generatedSchema | private org.apache.spark.sql.types.StructType aggregateBufferSchema = | $generatedAggBufferSchema @@ -127,9 +129,11 @@ class VectorizedHashMapGenerator( def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - val value = CodeGenerator.getValueFromVector(s"vectors[$ordinal]", key.dataType, - "buckets[idx]") - s"(${ctx.genEqual(key.dataType, value, key.name)})" + val vector = JavaCode.variable(s"vectors[$ordinal]", classOf[OnHeapColumnVector]) + val bucket = JavaCode.variable("buckets[idx]", IntegerType) + val value = CodeGenerator.getValueFromVector(vector, key.dataType, bucket) + val keyValue = JavaCode.variable(key.name, key.dataType) + s"(${ctx.genEqual(key.dataType, value, keyValue)})" }.mkString(" && ") } @@ -183,14 +187,18 @@ class VectorizedHashMapGenerator( def genCodeToSetKeys(groupingKeys: Seq[Buffer]): Seq[String] = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - CodeGenerator.setValue(s"vectors[$ordinal]", "numRows", key.dataType, key.name) + val vector = JavaCode.variable(s"vectors[$ordinal]", classOf[OnHeapColumnVector]) + val keyValue = JavaCode.variable(key.name, key.dataType) + CodeGenerator.setValue(vector, numRows, key.dataType, keyValue).code } } def genCodeToSetAggBuffers(bufferValues: Seq[Buffer]): Seq[String] = { bufferValues.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - CodeGenerator.updateColumn(s"vectors[${groupingKeys.length + ordinal}]", "numRows", - key.dataType, buffVars(ordinal), nullable = true) + val vector = JavaCode.variable(s"vectors[${groupingKeys.length + ordinal}]", + classOf[OnHeapColumnVector]) + CodeGenerator.updateColumn(vector, numRows, key.dataType, buffVars(ordinal), + nullable = true).code } } @@ -202,7 +210,7 @@ class VectorizedHashMapGenerator( | while (step < maxSteps) { | // Return bucket index if it's either an empty slot or already contains the key | if (buckets[idx] == -1) { - | if (numRows < capacity) { + | if ($numRows < capacity) { | | // Initialize aggregate keys | ${genCodeToSetKeys(groupingKeys).mkString("\n")} @@ -212,7 +220,7 @@ class VectorizedHashMapGenerator( | // Initialize aggregate values | ${genCodeToSetAggBuffers(bufferValues).mkString("\n")} | - | buckets[idx] = numRows++; + | buckets[idx] = $numRows++; | aggBufferRow.rowId = buckets[idx]; | return aggBufferRow; | } else { @@ -235,7 +243,7 @@ class VectorizedHashMapGenerator( protected def generateRowIterator(): String = { s""" |public java.util.Iterator<${classOf[InternalRow].getName}> rowIterator() { - | batch.setNumRows(numRows); + | batch.setNumRows($numRows); | return batch.rowIterator(); |} """.stripMargin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 0da0e8610c39..127347c7a1a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -172,18 +172,18 @@ case class BroadcastHashJoinExec( /** * Generates the code for variable of build side. */ - private def genBuildSideVars(ctx: CodegenContext, matched: String): Seq[ExprCode] = { + private def genBuildSideVars(ctx: CodegenContext, matched: ExprValue): Seq[ExprCode] = { ctx.currentVars = null - ctx.INPUT_ROW = matched + ctx.INPUT_ROW = JavaCode.variable(matched, classOf[InternalRow]) buildPlan.output.zipWithIndex.map { case (a, i) => val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx) if (joinType.isInstanceOf[InnerLike]) { ev } else { // the variables are needed even there is no matched rows - val isNull = ctx.freshName("isNull") - val value = ctx.freshName("value") - val javaType = CodeGenerator.javaType(a.dataType) + val isNull = JavaCode.isNullVariable(ctx.freshName("isNull")) + val value = JavaCode.variable(ctx.freshName("value"), a.dataType) + val javaType = inline"${CodeGenerator.javaType(a.dataType)}" val code = code""" |boolean $isNull = true; |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)}; @@ -205,7 +205,7 @@ case class BroadcastHashJoinExec( private def getJoinCondition( ctx: CodegenContext, input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = { - val matched = ctx.freshName("matched") + val matched = JavaCode.variable(ctx.freshName("matched"), classOf[UnsafeRow]) val buildVars = genBuildSideVars(ctx, matched) val checkCondition = if (condition.isDefined) { val expr = condition.get @@ -281,7 +281,7 @@ case class BroadcastHashJoinExec( private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = { val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) - val matched = ctx.freshName("matched") + val matched = JavaCode.variable(ctx.freshName("matched"), classOf[UnsafeRow]) val buildVars = genBuildSideVars(ctx, matched) val numOutput = metricTerm(ctx, "numOutputRows") @@ -469,7 +469,7 @@ case class BroadcastHashJoinExec( val numOutput = metricTerm(ctx, "numOutputRows") val existsVar = ctx.freshName("exists") - val matched = ctx.freshName("matched") + val matched = JavaCode.variable(ctx.freshName("matched"), classOf[UnsafeRow]) val buildVars = genBuildSideVars(ctx, matched) val checkCondition = if (condition.isDefined) { val expr = condition.get diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index f4b9d132122e..ca5c20df7f65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.types.BooleanType import org.apache.spark.util.collection.BitSet /** @@ -391,7 +392,7 @@ case class SortMergeJoinExec( row: String, keys: Seq[Expression], input: Seq[Attribute]): Seq[ExprCode] = { - ctx.INPUT_ROW = row + ctx.INPUT_ROW = JavaCode.variable(row, classOf[InternalRow]) ctx.currentVars = null keys.map(BindReferences.bindReference(_, input).genCode(ctx)) } @@ -420,11 +421,15 @@ case class SortMergeJoinExec( * Generate a function to scan both left and right to find a match, returns the term for * matched one row from left side and buffered rows from right side. */ - private def genScanner(ctx: CodegenContext): (String, String) = { + private def genScanner(ctx: CodegenContext): (ExprValue, ExprValue) = { // Create class member for next row from both sides. // Inline mutable state since not many join operations in a task - val leftRow = ctx.addMutableState("InternalRow", "leftRow", forceInline = true) - val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true) + val leftRow = JavaCode.global( + ctx.addMutableState("InternalRow", "leftRow", forceInline = true), + classOf[InternalRow]) + val rightRow = JavaCode.global( + ctx.addMutableState("InternalRow", "rightRow", forceInline = true), + classOf[InternalRow]) // Create variables for join keys from both sides. val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output) @@ -441,8 +446,9 @@ case class SortMergeJoinExec( val inMemoryThreshold = getInMemoryThreshold // Inline mutable state since not many join operations in a task - val matches = ctx.addMutableState(clsName, "matches", - v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true) + val matches = JavaCode.global(ctx.addMutableState(clsName, "matches", + v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true), + classOf[ExternalAppendOnlyUnsafeRowArray]) // Copy the left keys as class members so they could be used in next function call. val matchedKeyVars = copyKeys(ctx, leftKeyVars) @@ -512,22 +518,24 @@ case class SortMergeJoinExec( * the variables should be declared separately from accessing the columns, we can't use the * codegen of BoundReference here. */ - private def createLeftVars(ctx: CodegenContext, leftRow: String): (Seq[ExprCode], Seq[String]) = { - ctx.INPUT_ROW = leftRow + private def createLeftVars( + ctx: CodegenContext, + leftRow: ExprValue): (Seq[ExprCode], Seq[Block]) = { + ctx.INPUT_ROW = JavaCode.variable(leftRow, classOf[InternalRow]) left.output.zipWithIndex.map { case (a, i) => - val value = ctx.freshName("value") + val value = JavaCode.variable(ctx.freshName("value"), a.dataType) val valueCode = CodeGenerator.getValue(leftRow, a.dataType, i.toString) - val javaType = CodeGenerator.javaType(a.dataType) + val javaType = inline"${CodeGenerator.javaType(a.dataType)}" val defaultValue = CodeGenerator.defaultValue(a.dataType) if (a.nullable) { - val isNull = ctx.freshName("isNull") + val isNull = JavaCode.isNullVariable(ctx.freshName("isNull")) val code = code""" |$isNull = $leftRow.isNullAt($i); |$value = $isNull ? $defaultValue : ($valueCode); """.stripMargin val leftVarsDecl = - s""" + code""" |boolean $isNull = false; |$javaType $value = $defaultValue; """.stripMargin @@ -535,7 +543,7 @@ case class SortMergeJoinExec( leftVarsDecl) } else { val code = code"$value = $valueCode;" - val leftVarsDecl = s"""$javaType $value = $defaultValue;""" + val leftVarsDecl = code"""$javaType $value = $defaultValue;""" (ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), leftVarsDecl) } }.unzip @@ -546,7 +554,7 @@ case class SortMergeJoinExec( * part are accessed inside the loop. */ private def createRightVar(ctx: CodegenContext, rightRow: String): Seq[ExprCode] = { - ctx.INPUT_ROW = rightRow + ctx.INPUT_ROW = JavaCode.variable(rightRow, classOf[InternalRow]) right.output.zipWithIndex.map { case (a, i) => BoundReference(i, a.dataType, a.nullable).genCode(ctx) } @@ -561,7 +569,7 @@ case class SortMergeJoinExec( */ private def splitVarsByCondition( attributes: Seq[Attribute], - variables: Seq[ExprCode]): (String, String) = { + variables: Seq[ExprCode]): (Block, Block) = { if (condition.isDefined) { val condRefs = condition.get.references val (used, notUsed) = attributes.zip(variables).partition{ case (a, ev) => @@ -571,7 +579,7 @@ case class SortMergeJoinExec( val afterCond = evaluateVariables(notUsed.map(_._2)) (beforeCond, afterCond) } else { - (evaluateVariables(variables), "") + (evaluateVariables(variables), EmptyBlock) } } @@ -595,19 +603,19 @@ case class SortMergeJoinExec( val numOutput = metricTerm(ctx, "numOutputRows") val (beforeLoop, condCheck) = if (condition.isDefined) { // Split the code of creating variables based on whether it's used by condition or not. - val loaded = ctx.freshName("loaded") + val loaded = JavaCode.variable(ctx.freshName("loaded"), BooleanType) val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars) val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars) // Generate code for condition ctx.currentVars = leftVars ++ rightVars val cond = BindReferences.bindReference(condition.get, output).genCode(ctx) // evaluate the columns those used by condition before loop - val before = s""" + val before = code""" |boolean $loaded = false; |$leftBefore """.stripMargin - val checking = s""" + val checking = code""" |$rightBefore |${cond.code} |if (${cond.isNull} || !${cond.value}) continue; @@ -619,17 +627,17 @@ case class SortMergeJoinExec( """.stripMargin (before, checking) } else { - (evaluateVariables(leftVars), "") + (evaluateVariables(leftVars), EmptyBlock) } s""" |while (findNextInnerJoinRows($leftInput, $rightInput)) { | ${leftVarDecl.mkString("\n")} - | ${beforeLoop.trim} + | ${beforeLoop} | scala.collection.Iterator $iterator = $matches.generateIterator(); | while ($iterator.hasNext()) { | InternalRow $rightRow = (InternalRow) $iterator.next(); - | ${condCheck.trim} + | ${condCheck} | $numOutput.add(1); | ${consume(ctx, leftVars ++ rightVars)} | } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index 8280a3ce3984..26caf4735da2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -315,7 +315,7 @@ case class EmptyGenerator() extends Generator { override def elementSchema: StructType = new StructType().add("id", IntegerType) override def eval(input: InternalRow): TraversableOnce[InternalRow] = Seq.empty override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val iteratorClass = classOf[Iterator[_]].getName + val iteratorClass = inline"${classOf[Iterator[_]].getName}" ev.copy(code = code"$iteratorClass ${ev.value} = $iteratorClass$$.MODULE$$.empty();") } From 2a85388276d7227d04c24110756236bc9da063d4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 9 Jun 2018 11:32:00 +0000 Subject: [PATCH 2/2] Second pass. --- .../catalyst/expressions/BoundAttribute.scala | 4 +- .../expressions/codegen/CodeGenerator.scala | 2 +- .../codegen/GenerateSafeProjection.scala | 4 +- .../codegen/GenerateUnsafeProjection.scala | 2 +- .../expressions/collectionOperations.scala | 31 ++-- .../expressions/maskExpressions.scala | 154 +++++++++--------- .../expressions/stringExpressions.scala | 30 ++-- 7 files changed, 120 insertions(+), 107 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 13eb5982fe16..bd90ed928c61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -54,9 +54,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) } else { assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.") val javaType = inline"${CodeGenerator.javaType(dataType)}" - val value = JavaCode.expression( - CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString), - dataType) + val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (nullable) { ev.copy(code = code""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e77c38cb4be8..58737be7e7f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1376,7 +1376,7 @@ object CodeGenerator extends Logging { val msg = s"failed to compile: $e" logError(msg, e) val maxLines = SQLConf.get.loggingMaxLinesForCodegen - println(s"\n${CodeFormatter.format(code, maxLines)}") + logInfo(s"\n${CodeFormatter.format(code, maxLines)}") throw new CompileException(msg, e.getLocation) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 7cf8618fbb43..08861d4e66af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -57,7 +57,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => val converter = convertToSafe( ctx, - JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt), + CodeGenerator.getValue(tmpInput, dt, i.toString), dt) code""" if (!$tmpInput.isNullAt($i)) { @@ -96,7 +96,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val elementConverter = convertToSafe( ctx, - JavaCode.expression(CodeGenerator.getValue(tmpInput, elementType, index), elementType), + CodeGenerator.getValue(tmpInput, elementType, index), elementType) val code = code""" final ArrayData $tmpInput = $input; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 81fd8a27a454..af2d2f05fc0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -57,7 +57,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => ExprCode( JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)"), - JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt)) + CodeGenerator.getValue(tmpInput, dt, i.toString)) } val rowWriterClass = classOf[UnsafeRowWriter].getName diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index d6cdb4597621..63fde0c20202 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2159,12 +2159,12 @@ case class ArrayRemove(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (arr, value) => { - val numsToRemove = ctx.freshName("numsToRemove") - val newArraySize = ctx.freshName("newArraySize") - val i = ctx.freshName("i") + val numsToRemove = JavaCode.variable(ctx.freshName("numsToRemove"), IntegerType) + val newArraySize = JavaCode.variable(ctx.freshName("newArraySize"), IntegerType) + val i = JavaCode.variable(ctx.freshName("i"), IntegerType) val getValue = CodeGenerator.getValue(arr, elementType, i) val isEqual = ctx.genEqual(elementType, value, getValue) - s""" + code""" |int $numsToRemove = 0; |for (int $i = 0; $i < $arr.numElements(); $i ++) { | if (!$arr.isNullAt($i) && $isEqual) { @@ -2180,17 +2180,17 @@ case class ArrayRemove(left: Expression, right: Expression) def genCodeForResult( ctx: CodegenContext, ev: ExprCode, - inputArray: String, - value: String, - newArraySize: String): String = { - val values = ctx.freshName("values") - val i = ctx.freshName("i") - val pos = ctx.freshName("pos") + inputArray: ExprValue, + value: ExprValue, + newArraySize: ExprValue): Block = { + val values = JavaCode.variable(ctx.freshName("values"), classOf[Array[Object]]) + val i = JavaCode.variable(ctx.freshName("i"), IntegerType) + val pos = JavaCode.variable(ctx.freshName("pos"), IntegerType) val getValue = CodeGenerator.getValue(inputArray, elementType, i) val isEqual = ctx.genEqual(elementType, value, getValue) if (!CodeGenerator.isPrimitiveType(elementType)) { - val arrayClass = classOf[GenericArrayData].getName - s""" + val arrayClass = inline"${classOf[GenericArrayData].getName}" + code""" |int $pos = 0; |Object[] $values = new Object[$newArraySize]; |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { @@ -2208,9 +2208,10 @@ case class ArrayRemove(left: Expression, right: Expression) |${ev.value} = new $arrayClass($values); """.stripMargin } else { - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - s""" - |${ctx.createUnsafeArray(values, newArraySize, elementType, s" $prettyName failed.")} + val primitiveValueTypeName = inline"${CodeGenerator.primitiveTypeName(elementType)}" + val errorMsg = new LiteralValue(s" $prettyName failed.", classOf[String]) + code""" + |${ctx.createUnsafeArray(values, newArraySize, elementType, errorMsg)} |int $pos = 0; |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { | if ($inputArray.isNullAt($i)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala index 276a57266a6e..572b32502cf1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala @@ -22,7 +22,8 @@ import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.MaskExpressionsUtils._ import org.apache.spark.sql.catalyst.expressions.MaskLike._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{Block, CodegenContext, CodeGenerator, ExprCode, ExprValue, JavaCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -36,23 +37,25 @@ trait MaskLike { protected lazy val lowerReplacement: Int = getReplacementChar(lower, defaultMaskedLowercase) protected lazy val digitReplacement: Int = getReplacementChar(digit, defaultMaskedDigit) - protected val maskUtilsClassName: String = classOf[MaskExpressionsUtils].getName + protected val maskUtilsClassName: Block = inline"${classOf[MaskExpressionsUtils].getName}" - def inputStringLengthCode(inputString: String, length: String): String = { - s"${CodeGenerator.JAVA_INT} $length = $inputString.codePointCount(0, $inputString.length());" + def inputStringLengthCode(inputString: ExprValue, length: ExprValue): Block = { + val intType = inline"${CodeGenerator.JAVA_INT}" + code"$intType $length = $inputString.codePointCount(0, $inputString.length());" } def appendMaskedToStringBuilderCode( ctx: CodegenContext, - sb: String, - inputString: String, - offset: String, - numChars: String): String = { - val i = ctx.freshName("i") - val codePoint = ctx.freshName("codePoint") - s""" - |for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) { - | ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset); + sb: ExprValue, + inputString: ExprValue, + offset: ExprValue, + numChars: JavaCode): Block = { + val i = JavaCode.variable(ctx.freshName("i"), IntegerType) + val codePoint = JavaCode.variable(ctx.freshName("codePoint"), IntegerType) + val intType = inline"${CodeGenerator.JAVA_INT}" + code""" + |for ($intType $i = 0; $i < $numChars; $i++) { + | $intType $codePoint = $inputString.codePointAt($offset); | $sb.appendCodePoint($maskUtilsClassName.transformChar($codePoint, | $upperReplacement, $lowerReplacement, | $digitReplacement, $defaultMaskedOther)); @@ -63,15 +66,16 @@ trait MaskLike { def appendUnchangedToStringBuilderCode( ctx: CodegenContext, - sb: String, - inputString: String, - offset: String, - numChars: String): String = { - val i = ctx.freshName("i") - val codePoint = ctx.freshName("codePoint") - s""" - |for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) { - | ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset); + sb: ExprValue, + inputString: ExprValue, + offset: ExprValue, + numChars: JavaCode): Block = { + val i = JavaCode.variable(ctx.freshName("i"), IntegerType) + val codePoint = JavaCode.variable(ctx.freshName("codePoint"), IntegerType) + val intType = inline"${CodeGenerator.JAVA_INT}" + code""" + |for ($intType $i = 0; $i < $numChars; $i++) { + | $intType $codePoint = $inputString.codePointAt($offset); | $sb.appendCodePoint($codePoint); | $offset += Character.charCount($codePoint); |} @@ -179,16 +183,17 @@ case class Mask(child: Expression, upper: String, lower: String, digit: String) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val sb = ctx.freshName("sb") - val length = ctx.freshName("length") - val offset = ctx.freshName("offset") - val inputString = ctx.freshName("inputString") - s""" + nullSafeCodeGen(ctx, ev, (input: ExprValue) => { + val sb = JavaCode.variable(ctx.freshName("sb"), classOf[StringBuilder]) + val length = JavaCode.variable(ctx.freshName("length"), IntegerType) + val offset = JavaCode.variable(ctx.freshName("offset"), IntegerType) + val inputString = JavaCode.variable(ctx.freshName("inputString"), classOf[String]) + val intType = inline"${CodeGenerator.JAVA_INT}" + code""" |String $inputString = $input.toString(); |${inputStringLengthCode(inputString, length)} |StringBuilder $sb = new StringBuilder($length); - |${CodeGenerator.JAVA_INT} $offset = 0; + |$intType $offset = 0; |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, length)} |${ev.value} = UTF8String.fromString($sb.toString()); """.stripMargin @@ -256,21 +261,22 @@ case class MaskFirstN( } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val sb = ctx.freshName("sb") - val length = ctx.freshName("length") - val offset = ctx.freshName("offset") - val inputString = ctx.freshName("inputString") - val endOfMask = ctx.freshName("endOfMask") - s""" + nullSafeCodeGen(ctx, ev, (input: ExprValue) => { + val sb = JavaCode.variable(ctx.freshName("sb"), classOf[StringBuilder]) + val length = JavaCode.variable(ctx.freshName("length"), IntegerType) + val offset = JavaCode.variable(ctx.freshName("offset"), IntegerType) + val inputString = JavaCode.variable(ctx.freshName("inputString"), classOf[String]) + val endOfMask = JavaCode.variable(ctx.freshName("endOfMask"), IntegerType) + val intType = inline"${CodeGenerator.JAVA_INT}" + code""" |String $inputString = $input.toString(); |${inputStringLengthCode(inputString, length)} - |${CodeGenerator.JAVA_INT} $endOfMask = $charCount > $length ? $length : $charCount; - |${CodeGenerator.JAVA_INT} $offset = 0; + |$intType $endOfMask = $charCount > $length ? $length : $charCount; + |$intType $offset = 0; |StringBuilder $sb = new StringBuilder($length); |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)} |${appendUnchangedToStringBuilderCode( - ctx, sb, inputString, offset, s"$length - $endOfMask")} + ctx, sb, inputString, offset, code"$length - $endOfMask")} |${ev.value} = UTF8String.fromString($sb.toString()); |""".stripMargin }) @@ -339,22 +345,22 @@ case class MaskLastN( } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val sb = ctx.freshName("sb") - val length = ctx.freshName("length") - val offset = ctx.freshName("offset") - val inputString = ctx.freshName("inputString") - val startOfMask = ctx.freshName("startOfMask") - s""" + nullSafeCodeGen(ctx, ev, (input: ExprValue) => { + val sb = JavaCode.variable(ctx.freshName("sb"), classOf[StringBuilder]) + val length = JavaCode.variable(ctx.freshName("length"), IntegerType) + val offset = JavaCode.variable(ctx.freshName("offset"), IntegerType) + val inputString = JavaCode.variable(ctx.freshName("inputString"), classOf[String]) + val startOfMask = JavaCode.variable(ctx.freshName("startOfMask"), IntegerType) + val intType = inline"${CodeGenerator.JAVA_INT}" + code""" |String $inputString = $input.toString(); |${inputStringLengthCode(inputString, length)} - |${CodeGenerator.JAVA_INT} $startOfMask = $charCount >= $length ? - | 0 : $length - $charCount; - |${CodeGenerator.JAVA_INT} $offset = 0; + |$intType $startOfMask = $charCount >= $length ? 0 : $length - $charCount; + |$intType $offset = 0; |StringBuilder $sb = new StringBuilder($length); |${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)} |${appendMaskedToStringBuilderCode( - ctx, sb, inputString, offset, s"$length - $startOfMask")} + ctx, sb, inputString, offset, code"$length - $startOfMask")} |${ev.value} = UTF8String.fromString($sb.toString()); |""".stripMargin }) @@ -423,21 +429,22 @@ case class MaskShowFirstN( } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val sb = ctx.freshName("sb") - val length = ctx.freshName("length") - val offset = ctx.freshName("offset") - val inputString = ctx.freshName("inputString") - val startOfMask = ctx.freshName("startOfMask") - s""" + nullSafeCodeGen(ctx, ev, (input: ExprValue) => { + val sb = JavaCode.variable(ctx.freshName("sb"), classOf[StringBuilder]) + val length = JavaCode.variable(ctx.freshName("length"), IntegerType) + val offset = JavaCode.variable(ctx.freshName("offset"), IntegerType) + val inputString = JavaCode.variable(ctx.freshName("inputString"), classOf[String]) + val startOfMask = JavaCode.variable(ctx.freshName("startOfMask"), IntegerType) + val intType = inline"${CodeGenerator.JAVA_INT}" + code""" |String $inputString = $input.toString(); |${inputStringLengthCode(inputString, length)} - |${CodeGenerator.JAVA_INT} $startOfMask = $charCount > $length ? $length : $charCount; - |${CodeGenerator.JAVA_INT} $offset = 0; + |$intType $startOfMask = $charCount > $length ? $length : $charCount; + |$intType $offset = 0; |StringBuilder $sb = new StringBuilder($length); |${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)} |${appendMaskedToStringBuilderCode( - ctx, sb, inputString, offset, s"$length - $startOfMask")} + ctx, sb, inputString, offset, code"$length - $startOfMask")} |${ev.value} = UTF8String.fromString($sb.toString()); |""".stripMargin }) @@ -506,21 +513,22 @@ case class MaskShowLastN( } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val sb = ctx.freshName("sb") - val length = ctx.freshName("length") - val offset = ctx.freshName("offset") - val inputString = ctx.freshName("inputString") - val endOfMask = ctx.freshName("endOfMask") - s""" + nullSafeCodeGen(ctx, ev, (input: ExprValue) => { + val sb = JavaCode.variable(ctx.freshName("sb"), classOf[StringBuilder]) + val length = JavaCode.variable(ctx.freshName("length"), IntegerType) + val offset = JavaCode.variable(ctx.freshName("offset"), IntegerType) + val inputString = JavaCode.variable(ctx.freshName("inputString"), classOf[String]) + val endOfMask = JavaCode.variable(ctx.freshName("endOfMask"), IntegerType) + val intType = inline"${CodeGenerator.JAVA_INT}" + code""" |String $inputString = $input.toString(); |${inputStringLengthCode(inputString, length)} - |${CodeGenerator.JAVA_INT} $endOfMask = $charCount >= $length ? 0 : $length - $charCount; - |${CodeGenerator.JAVA_INT} $offset = 0; + |$intType $endOfMask = $charCount >= $length ? 0 : $length - $charCount; + |$intType $offset = 0; |StringBuilder $sb = new StringBuilder($length); |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)} |${appendUnchangedToStringBuilderCode( - ctx, sb, inputString, offset, s"$length - $endOfMask")} + ctx, sb, inputString, offset, code"$length - $endOfMask")} |${ev.value} = UTF8String.fromString($sb.toString()); |""".stripMargin }) @@ -553,9 +561,9 @@ case class MaskHash(child: Expression) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val digestUtilsClass = classOf[DigestUtils].getName.stripSuffix("$") - s""" + nullSafeCodeGen(ctx, ev, (input: ExprValue) => { + val digestUtilsClass = inline"${classOf[DigestUtils].getName.stripSuffix("$")}" + code""" |${ev.value} = UTF8String.fromString($digestUtilsClass.md5Hex($input.toString())); |""".stripMargin }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 19ebd895fc85..935262437476 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -2045,20 +2045,24 @@ case class FormatNumber(x: Expression, d: Expression) // SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.') // as a decimal separator. val usLocale = "US" - val numberFormat = ctx.addMutableState(df, "numberFormat", - v => s"""$v = new $df("", new $dfs($l.$usLocale));""") + val numberFormat = JavaCode.global(ctx.addMutableState(df, "numberFormat", + v => s"""$v = new $df("", new $dfs($l.$usLocale));"""), + classOf[DecimalFormat]) right.dataType match { case IntegerType => - val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();") - val i = ctx.freshName("i") - val lastDValue = - ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;") - s""" + val pattern = JavaCode.global( + ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();"), + classOf[StringBuffer]) + val i = JavaCode.variable(ctx.freshName("i"), IntegerType) + val lastDValue = JavaCode.global( + ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;"), + IntegerType) + code""" if ($d >= 0) { $pattern.delete(0, $pattern.length()); if ($d != $lastDValue) { - $pattern.append("$defaultFormat"); + $pattern.append("${inline"$defaultFormat"}"); if ($d > 0) { $pattern.append("."); @@ -2076,14 +2080,16 @@ case class FormatNumber(x: Expression, d: Expression) } """ case StringType => - val lastDValue = ctx.addMutableState("String", "lastDValue", v => s"""$v = null;""") - val dValue = ctx.freshName("dValue") - s""" + val lastDValue = JavaCode.global( + ctx.addMutableState("String", "lastDValue", v => s"""$v = null;"""), + classOf[String]) + val dValue = JavaCode.variable(ctx.freshName("dValue"), classOf[String]) + code""" String $dValue = $d.toString(); if (!$dValue.equals($lastDValue)) { $lastDValue = $dValue; if ($dValue.isEmpty()) { - $numberFormat.applyLocalizedPattern("$defaultFormat"); + $numberFormat.applyLocalizedPattern("${inline"$defaultFormat"}"); } else { $numberFormat.applyLocalizedPattern($dValue); }