Skip to content

Commit 5b965d6

Browse files
Davies Liudavies
authored andcommitted
[SPARK-9644] [SQL] Support update DecimalType with precision > 18 in UnsafeRow
In order to support update a varlength (actually fixed length) object, the space should be preserved even it's null. And, we can't call setNullAt(i) for it anymore, we because setNullAt(i) will remove the offset of the preserved space, should call setDecimal(i, null, precision) instead. After this, we can do hash based aggregation on DecimalType with precision > 18. In a tests, this could decrease the end-to-end run time of aggregation query from 37 seconds (sort based) to 24 seconds (hash based). cc rxin Author: Davies Liu <[email protected]> Closes apache#7978 from davies/update_decimal and squashes the following commits: bed8100 [Davies Liu] isSettable -> isMutable 923c9eb [Davies Liu] address comments and fix bug 385891d [Davies Liu] Merge branch 'master' of github.com:apache/spark into update_decimal 36a1872 [Davies Liu] fix tests cd6c524 [Davies Liu] support set decimal with precision > 18
1 parent aead18f commit 5b965d6

File tree

10 files changed

+183
-61
lines changed

10 files changed

+183
-61
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,11 @@ public static int calculateBitSetWidthInBytes(int numFields) {
6565
/**
6666
* Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types)
6767
*/
68-
public static final Set<DataType> settableFieldTypes;
68+
public static final Set<DataType> mutableFieldTypes;
6969

70-
// DecimalType(precision <= 18) is settable
70+
// DecimalType is also mutable
7171
static {
72-
settableFieldTypes = Collections.unmodifiableSet(
72+
mutableFieldTypes = Collections.unmodifiableSet(
7373
new HashSet<>(
7474
Arrays.asList(new DataType[] {
7575
NullType,
@@ -87,12 +87,16 @@ public static int calculateBitSetWidthInBytes(int numFields) {
8787

8888
public static boolean isFixedLength(DataType dt) {
8989
if (dt instanceof DecimalType) {
90-
return ((DecimalType) dt).precision() < Decimal.MAX_LONG_DIGITS();
90+
return ((DecimalType) dt).precision() <= Decimal.MAX_LONG_DIGITS();
9191
} else {
92-
return settableFieldTypes.contains(dt);
92+
return mutableFieldTypes.contains(dt);
9393
}
9494
}
9595

96+
public static boolean isMutable(DataType dt) {
97+
return mutableFieldTypes.contains(dt) || dt instanceof DecimalType;
98+
}
99+
96100
//////////////////////////////////////////////////////////////////////////////
97101
// Private fields and methods
98102
//////////////////////////////////////////////////////////////////////////////
@@ -238,17 +242,45 @@ public void setFloat(int ordinal, float value) {
238242
PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value);
239243
}
240244

245+
/**
246+
* Updates the decimal column.
247+
*
248+
* Note: In order to support update a decimal with precision > 18, CAN NOT call
249+
* setNullAt() for this column.
250+
*/
241251
@Override
242252
public void setDecimal(int ordinal, Decimal value, int precision) {
243253
assertIndexIsValid(ordinal);
244-
if (value == null) {
245-
setNullAt(ordinal);
246-
} else {
247-
if (precision <= Decimal.MAX_LONG_DIGITS()) {
254+
if (precision <= Decimal.MAX_LONG_DIGITS()) {
255+
// compact format
256+
if (value == null) {
257+
setNullAt(ordinal);
258+
} else {
248259
setLong(ordinal, value.toUnscaledLong());
260+
}
261+
} else {
262+
// fixed length
263+
long cursor = getLong(ordinal) >>> 32;
264+
assert cursor > 0 : "invalid cursor " + cursor;
265+
// zero-out the bytes
266+
PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + cursor, 0L);
267+
PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + cursor + 8, 0L);
268+
269+
if (value == null) {
270+
setNullAt(ordinal);
271+
// keep the offset for future update
272+
PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(ordinal), cursor << 32);
249273
} else {
250-
// TODO(davies): support update decimal (hold a bounded space even it's null)
251-
throw new UnsupportedOperationException();
274+
275+
final BigInteger integer = value.toJavaBigDecimal().unscaledValue();
276+
final int[] mag = (int[]) PlatformDependent.UNSAFE.getObjectVolatile(integer,
277+
PlatformDependent.BIG_INTEGER_MAG_OFFSET);
278+
assert(mag.length <= 4);
279+
280+
// Write the bytes to the variable length portion.
281+
PlatformDependent.copyMemory(mag, PlatformDependent.INT_ARRAY_OFFSET,
282+
baseObject, baseOffset + cursor, mag.length * 4);
283+
setLong(ordinal, (cursor << 32) | ((long) (((integer.signum() + 1) << 8) + mag.length)));
252284
}
253285
}
254286
}
@@ -343,6 +375,8 @@ public double getDouble(int ordinal) {
343375
return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal));
344376
}
345377

378+
private static byte[] EMPTY = new byte[0];
379+
346380
@Override
347381
public Decimal getDecimal(int ordinal, int precision, int scale) {
348382
if (isNullAt(ordinal)) {
@@ -351,10 +385,20 @@ public Decimal getDecimal(int ordinal, int precision, int scale) {
351385
if (precision <= Decimal.MAX_LONG_DIGITS()) {
352386
return Decimal.apply(getLong(ordinal), precision, scale);
353387
} else {
354-
byte[] bytes = getBinary(ordinal);
355-
BigInteger bigInteger = new BigInteger(bytes);
356-
BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
357-
return Decimal.apply(new scala.math.BigDecimal(javaDecimal), precision, scale);
388+
long offsetAndSize = getLong(ordinal);
389+
long offset = offsetAndSize >>> 32;
390+
int signum = ((int) (offsetAndSize & 0xfff) >> 8);
391+
assert signum >=0 && signum <= 2 : "invalid signum " + signum;
392+
int size = (int) (offsetAndSize & 0xff);
393+
int[] mag = new int[size];
394+
PlatformDependent.copyMemory(baseObject, baseOffset + offset,
395+
mag, PlatformDependent.INT_ARRAY_OFFSET, size * 4);
396+
397+
// create a BigInteger using signum and mag
398+
BigInteger v = new BigInteger(0, EMPTY); // create the initial object
399+
PlatformDependent.UNSAFE.putInt(v, PlatformDependent.BIG_INTEGER_SIGNUM_OFFSET, signum - 1);
400+
PlatformDependent.UNSAFE.putObjectVolatile(v, PlatformDependent.BIG_INTEGER_MAG_OFFSET, mag);
401+
return Decimal.apply(new BigDecimal(v, scale), precision, scale);
358402
}
359403
}
360404

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions;
1919

20+
import java.math.BigInteger;
21+
2022
import org.apache.spark.sql.catalyst.InternalRow;
2123
import org.apache.spark.sql.types.Decimal;
22-
import org.apache.spark.sql.types.MapData;
2324
import org.apache.spark.unsafe.PlatformDependent;
2425
import org.apache.spark.unsafe.array.ByteArrayMethods;
2526
import org.apache.spark.unsafe.types.ByteArray;
@@ -47,29 +48,41 @@ public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input
4748

4849
/** Writer for Decimal with precision larger than 18. */
4950
public static class DecimalWriter {
50-
51+
private static final int SIZE = 16;
5152
public static int getSize(Decimal input) {
5253
// bounded size
53-
return 16;
54+
return SIZE;
5455
}
5556

5657
public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) {
58+
final Object base = target.getBaseObject();
5759
final long offset = target.getBaseOffset() + cursor;
58-
final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray();
59-
final int numBytes = bytes.length;
60-
assert(numBytes <= 16);
61-
6260
// zero-out the bytes
63-
PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, 0L);
64-
PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, 0L);
61+
PlatformDependent.UNSAFE.putLong(base, offset, 0L);
62+
PlatformDependent.UNSAFE.putLong(base, offset + 8, 0L);
63+
64+
if (input == null) {
65+
target.setNullAt(ordinal);
66+
// keep the offset and length for update
67+
int fieldOffset = UnsafeRow.calculateBitSetWidthInBytes(target.numFields()) + ordinal * 8;
68+
PlatformDependent.UNSAFE.putLong(base, target.getBaseOffset() + fieldOffset,
69+
((long) cursor) << 32);
70+
return SIZE;
71+
}
6572

66-
// Write the bytes to the variable length portion.
67-
PlatformDependent.copyMemory(bytes, PlatformDependent.BYTE_ARRAY_OFFSET,
68-
target.getBaseObject(), offset, numBytes);
73+
final BigInteger integer = input.toJavaBigDecimal().unscaledValue();
74+
int signum = integer.signum() + 1;
75+
final int[] mag = (int[]) PlatformDependent.UNSAFE.getObjectVolatile(integer,
76+
PlatformDependent.BIG_INTEGER_MAG_OFFSET);
77+
assert(mag.length <= 4);
6978

79+
// Write the bytes to the variable length portion.
80+
PlatformDependent.copyMemory(mag, PlatformDependent.INT_ARRAY_OFFSET,
81+
base, target.getBaseOffset() + cursor, mag.length * 4);
7082
// Set the fixed length portion.
71-
target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes));
72-
return 16;
83+
target.setLong(ordinal, (((long) cursor) << 32) | ((long) ((signum << 8) + mag.length)));
84+
85+
return SIZE;
7386
}
7487
}
7588

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer
2121

2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
24+
import org.apache.spark.sql.types.DecimalType
2425

2526
// MutableProjection is not accessible in Java
2627
abstract class BaseMutableProjection extends MutableProjection
@@ -43,14 +44,26 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
4344
case (NoOp, _) => ""
4445
case (e, i) =>
4546
val evaluationCode = e.gen(ctx)
46-
evaluationCode.code +
47+
if (e.dataType.isInstanceOf[DecimalType]) {
48+
// Can't call setNullAt on DecimalType, because we need to keep the offset
4749
s"""
50+
${evaluationCode.code}
51+
if (${evaluationCode.isNull}) {
52+
${ctx.setColumn("mutableRow", e.dataType, i, null)};
53+
} else {
54+
${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)};
55+
}
56+
"""
57+
} else {
58+
s"""
59+
${evaluationCode.code}
4860
if (${evaluationCode.isNull}) {
4961
mutableRow.setNullAt($i);
5062
} else {
5163
${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)};
5264
}
5365
"""
66+
}
5467
}
5568
// collect projections into blocks as function has 64kb codesize limit in JVM
5669
val projectionBlocks = new ArrayBuffer[String]()

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,18 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
4545

4646
/** Returns true iff we support this data type. */
4747
def canSupport(dataType: DataType): Boolean = dataType match {
48+
case NullType => true
4849
case t: AtomicType => true
4950
case _: CalendarIntervalType => true
5051
case t: StructType => t.toSeq.forall(field => canSupport(field.dataType))
51-
case NullType => true
5252
case t: ArrayType if canSupport(t.elementType) => true
5353
case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true
5454
case _ => false
5555
}
5656

5757
def genAdditionalSize(dt: DataType, ev: GeneratedExpressionCode): String = dt match {
5858
case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
59-
s" + (${ev.isNull} ? 0 : $DecimalWriter.getSize(${ev.primitive}))"
59+
s" + $DecimalWriter.getSize(${ev.primitive})"
6060
case StringType =>
6161
s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))"
6262
case BinaryType =>
@@ -76,41 +76,41 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
7676
ctx: CodeGenContext,
7777
fieldType: DataType,
7878
ev: GeneratedExpressionCode,
79-
primitive: String,
79+
target: String,
8080
index: Int,
8181
cursor: String): String = fieldType match {
8282
case _ if ctx.isPrimitiveType(fieldType) =>
83-
s"${ctx.setColumn(primitive, fieldType, index, ev.primitive)}"
83+
s"${ctx.setColumn(target, fieldType, index, ev.primitive)}"
8484
case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
8585
s"""
8686
// make sure Decimal object has the same scale as DecimalType
8787
if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) {
88-
$CompactDecimalWriter.write($primitive, $index, $cursor, ${ev.primitive});
88+
$CompactDecimalWriter.write($target, $index, $cursor, ${ev.primitive});
8989
} else {
90-
$primitive.setNullAt($index);
90+
$target.setNullAt($index);
9191
}
9292
"""
9393
case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
9494
s"""
9595
// make sure Decimal object has the same scale as DecimalType
9696
if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) {
97-
$cursor += $DecimalWriter.write($primitive, $index, $cursor, ${ev.primitive});
97+
$cursor += $DecimalWriter.write($target, $index, $cursor, ${ev.primitive});
9898
} else {
99-
$primitive.setNullAt($index);
99+
$cursor += $DecimalWriter.write($target, $index, $cursor, null);
100100
}
101101
"""
102102
case StringType =>
103-
s"$cursor += $StringWriter.write($primitive, $index, $cursor, ${ev.primitive})"
103+
s"$cursor += $StringWriter.write($target, $index, $cursor, ${ev.primitive})"
104104
case BinaryType =>
105-
s"$cursor += $BinaryWriter.write($primitive, $index, $cursor, ${ev.primitive})"
105+
s"$cursor += $BinaryWriter.write($target, $index, $cursor, ${ev.primitive})"
106106
case CalendarIntervalType =>
107-
s"$cursor += $IntervalWriter.write($primitive, $index, $cursor, ${ev.primitive})"
107+
s"$cursor += $IntervalWriter.write($target, $index, $cursor, ${ev.primitive})"
108108
case _: StructType =>
109-
s"$cursor += $StructWriter.write($primitive, $index, $cursor, ${ev.primitive})"
109+
s"$cursor += $StructWriter.write($target, $index, $cursor, ${ev.primitive})"
110110
case _: ArrayType =>
111-
s"$cursor += $ArrayWriter.write($primitive, $index, $cursor, ${ev.primitive})"
111+
s"$cursor += $ArrayWriter.write($target, $index, $cursor, ${ev.primitive})"
112112
case _: MapType =>
113-
s"$cursor += $MapWriter.write($primitive, $index, $cursor, ${ev.primitive})"
113+
s"$cursor += $MapWriter.write($target, $index, $cursor, ${ev.primitive})"
114114
case NullType => ""
115115
case _ =>
116116
throw new UnsupportedOperationException(s"Not supported DataType: $fieldType")
@@ -146,13 +146,24 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
146146

147147
val fieldWriters = inputTypes.zip(convertedFields).zipWithIndex.map { case ((dt, ev), i) =>
148148
val update = genFieldWriter(ctx, dt, ev, output, i, cursor)
149-
s"""
150-
if (${ev.isNull}) {
151-
$output.setNullAt($i);
152-
} else {
153-
$update;
154-
}
155-
"""
149+
if (dt.isInstanceOf[DecimalType]) {
150+
// Can't call setNullAt() for DecimalType
151+
s"""
152+
if (${ev.isNull}) {
153+
$cursor += $DecimalWriter.write($output, $i, $cursor, null);
154+
} else {
155+
$update;
156+
}
157+
"""
158+
} else {
159+
s"""
160+
if (${ev.isNull}) {
161+
$output.setNullAt($i);
162+
} else {
163+
$update;
164+
}
165+
"""
166+
}
156167
}.mkString("\n")
157168

158169
val code = s"""

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions
2020
import org.apache.spark.sql.Row
2121
import org.apache.spark.sql.catalyst.InternalRow
2222
import org.apache.spark.sql.types._
23-
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
2423

2524
/**
2625
* An extended interface to [[InternalRow]] that allows the values for each column to be updated.
@@ -39,6 +38,13 @@ abstract class MutableRow extends InternalRow {
3938
def setLong(i: Int, value: Long): Unit = { update(i, value) }
4039
def setFloat(i: Int, value: Float): Unit = { update(i, value) }
4140
def setDouble(i: Int, value: Double): Unit = { update(i, value) }
41+
42+
/**
43+
* Update the decimal column at `i`.
44+
*
45+
* Note: In order to support update decimal with precision > 18 in UnsafeRow,
46+
* CAN NOT call setNullAt() for decimal column on UnsafeRow, call setDecimal(i, null, precision).
47+
*/
4248
def setDecimal(i: Int, value: Decimal, precision: Int) { update(i, value) }
4349
}
4450

0 commit comments

Comments
 (0)