Skip to content

Commit 2a85388

Browse files
committed
Second pass.
1 parent d80a948 commit 2a85388

File tree

7 files changed

+120
-107
lines changed

7 files changed

+120
-107
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
5454
} else {
5555
assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.")
5656
val javaType = inline"${CodeGenerator.javaType(dataType)}"
57-
val value = JavaCode.expression(
58-
CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString),
59-
dataType)
57+
val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
6058
if (nullable) {
6159
ev.copy(code =
6260
code"""

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1376,7 +1376,7 @@ object CodeGenerator extends Logging {
13761376
val msg = s"failed to compile: $e"
13771377
logError(msg, e)
13781378
val maxLines = SQLConf.get.loggingMaxLinesForCodegen
1379-
println(s"\n${CodeFormatter.format(code, maxLines)}")
1379+
logInfo(s"\n${CodeFormatter.format(code, maxLines)}")
13801380
throw new CompileException(msg, e.getLocation)
13811381
}
13821382

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
5757
val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) =>
5858
val converter = convertToSafe(
5959
ctx,
60-
JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt),
60+
CodeGenerator.getValue(tmpInput, dt, i.toString),
6161
dt)
6262
code"""
6363
if (!$tmpInput.isNullAt($i)) {
@@ -96,7 +96,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
9696

9797
val elementConverter = convertToSafe(
9898
ctx,
99-
JavaCode.expression(CodeGenerator.getValue(tmpInput, elementType, index), elementType),
99+
CodeGenerator.getValue(tmpInput, elementType, index),
100100
elementType)
101101
val code = code"""
102102
final ArrayData $tmpInput = $input;

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
5757
val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
5858
ExprCode(
5959
JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)"),
60-
JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt))
60+
CodeGenerator.getValue(tmpInput, dt, i.toString))
6161
}
6262

6363
val rowWriterClass = classOf[UnsafeRowWriter].getName

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2159,12 +2159,12 @@ case class ArrayRemove(left: Expression, right: Expression)
21592159

21602160
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
21612161
nullSafeCodeGen(ctx, ev, (arr, value) => {
2162-
val numsToRemove = ctx.freshName("numsToRemove")
2163-
val newArraySize = ctx.freshName("newArraySize")
2164-
val i = ctx.freshName("i")
2162+
val numsToRemove = JavaCode.variable(ctx.freshName("numsToRemove"), IntegerType)
2163+
val newArraySize = JavaCode.variable(ctx.freshName("newArraySize"), IntegerType)
2164+
val i = JavaCode.variable(ctx.freshName("i"), IntegerType)
21652165
val getValue = CodeGenerator.getValue(arr, elementType, i)
21662166
val isEqual = ctx.genEqual(elementType, value, getValue)
2167-
s"""
2167+
code"""
21682168
|int $numsToRemove = 0;
21692169
|for (int $i = 0; $i < $arr.numElements(); $i ++) {
21702170
| if (!$arr.isNullAt($i) && $isEqual) {
@@ -2180,17 +2180,17 @@ case class ArrayRemove(left: Expression, right: Expression)
21802180
def genCodeForResult(
21812181
ctx: CodegenContext,
21822182
ev: ExprCode,
2183-
inputArray: String,
2184-
value: String,
2185-
newArraySize: String): String = {
2186-
val values = ctx.freshName("values")
2187-
val i = ctx.freshName("i")
2188-
val pos = ctx.freshName("pos")
2183+
inputArray: ExprValue,
2184+
value: ExprValue,
2185+
newArraySize: ExprValue): Block = {
2186+
val values = JavaCode.variable(ctx.freshName("values"), classOf[Array[Object]])
2187+
val i = JavaCode.variable(ctx.freshName("i"), IntegerType)
2188+
val pos = JavaCode.variable(ctx.freshName("pos"), IntegerType)
21892189
val getValue = CodeGenerator.getValue(inputArray, elementType, i)
21902190
val isEqual = ctx.genEqual(elementType, value, getValue)
21912191
if (!CodeGenerator.isPrimitiveType(elementType)) {
2192-
val arrayClass = classOf[GenericArrayData].getName
2193-
s"""
2192+
val arrayClass = inline"${classOf[GenericArrayData].getName}"
2193+
code"""
21942194
|int $pos = 0;
21952195
|Object[] $values = new Object[$newArraySize];
21962196
|for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
@@ -2208,9 +2208,10 @@ case class ArrayRemove(left: Expression, right: Expression)
22082208
|${ev.value} = new $arrayClass($values);
22092209
""".stripMargin
22102210
} else {
2211-
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
2212-
s"""
2213-
|${ctx.createUnsafeArray(values, newArraySize, elementType, s" $prettyName failed.")}
2211+
val primitiveValueTypeName = inline"${CodeGenerator.primitiveTypeName(elementType)}"
2212+
val errorMsg = new LiteralValue(s" $prettyName failed.", classOf[String])
2213+
code"""
2214+
|${ctx.createUnsafeArray(values, newArraySize, elementType, errorMsg)}
22142215
|int $pos = 0;
22152216
|for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
22162217
| if ($inputArray.isNullAt($i)) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala

Lines changed: 81 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ import org.apache.commons.codec.digest.DigestUtils
2222
import org.apache.spark.sql.AnalysisException
2323
import org.apache.spark.sql.catalyst.expressions.MaskExpressionsUtils._
2424
import org.apache.spark.sql.catalyst.expressions.MaskLike._
25-
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
25+
import org.apache.spark.sql.catalyst.expressions.codegen.{Block, CodegenContext, CodeGenerator, ExprCode, ExprValue, JavaCode}
26+
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2627
import org.apache.spark.sql.types._
2728
import org.apache.spark.unsafe.types.UTF8String
2829

@@ -36,23 +37,25 @@ trait MaskLike {
3637
protected lazy val lowerReplacement: Int = getReplacementChar(lower, defaultMaskedLowercase)
3738
protected lazy val digitReplacement: Int = getReplacementChar(digit, defaultMaskedDigit)
3839

39-
protected val maskUtilsClassName: String = classOf[MaskExpressionsUtils].getName
40+
protected val maskUtilsClassName: Block = inline"${classOf[MaskExpressionsUtils].getName}"
4041

41-
def inputStringLengthCode(inputString: String, length: String): String = {
42-
s"${CodeGenerator.JAVA_INT} $length = $inputString.codePointCount(0, $inputString.length());"
42+
def inputStringLengthCode(inputString: ExprValue, length: ExprValue): Block = {
43+
val intType = inline"${CodeGenerator.JAVA_INT}"
44+
code"$intType $length = $inputString.codePointCount(0, $inputString.length());"
4345
}
4446

4547
def appendMaskedToStringBuilderCode(
4648
ctx: CodegenContext,
47-
sb: String,
48-
inputString: String,
49-
offset: String,
50-
numChars: String): String = {
51-
val i = ctx.freshName("i")
52-
val codePoint = ctx.freshName("codePoint")
53-
s"""
54-
|for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) {
55-
| ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset);
49+
sb: ExprValue,
50+
inputString: ExprValue,
51+
offset: ExprValue,
52+
numChars: JavaCode): Block = {
53+
val i = JavaCode.variable(ctx.freshName("i"), IntegerType)
54+
val codePoint = JavaCode.variable(ctx.freshName("codePoint"), IntegerType)
55+
val intType = inline"${CodeGenerator.JAVA_INT}"
56+
code"""
57+
|for ($intType $i = 0; $i < $numChars; $i++) {
58+
| $intType $codePoint = $inputString.codePointAt($offset);
5659
| $sb.appendCodePoint($maskUtilsClassName.transformChar($codePoint,
5760
| $upperReplacement, $lowerReplacement,
5861
| $digitReplacement, $defaultMaskedOther));
@@ -63,15 +66,16 @@ trait MaskLike {
6366

6467
def appendUnchangedToStringBuilderCode(
6568
ctx: CodegenContext,
66-
sb: String,
67-
inputString: String,
68-
offset: String,
69-
numChars: String): String = {
70-
val i = ctx.freshName("i")
71-
val codePoint = ctx.freshName("codePoint")
72-
s"""
73-
|for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) {
74-
| ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset);
69+
sb: ExprValue,
70+
inputString: ExprValue,
71+
offset: ExprValue,
72+
numChars: JavaCode): Block = {
73+
val i = JavaCode.variable(ctx.freshName("i"), IntegerType)
74+
val codePoint = JavaCode.variable(ctx.freshName("codePoint"), IntegerType)
75+
val intType = inline"${CodeGenerator.JAVA_INT}"
76+
code"""
77+
|for ($intType $i = 0; $i < $numChars; $i++) {
78+
| $intType $codePoint = $inputString.codePointAt($offset);
7579
| $sb.appendCodePoint($codePoint);
7680
| $offset += Character.charCount($codePoint);
7781
|}
@@ -179,16 +183,17 @@ case class Mask(child: Expression, upper: String, lower: String, digit: String)
179183
}
180184

181185
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
182-
nullSafeCodeGen(ctx, ev, (input: String) => {
183-
val sb = ctx.freshName("sb")
184-
val length = ctx.freshName("length")
185-
val offset = ctx.freshName("offset")
186-
val inputString = ctx.freshName("inputString")
187-
s"""
186+
nullSafeCodeGen(ctx, ev, (input: ExprValue) => {
187+
val sb = JavaCode.variable(ctx.freshName("sb"), classOf[StringBuilder])
188+
val length = JavaCode.variable(ctx.freshName("length"), IntegerType)
189+
val offset = JavaCode.variable(ctx.freshName("offset"), IntegerType)
190+
val inputString = JavaCode.variable(ctx.freshName("inputString"), classOf[String])
191+
val intType = inline"${CodeGenerator.JAVA_INT}"
192+
code"""
188193
|String $inputString = $input.toString();
189194
|${inputStringLengthCode(inputString, length)}
190195
|StringBuilder $sb = new StringBuilder($length);
191-
|${CodeGenerator.JAVA_INT} $offset = 0;
196+
|$intType $offset = 0;
192197
|${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, length)}
193198
|${ev.value} = UTF8String.fromString($sb.toString());
194199
""".stripMargin
@@ -256,21 +261,22 @@ case class MaskFirstN(
256261
}
257262

258263
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
259-
nullSafeCodeGen(ctx, ev, (input: String) => {
260-
val sb = ctx.freshName("sb")
261-
val length = ctx.freshName("length")
262-
val offset = ctx.freshName("offset")
263-
val inputString = ctx.freshName("inputString")
264-
val endOfMask = ctx.freshName("endOfMask")
265-
s"""
264+
nullSafeCodeGen(ctx, ev, (input: ExprValue) => {
265+
val sb = JavaCode.variable(ctx.freshName("sb"), classOf[StringBuilder])
266+
val length = JavaCode.variable(ctx.freshName("length"), IntegerType)
267+
val offset = JavaCode.variable(ctx.freshName("offset"), IntegerType)
268+
val inputString = JavaCode.variable(ctx.freshName("inputString"), classOf[String])
269+
val endOfMask = JavaCode.variable(ctx.freshName("endOfMask"), IntegerType)
270+
val intType = inline"${CodeGenerator.JAVA_INT}"
271+
code"""
266272
|String $inputString = $input.toString();
267273
|${inputStringLengthCode(inputString, length)}
268-
|${CodeGenerator.JAVA_INT} $endOfMask = $charCount > $length ? $length : $charCount;
269-
|${CodeGenerator.JAVA_INT} $offset = 0;
274+
|$intType $endOfMask = $charCount > $length ? $length : $charCount;
275+
|$intType $offset = 0;
270276
|StringBuilder $sb = new StringBuilder($length);
271277
|${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)}
272278
|${appendUnchangedToStringBuilderCode(
273-
ctx, sb, inputString, offset, s"$length - $endOfMask")}
279+
ctx, sb, inputString, offset, code"$length - $endOfMask")}
274280
|${ev.value} = UTF8String.fromString($sb.toString());
275281
|""".stripMargin
276282
})
@@ -339,22 +345,22 @@ case class MaskLastN(
339345
}
340346

341347
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
342-
nullSafeCodeGen(ctx, ev, (input: String) => {
343-
val sb = ctx.freshName("sb")
344-
val length = ctx.freshName("length")
345-
val offset = ctx.freshName("offset")
346-
val inputString = ctx.freshName("inputString")
347-
val startOfMask = ctx.freshName("startOfMask")
348-
s"""
348+
nullSafeCodeGen(ctx, ev, (input: ExprValue) => {
349+
val sb = JavaCode.variable(ctx.freshName("sb"), classOf[StringBuilder])
350+
val length = JavaCode.variable(ctx.freshName("length"), IntegerType)
351+
val offset = JavaCode.variable(ctx.freshName("offset"), IntegerType)
352+
val inputString = JavaCode.variable(ctx.freshName("inputString"), classOf[String])
353+
val startOfMask = JavaCode.variable(ctx.freshName("startOfMask"), IntegerType)
354+
val intType = inline"${CodeGenerator.JAVA_INT}"
355+
code"""
349356
|String $inputString = $input.toString();
350357
|${inputStringLengthCode(inputString, length)}
351-
|${CodeGenerator.JAVA_INT} $startOfMask = $charCount >= $length ?
352-
| 0 : $length - $charCount;
353-
|${CodeGenerator.JAVA_INT} $offset = 0;
358+
|$intType $startOfMask = $charCount >= $length ? 0 : $length - $charCount;
359+
|$intType $offset = 0;
354360
|StringBuilder $sb = new StringBuilder($length);
355361
|${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)}
356362
|${appendMaskedToStringBuilderCode(
357-
ctx, sb, inputString, offset, s"$length - $startOfMask")}
363+
ctx, sb, inputString, offset, code"$length - $startOfMask")}
358364
|${ev.value} = UTF8String.fromString($sb.toString());
359365
|""".stripMargin
360366
})
@@ -423,21 +429,22 @@ case class MaskShowFirstN(
423429
}
424430

425431
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
426-
nullSafeCodeGen(ctx, ev, (input: String) => {
427-
val sb = ctx.freshName("sb")
428-
val length = ctx.freshName("length")
429-
val offset = ctx.freshName("offset")
430-
val inputString = ctx.freshName("inputString")
431-
val startOfMask = ctx.freshName("startOfMask")
432-
s"""
432+
nullSafeCodeGen(ctx, ev, (input: ExprValue) => {
433+
val sb = JavaCode.variable(ctx.freshName("sb"), classOf[StringBuilder])
434+
val length = JavaCode.variable(ctx.freshName("length"), IntegerType)
435+
val offset = JavaCode.variable(ctx.freshName("offset"), IntegerType)
436+
val inputString = JavaCode.variable(ctx.freshName("inputString"), classOf[String])
437+
val startOfMask = JavaCode.variable(ctx.freshName("startOfMask"), IntegerType)
438+
val intType = inline"${CodeGenerator.JAVA_INT}"
439+
code"""
433440
|String $inputString = $input.toString();
434441
|${inputStringLengthCode(inputString, length)}
435-
|${CodeGenerator.JAVA_INT} $startOfMask = $charCount > $length ? $length : $charCount;
436-
|${CodeGenerator.JAVA_INT} $offset = 0;
442+
|$intType $startOfMask = $charCount > $length ? $length : $charCount;
443+
|$intType $offset = 0;
437444
|StringBuilder $sb = new StringBuilder($length);
438445
|${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)}
439446
|${appendMaskedToStringBuilderCode(
440-
ctx, sb, inputString, offset, s"$length - $startOfMask")}
447+
ctx, sb, inputString, offset, code"$length - $startOfMask")}
441448
|${ev.value} = UTF8String.fromString($sb.toString());
442449
|""".stripMargin
443450
})
@@ -506,21 +513,22 @@ case class MaskShowLastN(
506513
}
507514

508515
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
509-
nullSafeCodeGen(ctx, ev, (input: String) => {
510-
val sb = ctx.freshName("sb")
511-
val length = ctx.freshName("length")
512-
val offset = ctx.freshName("offset")
513-
val inputString = ctx.freshName("inputString")
514-
val endOfMask = ctx.freshName("endOfMask")
515-
s"""
516+
nullSafeCodeGen(ctx, ev, (input: ExprValue) => {
517+
val sb = JavaCode.variable(ctx.freshName("sb"), classOf[StringBuilder])
518+
val length = JavaCode.variable(ctx.freshName("length"), IntegerType)
519+
val offset = JavaCode.variable(ctx.freshName("offset"), IntegerType)
520+
val inputString = JavaCode.variable(ctx.freshName("inputString"), classOf[String])
521+
val endOfMask = JavaCode.variable(ctx.freshName("endOfMask"), IntegerType)
522+
val intType = inline"${CodeGenerator.JAVA_INT}"
523+
code"""
516524
|String $inputString = $input.toString();
517525
|${inputStringLengthCode(inputString, length)}
518-
|${CodeGenerator.JAVA_INT} $endOfMask = $charCount >= $length ? 0 : $length - $charCount;
519-
|${CodeGenerator.JAVA_INT} $offset = 0;
526+
|$intType $endOfMask = $charCount >= $length ? 0 : $length - $charCount;
527+
|$intType $offset = 0;
520528
|StringBuilder $sb = new StringBuilder($length);
521529
|${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)}
522530
|${appendUnchangedToStringBuilderCode(
523-
ctx, sb, inputString, offset, s"$length - $endOfMask")}
531+
ctx, sb, inputString, offset, code"$length - $endOfMask")}
524532
|${ev.value} = UTF8String.fromString($sb.toString());
525533
|""".stripMargin
526534
})
@@ -553,9 +561,9 @@ case class MaskHash(child: Expression)
553561
}
554562

555563
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
556-
nullSafeCodeGen(ctx, ev, (input: String) => {
557-
val digestUtilsClass = classOf[DigestUtils].getName.stripSuffix("$")
558-
s"""
564+
nullSafeCodeGen(ctx, ev, (input: ExprValue) => {
565+
val digestUtilsClass = inline"${classOf[DigestUtils].getName.stripSuffix("$")}"
566+
code"""
559567
|${ev.value} = UTF8String.fromString($digestUtilsClass.md5Hex($input.toString()));
560568
|""".stripMargin
561569
})

0 commit comments

Comments
 (0)