Skip to content

Commit c4da534

Browse files
Davies Liuliancheng
authored andcommitted
[SPARK-10990] [SPARK-11018] [SQL] improve unrolling of complex types
This PR improve the unrolling and read of complex types in columnar cache: 1) Using UnsafeProjection to do serialization of complex types, so they will not be serialized three times (two for actualSize) 2) Copy the bytes from UnsafeRow/UnsafeArrayData to ByteBuffer directly, avoiding the immediate byte[] 3) Using the underlying array in ByteBuffer to create UTF8String/UnsafeRow/UnsafeArrayData without copy. Combine these optimizations, we can reduce the unrolling time from 25s to 21s (20% less), reduce the scanning time from 3.5s to 2.5s (28% less). ``` df = sqlContext.read.parquet(path) t = time.time() df.cache() df.count() print 'unrolling', time.time() - t for i in range(10): t = time.time() print df.select("*")._jdf.queryExecution().toRdd().count() print time.time() - t ``` The schema is ``` root |-- a: struct (nullable = true) | |-- b: long (nullable = true) | |-- c: string (nullable = true) |-- d: array (nullable = true) | |-- element: long (containsNull = true) |-- e: map (nullable = true) | |-- key: long | |-- value: string (valueContainsNull = true) ``` Now the columnar cache depends on that UnsafeProjection support all the data types (including UDT), this PR also fix that. Author: Davies Liu <[email protected]> Closes #9016 from davies/complex2.
1 parent f97e932 commit c4da534

File tree

12 files changed

+188
-140
lines changed

12 files changed

+188
-140
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import java.math.BigDecimal;
2121
import java.math.BigInteger;
22+
import java.nio.ByteBuffer;
2223

2324
import org.apache.spark.sql.types.*;
2425
import org.apache.spark.unsafe.Platform;
@@ -145,6 +146,8 @@ public Object get(int ordinal, DataType dataType) {
145146
return getArray(ordinal);
146147
} else if (dataType instanceof MapType) {
147148
return getMap(ordinal);
149+
} else if (dataType instanceof UserDefinedType) {
150+
return get(ordinal, ((UserDefinedType)dataType).sqlType());
148151
} else {
149152
throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString());
150153
}
@@ -306,6 +309,15 @@ public void writeToMemory(Object target, long targetOffset) {
306309
Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes);
307310
}
308311

312+
public void writeTo(ByteBuffer buffer) {
313+
assert(buffer.hasArray());
314+
byte[] target = buffer.array();
315+
int offset = buffer.arrayOffset();
316+
int pos = buffer.position();
317+
writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos);
318+
buffer.position(pos + sizeInBytes);
319+
}
320+
309321
@Override
310322
public UnsafeArrayData copy() {
311323
UnsafeArrayData arrayCopy = new UnsafeArrayData();

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.io.*;
2121
import java.math.BigDecimal;
2222
import java.math.BigInteger;
23+
import java.nio.ByteBuffer;
2324
import java.util.Arrays;
2425
import java.util.Collections;
2526
import java.util.HashSet;
@@ -326,6 +327,8 @@ public Object get(int ordinal, DataType dataType) {
326327
return getArray(ordinal);
327328
} else if (dataType instanceof MapType) {
328329
return getMap(ordinal);
330+
} else if (dataType instanceof UserDefinedType) {
331+
return get(ordinal, ((UserDefinedType)dataType).sqlType());
329332
} else {
330333
throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString());
331334
}
@@ -602,6 +605,15 @@ public void writeToMemory(Object target, long targetOffset) {
602605
Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes);
603606
}
604607

608+
public void writeTo(ByteBuffer buffer) {
609+
assert (buffer.hasArray());
610+
byte[] target = buffer.array();
611+
int offset = buffer.arrayOffset();
612+
int pos = buffer.position();
613+
writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos);
614+
buffer.position(pos + sizeInBytes);
615+
}
616+
605617
@Override
606618
public void writeExternal(ObjectOutput out) throws IOException {
607619
byte[] bytes = getBytes();

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ class CodeGenContext {
129129
case _: ArrayType => s"$input.getArray($ordinal)"
130130
case _: MapType => s"$input.getMap($ordinal)"
131131
case NullType => "null"
132+
case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal)
132133
case _ => s"($jt)$input.get($ordinal, null)"
133134
}
134135
}
@@ -143,6 +144,7 @@ class CodeGenContext {
143144
case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})"
144145
// The UTF8String may came from UnsafeRow, otherwise clone is cheap (re-use the bytes)
145146
case StringType => s"$row.update($ordinal, $value.clone())"
147+
case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value)
146148
case _ => s"$row.update($ordinal, $value)"
147149
}
148150
}
@@ -177,6 +179,7 @@ class CodeGenContext {
177179
case _: MapType => "MapData"
178180
case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName
179181
case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName
182+
case udt: UserDefinedType[_] => javaType(udt.sqlType)
180183
case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]"
181184
case ObjectType(cls) => cls.getName
182185
case _ => "Object"
@@ -222,6 +225,7 @@ class CodeGenContext {
222225
case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2"
223226
case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2"
224227
case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2"
228+
case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2)
225229
case other => s"$c1.equals($c2)"
226230
}
227231

@@ -255,6 +259,7 @@ class CodeGenContext {
255259
addNewFunction(compareFunc, funcCode)
256260
s"this.$compareFunc($c1, $c2)"
257261
case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)"
262+
case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2)
258263
case _ =>
259264
throw new IllegalArgumentException("cannot generate compare code for un-comparable type")
260265
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
124124
case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType)
125125
// UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe.
126126
case StringType => GeneratedExpressionCode("", "false", s"$input.clone()")
127+
case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType)
127128
case _ => GeneratedExpressionCode("", "false", input)
128129
}
129130

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
3939
case t: StructType => t.toSeq.forall(field => canSupport(field.dataType))
4040
case t: ArrayType if canSupport(t.elementType) => true
4141
case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true
42+
case dt: OpenHashSetUDT => false // it's not a standard UDT
43+
case udt: UserDefinedType[_] => canSupport(udt.sqlType)
4244
case _ => false
4345
}
4446

@@ -77,7 +79,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
7779
ctx.addMutableState(rowWriterClass, rowWriter, s"this.$rowWriter = new $rowWriterClass();")
7880

7981
val writeFields = inputs.zip(inputTypes).zipWithIndex.map {
80-
case ((input, dt), index) =>
82+
case ((input, dataType), index) =>
83+
val dt = dataType match {
84+
case udt: UserDefinedType[_] => udt.sqlType
85+
case other => other
86+
}
8187
val tmpCursor = ctx.freshName("tmpCursor")
8288

8389
val setNull = dt match {
@@ -167,15 +173,20 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
167173
val index = ctx.freshName("index")
168174
val element = ctx.freshName("element")
169175

170-
val jt = ctx.javaType(elementType)
176+
val et = elementType match {
177+
case udt: UserDefinedType[_] => udt.sqlType
178+
case other => other
179+
}
180+
181+
val jt = ctx.javaType(et)
171182

172-
val fixedElementSize = elementType match {
183+
val fixedElementSize = et match {
173184
case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => 8
174-
case _ if ctx.isPrimitiveType(jt) => elementType.defaultSize
185+
case _ if ctx.isPrimitiveType(jt) => et.defaultSize
175186
case _ => 0
176187
}
177188

178-
val writeElement = elementType match {
189+
val writeElement = et match {
179190
case t: StructType =>
180191
s"""
181192
$arrayWriter.setOffset($index);
@@ -194,13 +205,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
194205
${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)}
195206
"""
196207

197-
case _ if ctx.isPrimitiveType(elementType) =>
208+
case _ if ctx.isPrimitiveType(et) =>
198209
// Should we do word align?
199-
val dataSize = elementType.defaultSize
210+
val dataSize = et.defaultSize
200211

201212
s"""
202213
$arrayWriter.setOffset($index);
203-
${writePrimitiveType(ctx, element, elementType,
214+
${writePrimitiveType(ctx, element, et,
204215
s"$bufferHolder.buffer", s"$bufferHolder.cursor")}
205216
$bufferHolder.cursor += $dataSize;
206217
"""
@@ -237,7 +248,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
237248
if ($input.isNullAt($index)) {
238249
$arrayWriter.setNullAt($index);
239250
} else {
240-
final $jt $element = ${ctx.getValue(input, elementType, index)};
251+
final $jt $element = ${ctx.getValue(input, et, index)};
241252
$writeElement
242253
}
243254
}

sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ package org.apache.spark.sql.columnar
1919

2020
import java.nio.{ByteBuffer, ByteOrder}
2121

22-
import org.apache.spark.sql.catalyst.InternalRow
23-
import org.apache.spark.sql.catalyst.expressions.MutableRow
22+
import org.apache.spark.sql.catalyst.expressions.{MutableRow, UnsafeArrayData, UnsafeMapData, UnsafeRow}
2423
import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor
2524
import org.apache.spark.sql.types._
2625

@@ -109,15 +108,15 @@ private[sql] class DecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalTy
109108
with NullableColumnAccessor
110109

111110
private[sql] class StructColumnAccessor(buffer: ByteBuffer, dataType: StructType)
112-
extends BasicColumnAccessor[InternalRow](buffer, STRUCT(dataType))
111+
extends BasicColumnAccessor[UnsafeRow](buffer, STRUCT(dataType))
113112
with NullableColumnAccessor
114113

115114
private[sql] class ArrayColumnAccessor(buffer: ByteBuffer, dataType: ArrayType)
116-
extends BasicColumnAccessor[ArrayData](buffer, ARRAY(dataType))
115+
extends BasicColumnAccessor[UnsafeArrayData](buffer, ARRAY(dataType))
117116
with NullableColumnAccessor
118117

119118
private[sql] class MapColumnAccessor(buffer: ByteBuffer, dataType: MapType)
120-
extends BasicColumnAccessor[MapData](buffer, MAP(dataType))
119+
extends BasicColumnAccessor[UnsafeMapData](buffer, MAP(dataType))
121120
with NullableColumnAccessor
122121

123122
private[sql] object ColumnAccessor {

0 commit comments

Comments
 (0)