Skip to content

Commit 49adf26

Browse files
committed
add unsafe map
1 parent 20d1039 commit 49adf26

File tree

10 files changed

+265
-24
lines changed

10 files changed

+265
-24
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -277,17 +277,17 @@ public ArrayData getArray(int ordinal) {
277277
assertIndexIsValid(ordinal);
278278
final int offset = getElementOffset(ordinal);
279279
if (offset < 0) return null;
280-
final int numElements = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset);
281280
final int size = getElementSize(offset, ordinal);
282-
final UnsafeArrayData array = new UnsafeArrayData();
283-
// Skip the first 4 bytes.
284-
array.pointTo(baseObject, baseOffset + offset + 4, numElements, size - 4);
285-
return array;
281+
return UnsafeReaders.readArray(baseObject, baseOffset + offset, size);
286282
}
287283

288284
@Override
289285
public MapData getMap(int ordinal) {
290-
return null;
286+
assertIndexIsValid(ordinal);
287+
final int offset = getElementOffset(ordinal);
288+
if (offset < 0) return null;
289+
final int size = getElementSize(offset, ordinal);
290+
return UnsafeReaders.readMap(baseObject, baseOffset + offset, size);
291291
}
292292

293293
@Override
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions;
19+
20+
import org.apache.spark.sql.types.ArrayData;
21+
import org.apache.spark.sql.types.MapData;
22+
23+
/**
24+
* An Unsafe implementation of Map which is backed by raw memory instead of Java objects.
25+
*
26+
* Currently we just use 2 UnsafeArrayData to represent UnsafeMapData.
27+
*/
28+
public class UnsafeMapData extends MapData {
29+
30+
public final UnsafeArrayData keys;
31+
public final UnsafeArrayData values;
32+
// The number of elements in this array
33+
private int numElements;
34+
35+
public UnsafeMapData(UnsafeArrayData keys, UnsafeArrayData values) {
36+
assert keys.numElements() == values.numElements();
37+
this.numElements = keys.numElements();
38+
this.keys = keys;
39+
this.values = values;
40+
}
41+
42+
@Override
43+
public int numElements() {
44+
return numElements;
45+
}
46+
47+
@Override
48+
public ArrayData keyArray() {
49+
return keys;
50+
}
51+
52+
@Override
53+
public ArrayData valueArray() {
54+
return values;
55+
}
56+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions;
19+
20+
import org.apache.spark.unsafe.PlatformDependent;
21+
22+
public class UnsafeReaders {
23+
24+
public static UnsafeArrayData readArray(Object baseObject, long baseOffset, int numBytes) {
25+
// Read the number of elements from first 4 bytes.
26+
final int numElements = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset);
27+
final UnsafeArrayData array = new UnsafeArrayData();
28+
// Skip the first 4 bytes.
29+
array.pointTo(baseObject, baseOffset + 4, numElements, numBytes - 4);
30+
return array;
31+
}
32+
33+
public static UnsafeMapData readMap(Object baseObject, long baseOffset, int numBytes) {
34+
// Read the number of elements from first 4 bytes.
35+
final int numElements = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset);
36+
// Read the numBytes of key array in second 4 bytes.
37+
final int keyArraySize = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + 4);
38+
final int valueArraySize = numBytes - 8 - keyArraySize;
39+
40+
final UnsafeArrayData keyArray = new UnsafeArrayData();
41+
keyArray.pointTo(baseObject, baseOffset + 8, numElements, keyArraySize);
42+
43+
final UnsafeArrayData valueArray = new UnsafeArrayData();
44+
valueArray.pointTo(baseObject, baseOffset + 8 + keyArraySize, numElements, valueArraySize);
45+
46+
return new UnsafeMapData(keyArray, valueArray);
47+
}
48+
}

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,6 @@ public Decimal getDecimal(int ordinal, int precision, int scale) {
365365

366366
@Override
367367
public UTF8String getUTF8String(int ordinal) {
368-
assertIndexIsValid(ordinal);
369368
if (isNullAt(ordinal)) return null;
370369
final long offsetAndSize = getLong(ordinal);
371370
final int offset = (int) (offsetAndSize >> 32);
@@ -375,7 +374,6 @@ public UTF8String getUTF8String(int ordinal) {
375374

376375
@Override
377376
public byte[] getBinary(int ordinal) {
378-
assertIndexIsValid(ordinal);
379377
if (isNullAt(ordinal)) {
380378
return null;
381379
} else {
@@ -430,19 +428,25 @@ public ArrayData getArray(int ordinal) {
430428
final long offsetAndSize = getLong(ordinal);
431429
final int offset = (int) (offsetAndSize >> 32);
432430
final int size = (int) (offsetAndSize & ((1L << 32) - 1));
433-
final int numElements = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset);
434-
final UnsafeArrayData array = new UnsafeArrayData();
435-
// Skip the first 4 bytes.
436-
array.pointTo(baseObject, baseOffset + offset + 4, numElements, size - 4);
437-
return array;
431+
return UnsafeReaders.readArray(baseObject, baseOffset + offset, size);
432+
}
433+
}
434+
435+
@Override
436+
public MapData getMap(int ordinal) {
437+
if (isNullAt(ordinal)) {
438+
return null;
439+
} else {
440+
final long offsetAndSize = getLong(ordinal);
441+
final int offset = (int) (offsetAndSize >> 32);
442+
final int size = (int) (offsetAndSize & ((1L << 32) - 1));
443+
return UnsafeReaders.readMap(baseObject, baseOffset + offset, size);
438444
}
439445
}
440446

441447
/**
442448
* Copies this row, returning a self-contained UnsafeRow that stores its data in an internal
443449
* byte array rather than referencing data stored in a data page.
444-
* <p>
445-
* This method is only supported on UnsafeRows that do not use ObjectPools.
446450
*/
447451
@Override
448452
public UnsafeRow copy() {

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import org.apache.spark.sql.catalyst.InternalRow;
2121
import org.apache.spark.sql.types.Decimal;
22+
import org.apache.spark.sql.types.MapData;
2223
import org.apache.spark.unsafe.PlatformDependent;
2324
import org.apache.spark.unsafe.array.ByteArrayMethods;
2425
import org.apache.spark.unsafe.types.ByteArray;
@@ -215,4 +216,44 @@ public static int write(UnsafeRow target, int ordinal, int cursor, UnsafeArrayDa
215216
return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
216217
}
217218
}
219+
220+
public static class MapWriter {
221+
222+
public static int getSize(UnsafeMapData input) {
223+
// we need extra 8 bytes to store number of elements and numBytes of key array.
224+
final int sizeInBytes = 4 + 4 + input.keys.getSizeInBytes() + input.values.getSizeInBytes();
225+
return ByteArrayMethods.roundNumberOfBytesToNearestWord(sizeInBytes);
226+
}
227+
228+
public static int write(UnsafeRow target, int ordinal, int cursor, UnsafeMapData input) {
229+
final long offset = target.getBaseOffset() + cursor;
230+
final UnsafeArrayData keyArray = input.keys;
231+
final UnsafeArrayData valueArray = input.values;
232+
final int keysNumBytes = keyArray.getSizeInBytes();
233+
final int valuesNumBytes = valueArray.getSizeInBytes();
234+
final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes;
235+
236+
// write the number of elements into first 4 bytes.
237+
PlatformDependent.UNSAFE.putInt(target.getBaseObject(), offset, input.numElements());
238+
// write the numBytes of key array into second 4 bytes.
239+
PlatformDependent.UNSAFE.putInt(target.getBaseObject(), offset + 4, keysNumBytes);
240+
241+
// zero-out the padding bytes
242+
if ((numBytes & 0x07) > 0) {
243+
PlatformDependent.UNSAFE.putLong(
244+
target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L);
245+
}
246+
247+
// Write the bytes of key array to the variable length portion.
248+
keyArray.writeToMemory(target.getBaseObject(), offset + 8);
249+
250+
// Write the bytes of value array to the variable length portion.
251+
valueArray.writeToMemory(target.getBaseObject(), offset + 8 + keysNumBytes);
252+
253+
// Set the fixed length portion.
254+
target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes));
255+
256+
return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
257+
}
258+
}
218259
}

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ public static void writeToMemory(
3535
int numBytes) {
3636

3737
// zero-out the padding bytes
38-
if ((numBytes & 0x07) > 0) {
39-
PlatformDependent.UNSAFE.putLong(targetObject, targetOffset + ((numBytes >> 3) << 3), 0L);
40-
}
38+
// if ((numBytes & 0x07) > 0) {
39+
// PlatformDependent.UNSAFE.putLong(targetObject, targetOffset + ((numBytes >> 3) << 3), 0L);
40+
// }
4141

4242
// Write the UnsafeData to the target memory.
4343
PlatformDependent.copyMemory(
@@ -171,7 +171,39 @@ public static int write(Object targetObject, long targetOffset, UnsafeArrayData
171171
writeToMemory(input.getBaseObject(), input.getBaseOffset(),
172172
targetObject, targetOffset + 4, numBytes);
173173

174-
return getRoundedSize(numBytes) + 4;
174+
return getRoundedSize(numBytes + 4);
175+
}
176+
}
177+
178+
public static class MapWriter {
179+
180+
public static int getSize(UnsafeMapData input) {
181+
// we need extra 8 bytes to store number of elements and numBytes of key array.
182+
final int sizeInBytes = 4 + 4 + input.keys.getSizeInBytes() + input.values.getSizeInBytes();
183+
return getRoundedSize(sizeInBytes);
184+
}
185+
186+
public static int write(Object targetObject, long targetOffset, UnsafeMapData input) {
187+
final UnsafeArrayData keyArray = input.keys;
188+
final UnsafeArrayData valueArray = input.values;
189+
final int keysNumBytes = keyArray.getSizeInBytes();
190+
final int valuesNumBytes = valueArray.getSizeInBytes();
191+
final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes;
192+
193+
// write the number of elements into first 4 bytes.
194+
PlatformDependent.UNSAFE.putInt(targetObject, targetOffset, input.numElements());
195+
// write the numBytes of key array into second 4 bytes.
196+
PlatformDependent.UNSAFE.putInt(targetObject, targetOffset + 4, keysNumBytes);
197+
198+
// Write the bytes of key array to the variable length portion.
199+
writeToMemory(keyArray.getBaseObject(), keyArray.getBaseOffset(),
200+
targetObject, targetOffset + 8, keysNumBytes);
201+
202+
// Write the bytes of value array to the variable length portion.
203+
writeToMemory(valueArray.getBaseObject(), valueArray.getBaseOffset(),
204+
targetObject, targetOffset + 8 + keysNumBytes, valuesNumBytes);
205+
206+
return getRoundedSize(numBytes);
175207
}
176208
}
177209
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ import org.apache.spark.sql.types._
2323
case class FromUnsafe(child: Expression) extends UnaryExpression
2424
with ExpectsInputTypes with CodegenFallback {
2525

26-
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, StructType))
26+
override def inputTypes: Seq[AbstractDataType] =
27+
Seq(TypeCollection(ArrayType, StructType, MapType))
2728

2829
override def dataType: DataType = child.dataType
2930

@@ -38,19 +39,25 @@ case class FromUnsafe(child: Expression) extends UnaryExpression
3839
}
3940
new GenericInternalRow(result)
4041

41-
case ArrayType(elemnentType, _) =>
42+
case ArrayType(elementType, _) =>
4243
val array = value.asInstanceOf[UnsafeArrayData]
4344
val length = array.numElements()
4445
val result = new Array[Any](length)
4546
var i = 0
4647
while (i < length) {
4748
if (!array.isNullAt(i)) {
48-
result(i) = convert(array.get(i, elemnentType), elemnentType)
49+
result(i) = convert(array.get(i, elementType), elementType)
4950
}
5051
i += 1
5152
}
5253
new GenericArrayData(result)
5354

55+
case MapType(kt, vt, _) =>
56+
val map = value.asInstanceOf[UnsafeMapData]
57+
val safeKeyArray = convert(map.keys, ArrayType(kt)).asInstanceOf[GenericArrayData]
58+
val safeValueArray = convert(map.values, ArrayType(vt)).asInstanceOf[GenericArrayData]
59+
new ArrayBasedMapData(safeKeyArray, safeValueArray)
60+
5461
case _ => value
5562
}
5663

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ object FromUnsafeProjection {
156156
// todo: this is quite slow, maybe remove this whole projection after remove generic getter of
157157
// InternalRow?
158158
b.dataType match {
159-
case _: StructType | _: ArrayType => FromUnsafe(b)
159+
case _: StructType | _: ArrayType | _: MapType => FromUnsafe(b)
160160
case _ => b
161161
}
162162
}))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
306306
classOf[CalendarInterval].getName,
307307
classOf[ArrayData].getName,
308308
classOf[UnsafeArrayData].getName,
309-
classOf[MapData].getName
309+
classOf[MapData].getName,
310+
classOf[UnsafeMapData].getName
310311
))
311312
evaluator.setExtendedClass(classOf[GeneratedClass])
312313
try {

0 commit comments

Comments
 (0)