From 89d025225b557689389d16c207be8a25f5e82fa5 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 12 Jun 2018 08:40:20 +0000 Subject: [PATCH 1/9] Convert strings in codegen to blocks. --- .../catalyst/expressions/BoundAttribute.scala | 4 +- .../spark/sql/catalyst/expressions/Cast.scala | 370 +++++++++--------- .../expressions/codegen/CodeGenerator.scala | 16 + .../expressions/codegen/javaCode.scala | 40 ++ 4 files changed, 251 insertions(+), 179 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index df3ab05e02c76..77582e10f9ff2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, JavaCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ @@ -53,7 +53,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) ev.copy(code = oev.code) } else { assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.") - val javaType = CodeGenerator.javaType(dataType) + val javaType = JavaCode.javaType(dataType) val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (nullable) { ev.copy(code = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 699ea53b5df0f..628bf5972a49a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -625,25 +625,23 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val eval = child.genCode(ctx) val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) - ev.copy(code = + ev.copy(code = eval.code + code""" - ${eval.code} - // This comment is added for manually tracking reference of ${eval.value}, ${eval.isNull} ${castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)} """) } // The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull` // in parameter list, because the returned code will be put in null safe evaluation region. - private[this] type CastFunction = (String, String, String) => String + private[this] type CastFunction = (ExprValue, ExprValue, ExprValue) => Block private[this] def nullSafeCastFunction( from: DataType, to: DataType, ctx: CodegenContext): CastFunction = to match { - case _ if from == NullType => (c, evPrim, evNull) => s"$evNull = true;" - case _ if to == from => (c, evPrim, evNull) => s"$evPrim = $c;" + case _ if from == NullType => (c, evPrim, evNull) => code"$evNull = true;" + case _ if to == from => (c, evPrim, evNull) => code"$evPrim = $c;" case StringType => castToStringCode(from, ctx) case BinaryType => castToBinaryCode(from) case DateType => castToDateCode(from, ctx) @@ -664,18 +662,19 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx) case udt: UserDefinedType[_] if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => - (c, evPrim, evNull) => s"$evPrim = $c;" + (c, evPrim, evNull) => code"$evPrim = $c;" case _: UserDefinedType[_] => throw new SparkException(s"Cannot cast $from to $to.") } // Since we need to cast input expressions recursively inside ComplexTypes, such as Map's // Key and Value, Struct's field, we need to name out all the variable names involved in a cast. - private[this] def castCode(ctx: CodegenContext, input: String, inputIsNull: String, - result: String, resultIsNull: String, resultType: DataType, cast: CastFunction): String = { - s""" + private[this] def castCode(ctx: CodegenContext, input: ExprValue, inputIsNull: ExprValue, + result: ExprValue, resultIsNull: ExprValue, resultType: DataType, cast: CastFunction): Block = { + val javaType = JavaCode.javaType(resultType) + code""" boolean $resultIsNull = $inputIsNull; - ${CodeGenerator.javaType(resultType)} $result = ${CodeGenerator.defaultValue(resultType)}; + $javaType $result = ${CodeGenerator.defaultValue(resultType)}; if (!$inputIsNull) { ${cast(input, result, resultIsNull)} } @@ -684,22 +683,24 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private def writeArrayToStringBuilder( et: DataType, - array: String, - buffer: String, - ctx: CodegenContext): String = { + array: ExprValue, + buffer: ExprValue, + ctx: CodegenContext): Block = { val elementToStringCode = castToStringCode(et, ctx) val funcName = ctx.freshName("elementToString") - val elementToStringFunc = ctx.addNewFunction(funcName, + val element = JavaCode.variable("element", et) + val elementStr = JavaCode.variable("elementStr", StringType) + val elementToStringFunc = inline"${ctx.addNewFunction(funcName, s""" - |private UTF8String $funcName(${CodeGenerator.javaType(et)} element) { - | UTF8String elementStr = null; - | ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)} + |private UTF8String $funcName(${CodeGenerator.javaType(et)} $element) { + | UTF8String $elementStr = null; + | ${elementToStringCode(element, elementStr, null /* resultIsNull won't be used */)} | return elementStr; |} - """.stripMargin) + """.stripMargin)}" - val loopIndex = ctx.freshName("loopIndex") - s""" + val loopIndex = ctx.freshName("loopIndex", IntegerType) + code""" |$buffer.append("["); |if ($array.numElements() > 0) { | if (!$array.isNullAt(0)) { @@ -720,31 +721,36 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private def writeMapToStringBuilder( kt: DataType, vt: DataType, - map: String, - buffer: String, - ctx: CodegenContext): String = { + map: ExprValue, + buffer: ExprValue, + ctx: CodegenContext): Block = { def dataToStringFunc(func: String, dataType: DataType) = { val funcName = ctx.freshName(func) val dataToStringCode = castToStringCode(dataType, ctx) + val data = JavaCode.variable("data", dataType) + val dataStr = JavaCode.variable("dataStr", StringType) ctx.addNewFunction(funcName, s""" - |private UTF8String $funcName(${CodeGenerator.javaType(dataType)} data) { - | UTF8String dataStr = null; - | ${dataToStringCode("data", "dataStr", null /* resultIsNull won't be used */)} + |private UTF8String $funcName(${CodeGenerator.javaType(dataType)} $data) { + | UTF8String $dataStr = null; + | ${dataToStringCode(data, dataStr, null /* resultIsNull won't be used */)} | return dataStr; |} """.stripMargin) } - val keyToStringFunc = dataToStringFunc("keyToString", kt) - val valueToStringFunc = dataToStringFunc("valueToString", vt) - val loopIndex = ctx.freshName("loopIndex") - val getMapFirstKey = CodeGenerator.getValue(s"$map.keyArray()", kt, "0") - val getMapFirstValue = CodeGenerator.getValue(s"$map.valueArray()", vt, "0") - val getMapKeyArray = CodeGenerator.getValue(s"$map.keyArray()", kt, loopIndex) - val getMapValueArray = CodeGenerator.getValue(s"$map.valueArray()", vt, loopIndex) - s""" + val keyToStringFunc = inline"${dataToStringFunc("keyToString", kt)}" + val valueToStringFunc = inline"${dataToStringFunc("valueToString", vt)}" + val loopIndex = ctx.freshName("loopIndex", IntegerType) + val mapKeyArray = JavaCode.expression(s"$map.keyArray()", classOf[ArrayData]) + val mapValueArray = JavaCode.expression(s"$map.valueArray()", classOf[ArrayData]) + val getMapFirstKey = CodeGenerator.getValue(mapKeyArray, kt, JavaCode.literal("0", IntegerType)) + val getMapFirstValue = CodeGenerator.getValue(mapValueArray, vt, + JavaCode.literal("0", IntegerType)) + val getMapKeyArray = CodeGenerator.getValue(mapKeyArray, kt, loopIndex) + val getMapValueArray = CodeGenerator.getValue(mapValueArray, vt, loopIndex) + code""" |$buffer.append("["); |if ($map.numElements() > 0) { | $buffer.append($keyToStringFunc($getMapFirstKey)); @@ -769,20 +775,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private def writeStructToStringBuilder( st: Seq[DataType], - row: String, - buffer: String, - ctx: CodegenContext): String = { + row: ExprValue, + buffer: ExprValue, + ctx: CodegenContext): Block = { val structToStringCode = st.zipWithIndex.map { case (ft, i) => val fieldToStringCode = castToStringCode(ft, ctx) - val field = ctx.freshName("field") - val fieldStr = ctx.freshName("fieldStr") - s""" - |${if (i != 0) s"""$buffer.append(",");""" else ""} + val field = ctx.freshName("field", ft) + val fieldStr = ctx.freshName("fieldStr", StringType) + val javaType = JavaCode.javaType(ft) + code""" + |${if (i != 0) code"""$buffer.append(",");""" else EmptyBlock} |if (!$row.isNullAt($i)) { - | ${if (i != 0) s"""$buffer.append(" ");""" else ""} + | ${if (i != 0) code"""$buffer.append(" ");""" else EmptyBlock} | | // Append $i field into the string buffer - | ${CodeGenerator.javaType(ft)} $field = ${CodeGenerator.getValue(row, ft, s"$i")}; + | $javaType $field = ${CodeGenerator.getValue(row, ft, s"$i")}; | UTF8String $fieldStr = null; | ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)} | $buffer.append($fieldStr); @@ -791,11 +798,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } val writeStructCode = ctx.splitExpressions( - expressions = structToStringCode, + expressions = structToStringCode.map(_.code), funcName = "fieldToString", - arguments = ("InternalRow", row) :: (classOf[UTF8StringBuilder].getName, buffer) :: Nil) + arguments = ("InternalRow", row.code) :: + (classOf[UTF8StringBuilder].getName, buffer.code) :: Nil) - s""" + code""" |$buffer.append("["); |$writeStructCode |$buffer.append("]"); @@ -805,20 +813,20 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => - (c, evPrim, evNull) => s"$evPrim = UTF8String.fromBytes($c);" + (c, evPrim, evNull) => code"$evPrim = UTF8String.fromBytes($c);" case DateType => - (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( + (c, evPrim, evNull) => code"""$evPrim = UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c));""" case TimestampType => - val tz = ctx.addReferenceObj("timeZone", timeZone) - (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), timeZone.getClass) + (c, evPrim, evNull) => code"""$evPrim = UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));""" case ArrayType(et, _) => (c, evPrim, evNull) => { - val buffer = ctx.freshName("buffer") - val bufferClass = classOf[UTF8StringBuilder].getName + val buffer = ctx.freshName("buffer", classOf[UTF8StringBuilder]) + val bufferClass = JavaCode.className(classOf[UTF8StringBuilder]) val writeArrayElemCode = writeArrayToStringBuilder(et, c, buffer, ctx) - s""" + code""" |$bufferClass $buffer = new $bufferClass(); |$writeArrayElemCode; |$evPrim = $buffer.build(); @@ -826,10 +834,10 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } case MapType(kt, vt, _) => (c, evPrim, evNull) => { - val buffer = ctx.freshName("buffer") - val bufferClass = classOf[UTF8StringBuilder].getName + val buffer = ctx.freshName("buffer", classOf[UTF8StringBuilder]) + val bufferClass = JavaCode.className(classOf[UTF8StringBuilder]) val writeMapElemCode = writeMapToStringBuilder(kt, vt, c, buffer, ctx) - s""" + code""" |$bufferClass $buffer = new $bufferClass(); |$writeMapElemCode; |$evPrim = $buffer.build(); @@ -837,11 +845,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } case StructType(fields) => (c, evPrim, evNull) => { - val row = ctx.freshName("row") - val buffer = ctx.freshName("buffer") - val bufferClass = classOf[UTF8StringBuilder].getName + val row = ctx.freshName("row", classOf[InternalRow]) + val buffer = ctx.freshName("buffer", classOf[UTF8StringBuilder]) + val bufferClass = JavaCode.className(classOf[UTF8StringBuilder]) val writeStructCode = writeStructToStringBuilder(fields.map(_.dataType), row, buffer, ctx) - s""" + code""" |InternalRow $row = $c; |$bufferClass $buffer = new $bufferClass(); |$writeStructCode @@ -850,26 +858,26 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } case pudt: PythonUserDefinedType => castToStringCode(pudt.sqlType, ctx) case udt: UserDefinedType[_] => - val udtRef = ctx.addReferenceObj("udt", udt) + val udtRef = JavaCode.global(ctx.addReferenceObj("udt", udt), udt.sqlType) (c, evPrim, evNull) => { - s"$evPrim = UTF8String.fromString($udtRef.deserialize($c).toString());" + code"$evPrim = UTF8String.fromString($udtRef.deserialize($c).toString());" } case _ => - (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" + (c, evPrim, evNull) => code"$evPrim = UTF8String.fromString(String.valueOf($c));" } } private[this] def castToBinaryCode(from: DataType): CastFunction = from match { case StringType => - (c, evPrim, evNull) => s"$evPrim = $c.getBytes();" + (c, evPrim, evNull) => code"$evPrim = $c.getBytes();" } private[this] def castToDateCode( from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val intOpt = ctx.freshName("intOpt") - (c, evPrim, evNull) => s""" + val intOpt = ctx.freshName("intOpt", classOf[Option[Integer]]) + (c, evPrim, evNull) => code""" scala.Option $intOpt = org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToDate($c); if ($intOpt.isDefined()) { @@ -879,16 +887,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ case TimestampType => - val tz = ctx.addReferenceObj("timeZone", timeZone) + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), timeZone.getClass) (c, evPrim, evNull) => - s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToDays($c / 1000L, $tz);" + code"""$evPrim = + org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToDays($c / 1000L, $tz);""" case _ => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" } - private[this] def changePrecision(d: String, decimalType: DecimalType, - evPrim: String, evNull: String): String = - s""" + private[this] def changePrecision(d: ExprValue, decimalType: DecimalType, + evPrim: ExprValue, evNull: ExprValue): Block = + code""" if ($d.changePrecision(${decimalType.precision}, ${decimalType.scale})) { $evPrim = $d; } else { @@ -900,11 +909,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String from: DataType, target: DecimalType, ctx: CodegenContext): CastFunction = { - val tmp = ctx.freshName("tmpDecimal") + val tmp = ctx.freshName("tmpDecimal", classOf[Decimal]) from match { case StringType => (c, evPrim, evNull) => - s""" + code""" try { Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString())); ${changePrecision(tmp, target, evPrim, evNull)} @@ -914,37 +923,37 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """ case BooleanType => (c, evPrim, evNull) => - s""" + code""" Decimal $tmp = $c ? Decimal.apply(1) : Decimal.apply(0); ${changePrecision(tmp, target, evPrim, evNull)} """ case DateType => // date can't cast to decimal in Hive - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => // Note that we lose precision here. (c, evPrim, evNull) => - s""" + code""" Decimal $tmp = Decimal.apply( scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)})); ${changePrecision(tmp, target, evPrim, evNull)} """ case DecimalType() => (c, evPrim, evNull) => - s""" + code""" Decimal $tmp = $c.clone(); ${changePrecision(tmp, target, evPrim, evNull)} """ case x: IntegralType => (c, evPrim, evNull) => - s""" + code""" Decimal $tmp = Decimal.apply((long) $c); ${changePrecision(tmp, target, evPrim, evNull)} """ case x: FractionalType => // All other numeric types can be represented precisely as Doubles (c, evPrim, evNull) => - s""" + code""" try { Decimal $tmp = Decimal.apply(scala.math.BigDecimal.valueOf((double) $c)); ${changePrecision(tmp, target, evPrim, evNull)} @@ -959,10 +968,10 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val tz = ctx.addReferenceObj("timeZone", timeZone) - val longOpt = ctx.freshName("longOpt") + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), timeZone.getClass) + val longOpt = ctx.freshName("longOpt", classOf[Option[Long]]) (c, evPrim, evNull) => - s""" + code""" scala.Option $longOpt = org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToTimestamp($c, $tz); if ($longOpt.isDefined()) { @@ -972,18 +981,19 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;" + (c, evPrim, evNull) => code"$evPrim = $c ? 1L : 0L;" case _: IntegralType => - (c, evPrim, evNull) => s"$evPrim = ${longToTimeStampCode(c)};" + (c, evPrim, evNull) => code"$evPrim = ${longToTimeStampCode(c)};" case DateType => - val tz = ctx.addReferenceObj("timeZone", timeZone) + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), timeZone.getClass) (c, evPrim, evNull) => - s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMillis($c, $tz) * 1000;" + code"""$evPrim = + org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMillis($c, $tz) * 1000;""" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = ${decimalToTimestampCode(c)};" + (c, evPrim, evNull) => code"$evPrim = ${decimalToTimestampCode(c)};" case DoubleType => (c, evPrim, evNull) => - s""" + code""" if (Double.isNaN($c) || Double.isInfinite($c)) { $evNull = true; } else { @@ -992,7 +1002,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """ case FloatType => (c, evPrim, evNull) => - s""" + code""" if (Float.isNaN($c) || Float.isInfinite($c)) { $evNull = true; } else { @@ -1004,7 +1014,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToIntervalCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s"""$evPrim = CalendarInterval.fromString($c.toString()); + code"""$evPrim = CalendarInterval.fromString($c.toString()); if(${evPrim} == null) { ${evNull} = true; } @@ -1012,18 +1022,22 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } - private[this] def decimalToTimestampCode(d: String): String = - s"($d.toBigDecimal().bigDecimal().multiply(new java.math.BigDecimal(1000000L))).longValue()" - private[this] def longToTimeStampCode(l: String): String = s"$l * 1000000L" - private[this] def timestampToIntegerCode(ts: String): String = - s"java.lang.Math.floor((double) $ts / 1000000L)" - private[this] def timestampToDoubleCode(ts: String): String = s"$ts / 1000000.0" + private[this] def decimalToTimestampCode(d: ExprValue): ExprValue = + JavaCode.expression( + s"($d.toBigDecimal().bigDecimal().multiply(new java.math.BigDecimal(1000000L))).longValue()", + TimestampType) + private[this] def longToTimeStampCode(l: ExprValue): ExprValue = + JavaCode.expression(s"$l * 1000000L", TimestampType) + private[this] def timestampToIntegerCode(ts: ExprValue): ExprValue = + JavaCode.expression(s"java.lang.Math.floor((double) $ts / 1000000L)", IntegerType) + private[this] def timestampToDoubleCode(ts: ExprValue): ExprValue = + JavaCode.expression(s"$ts / 1000000.0", DoubleType) private[this] def castToBooleanCode(from: DataType): CastFunction = from match { case StringType => - val stringUtils = StringUtils.getClass.getName.stripSuffix("$") + val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}" (c, evPrim, evNull) => - s""" + code""" if ($stringUtils.isTrueString($c)) { $evPrim = true; } else if ($stringUtils.isFalseString($c)) { @@ -1033,21 +1047,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ case TimestampType => - (c, evPrim, evNull) => s"$evPrim = $c != 0;" + (c, evPrim, evNull) => code"$evPrim = $c != 0;" case DateType => // Hive would return null when cast from date to boolean - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = !$c.isZero();" + (c, evPrim, evNull) => code"$evPrim = !$c.isZero();" case n: NumericType => - (c, evPrim, evNull) => s"$evPrim = $c != 0;" + (c, evPrim, evNull) => code"$evPrim = $c != 0;" } private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("intWrapper") + val wrapper = ctx.freshName("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => - s""" + code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); if ($c.toByte($wrapper)) { $evPrim = (byte) $wrapper.value; @@ -1057,24 +1071,24 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $wrapper = null; """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? (byte) 1 : (byte) 0;" + (c, evPrim, evNull) => code"$evPrim = $c ? (byte) 1 : (byte) 0;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = (byte) ${timestampToIntegerCode(c)};" + (c, evPrim, evNull) => code"$evPrim = (byte) ${timestampToIntegerCode(c)};" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toByte();" + (c, evPrim, evNull) => code"$evPrim = $c.toByte();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (byte) $c;" + (c, evPrim, evNull) => code"$evPrim = (byte) $c;" } private[this] def castToShortCode( from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("intWrapper") + val wrapper = ctx.freshName("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => - s""" + code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); if ($c.toShort($wrapper)) { $evPrim = (short) $wrapper.value; @@ -1084,22 +1098,22 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $wrapper = null; """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? (short) 1 : (short) 0;" + (c, evPrim, evNull) => code"$evPrim = $c ? (short) 1 : (short) 0;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = (short) ${timestampToIntegerCode(c)};" + (c, evPrim, evNull) => code"$evPrim = (short) ${timestampToIntegerCode(c)};" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toShort();" + (c, evPrim, evNull) => code"$evPrim = $c.toShort();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (short) $c;" + (c, evPrim, evNull) => code"$evPrim = (short) $c;" } private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("intWrapper") + val wrapper = ctx.freshName("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => - s""" + code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); if ($c.toInt($wrapper)) { $evPrim = $wrapper.value; @@ -1109,23 +1123,23 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $wrapper = null; """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + (c, evPrim, evNull) => code"$evPrim = $c ? 1 : 0;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = (int) ${timestampToIntegerCode(c)};" + (c, evPrim, evNull) => code"$evPrim = (int) ${timestampToIntegerCode(c)};" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toInt();" + (c, evPrim, evNull) => code"$evPrim = $c.toInt();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (int) $c;" + (c, evPrim, evNull) => code"$evPrim = (int) $c;" } private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("longWrapper") + val wrapper = ctx.freshName("longWrapper", classOf[UTF8String.LongWrapper]) (c, evPrim, evNull) => - s""" + code""" UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper(); if ($c.toLong($wrapper)) { $evPrim = $wrapper.value; @@ -1135,21 +1149,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $wrapper = null; """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;" + (c, evPrim, evNull) => code"$evPrim = $c ? 1L : 0L;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = (long) ${timestampToIntegerCode(c)};" + (c, evPrim, evNull) => code"$evPrim = (long) ${timestampToIntegerCode(c)};" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toLong();" + (c, evPrim, evNull) => code"$evPrim = $c.toLong();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (long) $c;" + (c, evPrim, evNull) => code"$evPrim = (long) $c;" } private[this] def castToFloatCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s""" + code""" try { $evPrim = Float.valueOf($c.toString()); } catch (java.lang.NumberFormatException e) { @@ -1157,21 +1171,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1.0f : 0.0f;" + (c, evPrim, evNull) => code"$evPrim = $c ? 1.0f : 0.0f;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = (float) (${timestampToDoubleCode(c)});" + (c, evPrim, evNull) => code"$evPrim = (float) (${timestampToDoubleCode(c)});" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toFloat();" + (c, evPrim, evNull) => code"$evPrim = $c.toFloat();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (float) $c;" + (c, evPrim, evNull) => code"$evPrim = (float) $c;" } private[this] def castToDoubleCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s""" + code""" try { $evPrim = Double.valueOf($c.toString()); } catch (java.lang.NumberFormatException e) { @@ -1179,31 +1193,32 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1.0d : 0.0d;" + (c, evPrim, evNull) => code"$evPrim = $c ? 1.0d : 0.0d;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = ${timestampToDoubleCode(c)};" + (c, evPrim, evNull) => code"$evPrim = ${timestampToDoubleCode(c)};" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toDouble();" + (c, evPrim, evNull) => code"$evPrim = $c.toDouble();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (double) $c;" + (c, evPrim, evNull) => code"$evPrim = (double) $c;" } private[this] def castArrayCode( fromType: DataType, toType: DataType, ctx: CodegenContext): CastFunction = { val elementCast = nullSafeCastFunction(fromType, toType, ctx) - val arrayClass = classOf[GenericArrayData].getName - val fromElementNull = ctx.freshName("feNull") - val fromElementPrim = ctx.freshName("fePrim") - val toElementNull = ctx.freshName("teNull") - val toElementPrim = ctx.freshName("tePrim") - val size = ctx.freshName("n") - val j = ctx.freshName("j") - val values = ctx.freshName("values") + val arrayClass = JavaCode.className(classOf[GenericArrayData]) + val fromElementNull = ctx.isNullFreshName("feNull") + val fromElementPrim = ctx.freshName("fePrim", fromType) + val toElementNull = ctx.isNullFreshName("teNull") + val toElementPrim = ctx.freshName("tePrim", toType) + val size = ctx.freshName("n", IntegerType) + val j = ctx.freshName("j", IntegerType) + val values = ctx.freshName("values", classOf[Array[Object]]) + val javaType = JavaCode.javaType(fromType) (c, evPrim, evNull) => - s""" + code""" final int $size = $c.numElements(); final Object[] $values = new Object[$size]; for (int $j = 0; $j < $size; $j ++) { @@ -1211,7 +1226,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $values[$j] = null; } else { boolean $fromElementNull = false; - ${CodeGenerator.javaType(fromType)} $fromElementPrim = + $javaType $fromElementPrim = ${CodeGenerator.getValue(c, fromType, j)}; ${castCode(ctx, fromElementPrim, fromElementNull, toElementPrim, toElementNull, toType, elementCast)} @@ -1230,23 +1245,23 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val keysCast = castArrayCode(from.keyType, to.keyType, ctx) val valuesCast = castArrayCode(from.valueType, to.valueType, ctx) - val mapClass = classOf[ArrayBasedMapData].getName + val mapClass = JavaCode.className(classOf[ArrayBasedMapData]) - val keys = ctx.freshName("keys") - val convertedKeys = ctx.freshName("convertedKeys") - val convertedKeysNull = ctx.freshName("convertedKeysNull") + val keys = ctx.freshName("keys", ArrayType(from.keyType)) + val convertedKeys = ctx.freshName("convertedKeys", ArrayType(to.keyType)) + val convertedKeysNull = ctx.isNullFreshName("convertedKeysNull") - val values = ctx.freshName("values") - val convertedValues = ctx.freshName("convertedValues") - val convertedValuesNull = ctx.freshName("convertedValuesNull") + val values = ctx.freshName("values", ArrayType(from.valueType)) + val convertedValues = ctx.freshName("convertedValues", ArrayType(to.valueType)) + val convertedValuesNull = ctx.isNullFreshName("convertedValuesNull") (c, evPrim, evNull) => - s""" + code""" final ArrayData $keys = $c.keyArray(); final ArrayData $values = $c.valueArray(); - ${castCode(ctx, keys, "false", + ${castCode(ctx, keys, FalseLiteral, convertedKeys, convertedKeysNull, ArrayType(to.keyType), keysCast)} - ${castCode(ctx, values, "false", + ${castCode(ctx, values, FalseLiteral, convertedValues, convertedValuesNull, ArrayType(to.valueType), valuesCast)} $evPrim = new $mapClass($convertedKeys, $convertedValues); @@ -1259,17 +1274,18 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val fieldsCasts = from.fields.zip(to.fields).map { case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx) } - val rowClass = classOf[GenericInternalRow].getName - val tmpResult = ctx.freshName("tmpResult") - val tmpInput = ctx.freshName("tmpInput") + val tmpResult = ctx.freshName("tmpResult", classOf[GenericInternalRow]) + val rowClass = JavaCode.className(classOf[GenericInternalRow]) + val tmpInput = ctx.freshName("tmpInput", classOf[InternalRow]) val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) => - val fromFieldPrim = ctx.freshName("ffp") - val fromFieldNull = ctx.freshName("ffn") - val toFieldPrim = ctx.freshName("tfp") - val toFieldNull = ctx.freshName("tfn") - val fromType = CodeGenerator.javaType(from.fields(i).dataType) - s""" + val fromFieldPrim = ctx.freshName("ffp", from.fields(i).dataType) + val fromFieldNull = ctx.isNullFreshName("ffn") + val toFieldPrim = ctx.freshName("tfp", to.fields(i).dataType) + val toFieldNull = ctx.isNullFreshName("tfn") + val fromType = JavaCode.javaType(from.fields(i).dataType) + val setColumn = CodeGenerator.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim) + code""" boolean $fromFieldNull = $tmpInput.isNullAt($i); if ($fromFieldNull) { $tmpResult.setNullAt($i); @@ -1281,18 +1297,18 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String if ($toFieldNull) { $tmpResult.setNullAt($i); } else { - ${CodeGenerator.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)}; + $setColumn; } } """ } val fieldsEvalCodes = ctx.splitExpressions( - expressions = fieldsEvalCode, + expressions = fieldsEvalCode.map(_.code), funcName = "castStruct", - arguments = ("InternalRow", tmpInput) :: (rowClass, tmpResult) :: Nil) + arguments = ("InternalRow", tmpInput.code) :: (rowClass.code, tmpResult.code) :: Nil) (input, result, resultIsNull) => - s""" + code""" final $rowClass $tmpResult = new $rowClass(${fieldsCasts.length}); final InternalRow $tmpInput = $input; $fieldsEvalCodes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 66315e5906253..388c9add5d204 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -579,6 +579,22 @@ class CodegenContext { s"${fullName}_$id" } + /** + * Creates an `ExprValue` representing a local java variable of required data type. + */ + def freshName(name: String, dt: DataType): VariableValue = JavaCode.variable(freshName(name), dt) + + /** + * Creates an `ExprValue` representing a local java variable of required data type. + */ + def freshName(name: String, javaClass: Class[_]): VariableValue = + JavaCode.variable(freshName(name), javaClass) + + /** + * Creates an `ExprValue` representing a local boolean java variable. + */ + def isNullFreshName(name: String): VariableValue = JavaCode.isNullVariable(freshName(name)) + /** * Generates code for equal expression in Java. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 250ce48d059e0..090204c95300c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -113,6 +113,21 @@ object JavaCode { def isNullExpression(code: String): SimpleExprValue = { expression(code, BooleanType) } + + /** + * Create an `InlineBlock` for Java Class name. + */ + def className(javaClass: Class[_]): InlineBlock = InlineBlock(javaClass.getName) + + /** + * Create an `InlineBlock` for Java Type name. + */ + def javaType(dataType: DataType): InlineBlock = InlineBlock(CodeGenerator.javaType(dataType)) + + /** + * Create an `InlineBlock` for boxed Java Type name. + */ + def boxedType(dataType: DataType): InlineBlock = InlineBlock(CodeGenerator.boxedType(dataType)) } /** @@ -155,6 +170,17 @@ object Block { val CODE_BLOCK_BUFFER_LENGTH: Int = 512 + /** + * A custom string interpolator which inlines all types of input arguments into a string without + * tracking any reference of `JavaCode` instances. + */ + implicit class InlineHelper(val sc: StringContext) extends AnyVal { + def inline(args: Any*): Block = { + val inlineString = sc.raw(args: _*) + InlineBlock(inlineString) + } + } + implicit def blocksToBlock(blocks: Seq[Block]): Block = Blocks(blocks) implicit class BlockHelper(val sc: StringContext) extends AnyVal { @@ -233,6 +259,7 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends override def + (other: Block): Block = other match { case c: CodeBlock => Blocks(Seq(this, c)) + case i: InlineBlock => Blocks(Seq(this, i)) case b: Blocks => Blocks(Seq(this) ++ b.blocks) case EmptyBlock => this } @@ -244,6 +271,7 @@ case class Blocks(blocks: Seq[Block]) extends Block { override def + (other: Block): Block = other match { case c: CodeBlock => Blocks(blocks :+ c) + case i: InlineBlock => Blocks(blocks :+ i) case b: Blocks => Blocks(blocks ++ b.blocks) case EmptyBlock => this } @@ -256,6 +284,18 @@ object EmptyBlock extends Block with Serializable { override def + (other: Block): Block = other } +case class InlineBlock(block: String) extends Block { + override val code: String = block + override val exprValues: Set[ExprValue] = Set.empty + + override def + (other: Block): Block = other match { + case c: CodeBlock => Blocks(Seq(this, c)) + case i: InlineBlock => InlineBlock(block + i.block) + case b: Blocks => Blocks(Seq(this) ++ b.blocks) + case EmptyBlock => this + } +} + /** * A typed java fragment that must be a valid java expression. */ From 531faf4aad4c942efb826fe7ca2ba8a7f1e5cf3f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 13 Jun 2018 05:40:15 +0000 Subject: [PATCH 2/9] Address comment. --- .../spark/sql/catalyst/expressions/Cast.scala | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 628bf5972a49a..e0beaaa1df9b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -626,9 +626,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) ev.copy(code = eval.code + - code""" - ${castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)} - """) + castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)) } // The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull` @@ -1022,16 +1020,15 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } - private[this] def decimalToTimestampCode(d: ExprValue): ExprValue = - JavaCode.expression( - s"($d.toBigDecimal().bigDecimal().multiply(new java.math.BigDecimal(1000000L))).longValue()", - TimestampType) - private[this] def longToTimeStampCode(l: ExprValue): ExprValue = - JavaCode.expression(s"$l * 1000000L", TimestampType) - private[this] def timestampToIntegerCode(ts: ExprValue): ExprValue = - JavaCode.expression(s"java.lang.Math.floor((double) $ts / 1000000L)", IntegerType) - private[this] def timestampToDoubleCode(ts: ExprValue): ExprValue = - JavaCode.expression(s"$ts / 1000000.0", DoubleType) + private[this] def decimalToTimestampCode(d: ExprValue): Block = { + val block = code"new java.math.BigDecimal(1000000L)" + code"($d.toBigDecimal().bigDecimal().multiply($block)).longValue()" + } + private[this] def longToTimeStampCode(l: ExprValue): Block = code"$l * 1000000L" + private[this] def timestampToIntegerCode(ts: ExprValue): Block = + code"java.lang.Math.floor((double) $ts / 1000000L)" + private[this] def timestampToDoubleCode(ts: ExprValue): Block = + code"$ts / 1000000.0" private[this] def castToBooleanCode(from: DataType): CastFunction = from match { case StringType => From b592e66c030ba7c2d260c3be48c3b15139f40e5b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 14 Jun 2018 07:00:27 +0000 Subject: [PATCH 3/9] Rename api. --- .../spark/sql/catalyst/expressions/Cast.scala | 68 +++++++++---------- .../expressions/codegen/CodeGenerator.scala | 10 +-- 2 files changed, 37 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index e0beaaa1df9b3..e8b833195ff17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -697,7 +697,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String |} """.stripMargin)}" - val loopIndex = ctx.freshName("loopIndex", IntegerType) + val loopIndex = ctx.freshVariable("loopIndex", IntegerType) code""" |$buffer.append("["); |if ($array.numElements() > 0) { @@ -740,7 +740,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val keyToStringFunc = inline"${dataToStringFunc("keyToString", kt)}" val valueToStringFunc = inline"${dataToStringFunc("valueToString", vt)}" - val loopIndex = ctx.freshName("loopIndex", IntegerType) + val loopIndex = ctx.freshVariable("loopIndex", IntegerType) val mapKeyArray = JavaCode.expression(s"$map.keyArray()", classOf[ArrayData]) val mapValueArray = JavaCode.expression(s"$map.valueArray()", classOf[ArrayData]) val getMapFirstKey = CodeGenerator.getValue(mapKeyArray, kt, JavaCode.literal("0", IntegerType)) @@ -778,8 +778,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String ctx: CodegenContext): Block = { val structToStringCode = st.zipWithIndex.map { case (ft, i) => val fieldToStringCode = castToStringCode(ft, ctx) - val field = ctx.freshName("field", ft) - val fieldStr = ctx.freshName("fieldStr", StringType) + val field = ctx.freshVariable("field", ft) + val fieldStr = ctx.freshVariable("fieldStr", StringType) val javaType = JavaCode.javaType(ft) code""" |${if (i != 0) code"""$buffer.append(",");""" else EmptyBlock} @@ -821,7 +821,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));""" case ArrayType(et, _) => (c, evPrim, evNull) => { - val buffer = ctx.freshName("buffer", classOf[UTF8StringBuilder]) + val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder]) val bufferClass = JavaCode.className(classOf[UTF8StringBuilder]) val writeArrayElemCode = writeArrayToStringBuilder(et, c, buffer, ctx) code""" @@ -832,7 +832,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } case MapType(kt, vt, _) => (c, evPrim, evNull) => { - val buffer = ctx.freshName("buffer", classOf[UTF8StringBuilder]) + val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder]) val bufferClass = JavaCode.className(classOf[UTF8StringBuilder]) val writeMapElemCode = writeMapToStringBuilder(kt, vt, c, buffer, ctx) code""" @@ -843,8 +843,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } case StructType(fields) => (c, evPrim, evNull) => { - val row = ctx.freshName("row", classOf[InternalRow]) - val buffer = ctx.freshName("buffer", classOf[UTF8StringBuilder]) + val row = ctx.freshVariable("row", classOf[InternalRow]) + val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder]) val bufferClass = JavaCode.className(classOf[UTF8StringBuilder]) val writeStructCode = writeStructToStringBuilder(fields.map(_.dataType), row, buffer, ctx) code""" @@ -874,7 +874,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val intOpt = ctx.freshName("intOpt", classOf[Option[Integer]]) + val intOpt = ctx.freshVariable("intOpt", classOf[Option[Integer]]) (c, evPrim, evNull) => code""" scala.Option $intOpt = org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToDate($c); @@ -907,7 +907,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String from: DataType, target: DecimalType, ctx: CodegenContext): CastFunction = { - val tmp = ctx.freshName("tmpDecimal", classOf[Decimal]) + val tmp = ctx.freshVariable("tmpDecimal", classOf[Decimal]) from match { case StringType => (c, evPrim, evNull) => @@ -967,7 +967,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String ctx: CodegenContext): CastFunction = from match { case StringType => val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), timeZone.getClass) - val longOpt = ctx.freshName("longOpt", classOf[Option[Long]]) + val longOpt = ctx.freshVariable("longOpt", classOf[Option[Long]]) (c, evPrim, evNull) => code""" scala.Option $longOpt = @@ -1056,7 +1056,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("intWrapper", classOf[UTF8String.IntWrapper]) + val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); @@ -1083,7 +1083,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("intWrapper", classOf[UTF8String.IntWrapper]) + val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); @@ -1108,7 +1108,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("intWrapper", classOf[UTF8String.IntWrapper]) + val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); @@ -1133,7 +1133,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("longWrapper", classOf[UTF8String.LongWrapper]) + val wrapper = ctx.freshVariable("longWrapper", classOf[UTF8String.LongWrapper]) (c, evPrim, evNull) => code""" @@ -1205,13 +1205,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String fromType: DataType, toType: DataType, ctx: CodegenContext): CastFunction = { val elementCast = nullSafeCastFunction(fromType, toType, ctx) val arrayClass = JavaCode.className(classOf[GenericArrayData]) - val fromElementNull = ctx.isNullFreshName("feNull") - val fromElementPrim = ctx.freshName("fePrim", fromType) - val toElementNull = ctx.isNullFreshName("teNull") - val toElementPrim = ctx.freshName("tePrim", toType) - val size = ctx.freshName("n", IntegerType) - val j = ctx.freshName("j", IntegerType) - val values = ctx.freshName("values", classOf[Array[Object]]) + val fromElementNull = ctx.freshVariable("feNull", BooleanType) + val fromElementPrim = ctx.freshVariable("fePrim", fromType) + val toElementNull = ctx.freshVariable("teNull", BooleanType) + val toElementPrim = ctx.freshVariable("tePrim", toType) + val size = ctx.freshVariable("n", IntegerType) + val j = ctx.freshVariable("j", IntegerType) + val values = ctx.freshVariable("values", classOf[Array[Object]]) val javaType = JavaCode.javaType(fromType) (c, evPrim, evNull) => @@ -1244,13 +1244,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val mapClass = JavaCode.className(classOf[ArrayBasedMapData]) - val keys = ctx.freshName("keys", ArrayType(from.keyType)) - val convertedKeys = ctx.freshName("convertedKeys", ArrayType(to.keyType)) - val convertedKeysNull = ctx.isNullFreshName("convertedKeysNull") + val keys = ctx.freshVariable("keys", ArrayType(from.keyType)) + val convertedKeys = ctx.freshVariable("convertedKeys", ArrayType(to.keyType)) + val convertedKeysNull = ctx.freshVariable("convertedKeysNull", BooleanType) - val values = ctx.freshName("values", ArrayType(from.valueType)) - val convertedValues = ctx.freshName("convertedValues", ArrayType(to.valueType)) - val convertedValuesNull = ctx.isNullFreshName("convertedValuesNull") + val values = ctx.freshVariable("values", ArrayType(from.valueType)) + val convertedValues = ctx.freshVariable("convertedValues", ArrayType(to.valueType)) + val convertedValuesNull = ctx.freshVariable("convertedValuesNull", BooleanType) (c, evPrim, evNull) => code""" @@ -1271,15 +1271,15 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val fieldsCasts = from.fields.zip(to.fields).map { case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx) } - val tmpResult = ctx.freshName("tmpResult", classOf[GenericInternalRow]) + val tmpResult = ctx.freshVariable("tmpResult", classOf[GenericInternalRow]) val rowClass = JavaCode.className(classOf[GenericInternalRow]) - val tmpInput = ctx.freshName("tmpInput", classOf[InternalRow]) + val tmpInput = ctx.freshVariable("tmpInput", classOf[InternalRow]) val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) => - val fromFieldPrim = ctx.freshName("ffp", from.fields(i).dataType) - val fromFieldNull = ctx.isNullFreshName("ffn") - val toFieldPrim = ctx.freshName("tfp", to.fields(i).dataType) - val toFieldNull = ctx.isNullFreshName("tfn") + val fromFieldPrim = ctx.freshVariable("ffp", from.fields(i).dataType) + val fromFieldNull = ctx.freshVariable("ffn", BooleanType) + val toFieldPrim = ctx.freshVariable("tfp", to.fields(i).dataType) + val toFieldNull = ctx.freshVariable("tfn", BooleanType) val fromType = JavaCode.javaType(from.fields(i).dataType) val setColumn = CodeGenerator.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim) code""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 388c9add5d204..7140c3ab27ef2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -582,19 +582,15 @@ class CodegenContext { /** * Creates an `ExprValue` representing a local java variable of required data type. */ - def freshName(name: String, dt: DataType): VariableValue = JavaCode.variable(freshName(name), dt) + def freshVariable(name: String, dt: DataType): VariableValue = + JavaCode.variable(freshName(name), dt) /** * Creates an `ExprValue` representing a local java variable of required data type. */ - def freshName(name: String, javaClass: Class[_]): VariableValue = + def freshVariable(name: String, javaClass: Class[_]): VariableValue = JavaCode.variable(freshName(name), javaClass) - /** - * Creates an `ExprValue` representing a local boolean java variable. - */ - def isNullFreshName(name: String): VariableValue = JavaCode.isNullVariable(freshName(name)) - /** * Generates code for equal expression in Java. */ From a972e0ef694a9e39913f2ea859034cbc4d871f02 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 15 Jun 2018 03:59:06 +0000 Subject: [PATCH 4/9] Improve document. --- .../spark/sql/catalyst/expressions/codegen/javaCode.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 090204c95300c..d84c40514f742 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -171,8 +171,7 @@ object Block { val CODE_BLOCK_BUFFER_LENGTH: Int = 512 /** - * A custom string interpolator which inlines all types of input arguments into a string without - * tracking any reference of `JavaCode` instances. + * A custom string interpolator which inlines a string into code block. */ implicit class InlineHelper(val sc: StringContext) extends AnyVal { def inline(args: Any*): Block = { @@ -284,6 +283,10 @@ object EmptyBlock extends Block with Serializable { override def + (other: Block): Block = other } +/** + * A block inlines all types of input arguments into a string without + * tracking any reference of `JavaCode` instances. + */ case class InlineBlock(block: String) extends Block { override val code: String = block override val exprValues: Set[ExprValue] = Set.empty From 4b90551128353ac176968f77b10284a3b0d9eec7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 21 Jun 2018 09:27:54 +0000 Subject: [PATCH 5/9] Address comment. --- .../apache/spark/sql/catalyst/expressions/Cast.scala | 12 ++++++------ .../sql/catalyst/expressions/codegen/javaCode.scala | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index e8b833195ff17..7de5d7c7641d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -822,7 +822,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case ArrayType(et, _) => (c, evPrim, evNull) => { val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder]) - val bufferClass = JavaCode.className(classOf[UTF8StringBuilder]) + val bufferClass = JavaCode.javaType(classOf[UTF8StringBuilder]) val writeArrayElemCode = writeArrayToStringBuilder(et, c, buffer, ctx) code""" |$bufferClass $buffer = new $bufferClass(); @@ -833,7 +833,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case MapType(kt, vt, _) => (c, evPrim, evNull) => { val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder]) - val bufferClass = JavaCode.className(classOf[UTF8StringBuilder]) + val bufferClass = JavaCode.javaType(classOf[UTF8StringBuilder]) val writeMapElemCode = writeMapToStringBuilder(kt, vt, c, buffer, ctx) code""" |$bufferClass $buffer = new $bufferClass(); @@ -845,7 +845,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => { val row = ctx.freshVariable("row", classOf[InternalRow]) val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder]) - val bufferClass = JavaCode.className(classOf[UTF8StringBuilder]) + val bufferClass = JavaCode.javaType(classOf[UTF8StringBuilder]) val writeStructCode = writeStructToStringBuilder(fields.map(_.dataType), row, buffer, ctx) code""" |InternalRow $row = $c; @@ -1204,7 +1204,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castArrayCode( fromType: DataType, toType: DataType, ctx: CodegenContext): CastFunction = { val elementCast = nullSafeCastFunction(fromType, toType, ctx) - val arrayClass = JavaCode.className(classOf[GenericArrayData]) + val arrayClass = JavaCode.javaType(classOf[GenericArrayData]) val fromElementNull = ctx.freshVariable("feNull", BooleanType) val fromElementPrim = ctx.freshVariable("fePrim", fromType) val toElementNull = ctx.freshVariable("teNull", BooleanType) @@ -1242,7 +1242,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val keysCast = castArrayCode(from.keyType, to.keyType, ctx) val valuesCast = castArrayCode(from.valueType, to.valueType, ctx) - val mapClass = JavaCode.className(classOf[ArrayBasedMapData]) + val mapClass = JavaCode.javaType(classOf[ArrayBasedMapData]) val keys = ctx.freshVariable("keys", ArrayType(from.keyType)) val convertedKeys = ctx.freshVariable("convertedKeys", ArrayType(to.keyType)) @@ -1272,7 +1272,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx) } val tmpResult = ctx.freshVariable("tmpResult", classOf[GenericInternalRow]) - val rowClass = JavaCode.className(classOf[GenericInternalRow]) + val rowClass = JavaCode.javaType(classOf[GenericInternalRow]) val tmpInput = ctx.freshVariable("tmpInput", classOf[InternalRow]) val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index d84c40514f742..654f1d0691d47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -117,7 +117,7 @@ object JavaCode { /** * Create an `InlineBlock` for Java Class name. */ - def className(javaClass: Class[_]): InlineBlock = InlineBlock(javaClass.getName) + def javaType(javaClass: Class[_]): InlineBlock = InlineBlock(javaClass.getName) /** * Create an `InlineBlock` for Java Type name. From 7b0ee94753454300584f085c934cb4994f59055c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 16 Jul 2018 07:05:32 +0000 Subject: [PATCH 6/9] Update comment. --- .../spark/sql/catalyst/expressions/codegen/CodeGenerator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 9c939834d1cbc..b0832d888b75d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -586,7 +586,7 @@ class CodegenContext { JavaCode.variable(freshName(name), dt) /** - * Creates an `ExprValue` representing a local java variable of required data type. + * Creates an `ExprValue` representing a local java variable of required Java class. */ def freshVariable(name: String, javaClass: Class[_]): VariableValue = JavaCode.variable(freshName(name), javaClass) From f1f218068bc1e4c147c14dbc56c874c4c7d7cc4b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 16 Jul 2018 07:25:47 +0000 Subject: [PATCH 7/9] Address comment. --- .../org/apache/spark/sql/catalyst/expressions/Cast.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index f6cbfd4071aed..d0773064f02ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -748,7 +748,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val dataToStringCode = castToStringCode(dataType, ctx) val data = JavaCode.variable("data", dataType) val dataStr = JavaCode.variable("dataStr", StringType) - ctx.addNewFunction(funcName, + val functionCall = ctx.addNewFunction(funcName, s""" |private UTF8String $funcName(${CodeGenerator.javaType(dataType)} $data) { | UTF8String $dataStr = null; @@ -756,10 +756,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String | return dataStr; |} """.stripMargin) + inline"$functionCall" } - val keyToStringFunc = inline"${dataToStringFunc("keyToString", kt)}" - val valueToStringFunc = inline"${dataToStringFunc("valueToString", vt)}" + val keyToStringFunc = dataToStringFunc("keyToString", kt) + val valueToStringFunc = dataToStringFunc("valueToString", vt) val loopIndex = ctx.freshVariable("loopIndex", IntegerType) val mapKeyArray = JavaCode.expression(s"$map.keyArray()", classOf[ArrayData]) val mapValueArray = JavaCode.expression(s"$map.valueArray()", classOf[ArrayData]) From 807d8d44f950b8a588065b15bb7fa6a5db753075 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 16 Jul 2018 19:41:33 +0000 Subject: [PATCH 8/9] Address comment. --- .../spark/sql/catalyst/expressions/codegen/javaCode.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index a0a54f971b18f..3aa3bd872d3e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -221,9 +221,8 @@ object Block { EmptyBlock } else { args.foreach { - case _: ExprValue | _: Inline => + case _: ExprValue | _: Inline | _: Block => case _: Int | _: Long | _: Float | _: Double | _: String => - case _: Block => case other => throw new IllegalArgumentException( s"Can not interpolate ${other.getClass.getName} into code block.") } From 508e091f53084deefc35001ce8d89455ca549e53 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 12 Aug 2018 18:46:26 +0000 Subject: [PATCH 9/9] Address comment. --- .../scala/org/apache/spark/sql/catalyst/expressions/Cast.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d0773064f02ae..c5c2ee62c5cd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1042,7 +1042,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } private[this] def decimalToTimestampCode(d: ExprValue): Block = { - val block = code"new java.math.BigDecimal(1000000L)" + val block = inline"new java.math.BigDecimal(1000000L)" code"($d.toBigDecimal().bigDecimal().multiply($block)).longValue()" } private[this] def longToTimeStampCode(l: ExprValue): Block = code"$l * 1000000L"