Skip to content

Commit c5f745e

Browse files
cloud-fandavies
authored andcommitted
[SPARK-13072] [SQL] simplify and improve murmur3 hash expression codegen
simplify(remove several unnecessary local variables) the generated code of hash expression, and avoid null check if possible. generated code comparison for `hash(int, double, string, array<string>)`: **before:** ``` public UnsafeRow apply(InternalRow i) { /* hash(input[0, int],input[1, double],input[2, string],input[3, array<int>],42) */ int value1 = 42; /* input[0, int] */ int value3 = i.getInt(0); if (!false) { value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(value3, value1); } /* input[1, double] */ double value5 = i.getDouble(1); if (!false) { value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashLong(Double.doubleToLongBits(value5), value1); } /* input[2, string] */ boolean isNull6 = i.isNullAt(2); UTF8String value7 = isNull6 ? null : (i.getUTF8String(2)); if (!isNull6) { value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(value7.getBaseObject(), value7.getBaseOffset(), value7.numBytes(), value1); } /* input[3, array<int>] */ boolean isNull8 = i.isNullAt(3); ArrayData value9 = isNull8 ? null : (i.getArray(3)); if (!isNull8) { int result10 = value1; for (int index11 = 0; index11 < value9.numElements(); index11++) { if (!value9.isNullAt(index11)) { final int element12 = value9.getInt(index11); result10 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(element12, result10); } } value1 = result10; } } ``` **after:** ``` public UnsafeRow apply(InternalRow i) { /* hash(input[0, int],input[1, double],input[2, string],input[3, array<int>],42) */ int value1 = 42; /* input[0, int] */ int value3 = i.getInt(0); value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(value3, value1); /* input[1, double] */ double value5 = i.getDouble(1); value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashLong(Double.doubleToLongBits(value5), value1); /* input[2, string] */ boolean isNull6 = i.isNullAt(2); UTF8String value7 = isNull6 ? null : (i.getUTF8String(2)); if (!isNull6) { value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(value7.getBaseObject(), value7.getBaseOffset(), value7.numBytes(), value1); } /* input[3, array<int>] */ boolean isNull8 = i.isNullAt(3); ArrayData value9 = isNull8 ? null : (i.getArray(3)); if (!isNull8) { for (int index10 = 0; index10 < value9.numElements(); index10++) { final int element11 = value9.getInt(index10); value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(element11, value1); } } rowWriter14.write(0, value1); return result12; } ``` Author: Wenchen Fan <[email protected]> Closes #10974 from cloud-fan/codegen.
1 parent e4c1162 commit c5f745e

File tree

1 file changed

+69
-86
lines changed
  • sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions

1 file changed

+69
-86
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala

Lines changed: 69 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -325,36 +325,62 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
325325

326326
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
327327
ev.isNull = "false"
328-
val childrenHash = children.zipWithIndex.map {
329-
case (child, dt) =>
330-
val childGen = child.gen(ctx)
331-
val childHash = computeHash(childGen.value, child.dataType, ev.value, ctx)
332-
s"""
333-
${childGen.code}
334-
if (!${childGen.isNull}) {
335-
${childHash.code}
336-
${ev.value} = ${childHash.value};
337-
}
338-
"""
328+
val childrenHash = children.map { child =>
329+
val childGen = child.gen(ctx)
330+
childGen.code + generateNullCheck(child.nullable, childGen.isNull) {
331+
computeHash(childGen.value, child.dataType, ev.value, ctx)
332+
}
339333
}.mkString("\n")
334+
340335
s"""
341336
int ${ev.value} = $seed;
342337
$childrenHash
343338
"""
344339
}
345340

341+
private def generateNullCheck(nullable: Boolean, isNull: String)(execution: String): String = {
342+
if (nullable) {
343+
s"""
344+
if (!$isNull) {
345+
$execution
346+
}
347+
"""
348+
} else {
349+
"\n" + execution
350+
}
351+
}
352+
353+
private def nullSafeElementHash(
354+
input: String,
355+
index: String,
356+
nullable: Boolean,
357+
elementType: DataType,
358+
result: String,
359+
ctx: CodegenContext): String = {
360+
val element = ctx.freshName("element")
361+
362+
generateNullCheck(nullable, s"$input.isNullAt($index)") {
363+
s"""
364+
final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)};
365+
${computeHash(element, elementType, result, ctx)}
366+
"""
367+
}
368+
}
369+
346370
private def computeHash(
347371
input: String,
348372
dataType: DataType,
349-
seed: String,
350-
ctx: CodegenContext): ExprCode = {
373+
result: String,
374+
ctx: CodegenContext): String = {
351375
val hasher = classOf[Murmur3_x86_32].getName
352-
def hashInt(i: String): ExprCode = inlineValue(s"$hasher.hashInt($i, $seed)")
353-
def hashLong(l: String): ExprCode = inlineValue(s"$hasher.hashLong($l, $seed)")
354-
def inlineValue(v: String): ExprCode = ExprCode(code = "", isNull = "false", value = v)
376+
377+
def hashInt(i: String): String = s"$result = $hasher.hashInt($i, $result);"
378+
def hashLong(l: String): String = s"$result = $hasher.hashLong($l, $result);"
379+
def hashBytes(b: String): String =
380+
s"$result = $hasher.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length, $result);"
355381

356382
dataType match {
357-
case NullType => inlineValue(seed)
383+
case NullType => ""
358384
case BooleanType => hashInt(s"$input ? 1 : 0")
359385
case ByteType | ShortType | IntegerType | DateType => hashInt(input)
360386
case LongType | TimestampType => hashLong(input)
@@ -365,91 +391,48 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
365391
hashLong(s"$input.toUnscaledLong()")
366392
} else {
367393
val bytes = ctx.freshName("bytes")
368-
val code = s"byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();"
369-
val offset = "Platform.BYTE_ARRAY_OFFSET"
370-
val result = s"$hasher.hashUnsafeBytes($bytes, $offset, $bytes.length, $seed)"
371-
ExprCode(code, "false", result)
394+
s"""
395+
final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();
396+
${hashBytes(bytes)}
397+
"""
372398
}
373399
case CalendarIntervalType =>
374-
val microsecondsHash = s"$hasher.hashLong($input.microseconds, $seed)"
375-
val monthsHash = s"$hasher.hashInt($input.months, $microsecondsHash)"
376-
inlineValue(monthsHash)
377-
case BinaryType =>
378-
val offset = "Platform.BYTE_ARRAY_OFFSET"
379-
inlineValue(s"$hasher.hashUnsafeBytes($input, $offset, $input.length, $seed)")
400+
val microsecondsHash = s"$hasher.hashLong($input.microseconds, $result)"
401+
s"$result = $hasher.hashInt($input.months, $microsecondsHash);"
402+
case BinaryType => hashBytes(input)
380403
case StringType =>
381404
val baseObject = s"$input.getBaseObject()"
382405
val baseOffset = s"$input.getBaseOffset()"
383406
val numBytes = s"$input.numBytes()"
384-
inlineValue(s"$hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $seed)")
407+
s"$result = $hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);"
385408

386-
case ArrayType(et, _) =>
387-
val result = ctx.freshName("result")
409+
case ArrayType(et, containsNull) =>
388410
val index = ctx.freshName("index")
389-
val element = ctx.freshName("element")
390-
val elementHash = computeHash(element, et, result, ctx)
391-
val code =
392-
s"""
393-
int $result = $seed;
394-
for (int $index = 0; $index < $input.numElements(); $index++) {
395-
if (!$input.isNullAt($index)) {
396-
final ${ctx.javaType(et)} $element = ${ctx.getValue(input, et, index)};
397-
${elementHash.code}
398-
$result = ${elementHash.value};
399-
}
400-
}
401-
"""
402-
ExprCode(code, "false", result)
411+
s"""
412+
for (int $index = 0; $index < $input.numElements(); $index++) {
413+
${nullSafeElementHash(input, index, containsNull, et, result, ctx)}
414+
}
415+
"""
403416

404-
case MapType(kt, vt, _) =>
405-
val result = ctx.freshName("result")
417+
case MapType(kt, vt, valueContainsNull) =>
406418
val index = ctx.freshName("index")
407419
val keys = ctx.freshName("keys")
408420
val values = ctx.freshName("values")
409-
val key = ctx.freshName("key")
410-
val value = ctx.freshName("value")
411-
val keyHash = computeHash(key, kt, result, ctx)
412-
val valueHash = computeHash(value, vt, result, ctx)
413-
val code =
414-
s"""
415-
int $result = $seed;
416-
final ArrayData $keys = $input.keyArray();
417-
final ArrayData $values = $input.valueArray();
418-
for (int $index = 0; $index < $input.numElements(); $index++) {
419-
final ${ctx.javaType(kt)} $key = ${ctx.getValue(keys, kt, index)};
420-
${keyHash.code}
421-
$result = ${keyHash.value};
422-
if (!$values.isNullAt($index)) {
423-
final ${ctx.javaType(vt)} $value = ${ctx.getValue(values, vt, index)};
424-
${valueHash.code}
425-
$result = ${valueHash.value};
426-
}
427-
}
428-
"""
429-
ExprCode(code, "false", result)
421+
s"""
422+
final ArrayData $keys = $input.keyArray();
423+
final ArrayData $values = $input.valueArray();
424+
for (int $index = 0; $index < $input.numElements(); $index++) {
425+
${nullSafeElementHash(keys, index, false, kt, result, ctx)}
426+
${nullSafeElementHash(values, index, valueContainsNull, vt, result, ctx)}
427+
}
428+
"""
430429

431430
case StructType(fields) =>
432-
val result = ctx.freshName("result")
433-
val fieldsHash = fields.map(_.dataType).zipWithIndex.map {
434-
case (dt, index) =>
435-
val field = ctx.freshName("field")
436-
val fieldHash = computeHash(field, dt, result, ctx)
437-
s"""
438-
if (!$input.isNullAt($index)) {
439-
final ${ctx.javaType(dt)} $field = ${ctx.getValue(input, dt, index.toString)};
440-
${fieldHash.code}
441-
$result = ${fieldHash.value};
442-
}
443-
"""
431+
fields.zipWithIndex.map { case (field, index) =>
432+
nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx)
444433
}.mkString("\n")
445-
val code =
446-
s"""
447-
int $result = $seed;
448-
$fieldsHash
449-
"""
450-
ExprCode(code, "false", result)
451434

452-
case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, seed, ctx)
435+
case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, result, ctx)
453436
}
454437
}
455438
}

0 commit comments

Comments
 (0)