Skip to content

Commit 821b8db

Browse files
committed
add unsafe array
1 parent 16b928c commit 821b8db

File tree

10 files changed

+842
-15
lines changed

10 files changed

+842
-15
lines changed
Lines changed: 337 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions;
19+
20+
import org.apache.spark.sql.catalyst.InternalRow;
21+
import org.apache.spark.sql.types.*;
22+
import org.apache.spark.unsafe.PlatformDependent;
23+
import org.apache.spark.unsafe.array.ByteArrayMethods;
24+
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
25+
import org.apache.spark.unsafe.types.CalendarInterval;
26+
import org.apache.spark.unsafe.types.UTF8String;
27+
28+
import java.math.BigDecimal;
29+
import java.math.BigInteger;
30+
31+
// todo: doc
32+
// todo: there is a lof of duplicated code between UnsafeRow and UnsafeArrayData.
33+
public class UnsafeArrayData extends ArrayData {
34+
35+
private Object baseObject;
36+
private long baseOffset;
37+
38+
// The number of elements in this array
39+
private int numElements;
40+
41+
// The size of this array's backing data, in bytes
42+
private int sizeInBytes;
43+
44+
private int getElementOffset(int ordinal) {
45+
return PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + ordinal * 4L);
46+
}
47+
48+
private int getElementSize(int offset, int ordinal) {
49+
if (ordinal == numElements - 1) {
50+
return sizeInBytes - offset;
51+
} else {
52+
return Math.abs(getElementOffset(ordinal + 1)) - offset;
53+
}
54+
}
55+
56+
private void assertIndexIsValid(int ordinal) {
57+
assert ordinal >= 0 : "ordinal (" + ordinal + ") should >= 0";
58+
assert ordinal < numElements : "ordinal (" + ordinal + ") should < " + numElements;
59+
}
60+
61+
/**
62+
* Construct a new UnsafeArrayData. The resulting UnsafeArrayData won't be usable until
63+
* `pointTo()` has been called, since the value returned by this constructor is equivalent
64+
* to a null pointer.
65+
*/
66+
public UnsafeArrayData() { }
67+
68+
public Object getBaseObject() { return baseObject; }
69+
public long getBaseOffset() { return baseOffset; }
70+
public int getSizeInBytes() { return sizeInBytes; }
71+
72+
@Override
73+
public int numElements() { return numElements; }
74+
75+
/**
76+
* Update this UnsafeArrayData to point to different backing data.
77+
*
78+
* @param baseObject the base object
79+
* @param baseOffset the offset within the base object
80+
* @param sizeInBytes the size of this row's backing data, in bytes
81+
*/
82+
public void pointTo(Object baseObject, long baseOffset, int numElements, int sizeInBytes) {
83+
assert numElements >= 0 : "numElements (" + numElements + ") should >= 0";
84+
this.numElements = numElements;
85+
this.baseObject = baseObject;
86+
this.baseOffset = baseOffset;
87+
this.sizeInBytes = sizeInBytes;
88+
}
89+
90+
@Override
91+
public boolean isNullAt(int ordinal) {
92+
assertIndexIsValid(ordinal);
93+
return getElementOffset(ordinal) < 0;
94+
}
95+
96+
@Override
97+
public Object get(int ordinal, DataType dataType) {
98+
if (isNullAt(ordinal) || dataType instanceof NullType) {
99+
return null;
100+
} else if (dataType instanceof BooleanType) {
101+
return getBoolean(ordinal);
102+
} else if (dataType instanceof ByteType) {
103+
return getByte(ordinal);
104+
} else if (dataType instanceof ShortType) {
105+
return getShort(ordinal);
106+
} else if (dataType instanceof IntegerType) {
107+
return getInt(ordinal);
108+
} else if (dataType instanceof LongType) {
109+
return getLong(ordinal);
110+
} else if (dataType instanceof FloatType) {
111+
return getFloat(ordinal);
112+
} else if (dataType instanceof DoubleType) {
113+
return getDouble(ordinal);
114+
} else if (dataType instanceof DecimalType) {
115+
DecimalType dt = (DecimalType) dataType;
116+
return getDecimal(ordinal, dt.precision(), dt.scale());
117+
} else if (dataType instanceof DateType) {
118+
return getInt(ordinal);
119+
} else if (dataType instanceof TimestampType) {
120+
return getLong(ordinal);
121+
} else if (dataType instanceof BinaryType) {
122+
return getBinary(ordinal);
123+
} else if (dataType instanceof StringType) {
124+
return getUTF8String(ordinal);
125+
} else if (dataType instanceof CalendarIntervalType) {
126+
return getInterval(ordinal);
127+
} else if (dataType instanceof StructType) {
128+
return getStruct(ordinal, ((StructType) dataType).size());
129+
} else if (dataType instanceof ArrayType) {
130+
return getArray(ordinal);
131+
} else if (dataType instanceof MapType) {
132+
return getMap(ordinal);
133+
} else {
134+
throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString());
135+
}
136+
}
137+
138+
@Override
139+
public boolean getBoolean(int ordinal) {
140+
assertIndexIsValid(ordinal);
141+
final int offset = getElementOffset(ordinal);
142+
if (offset < 0) {
143+
return false;
144+
} else {
145+
return PlatformDependent.UNSAFE.getBoolean(baseObject, baseOffset + offset);
146+
}
147+
}
148+
149+
@Override
150+
public byte getByte(int ordinal) {
151+
assertIndexIsValid(ordinal);
152+
final int offset = getElementOffset(ordinal);
153+
if (offset < 0) {
154+
return 0;
155+
} else {
156+
return PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + offset);
157+
}
158+
}
159+
160+
@Override
161+
public short getShort(int ordinal) {
162+
assertIndexIsValid(ordinal);
163+
final int offset = getElementOffset(ordinal);
164+
if (offset < 0) {
165+
return 0;
166+
} else {
167+
return PlatformDependent.UNSAFE.getShort(baseObject, baseOffset + offset);
168+
}
169+
}
170+
171+
@Override
172+
public int getInt(int ordinal) {
173+
assertIndexIsValid(ordinal);
174+
final int offset = getElementOffset(ordinal);
175+
if (offset < 0) {
176+
return 0;
177+
} else {
178+
return PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset);
179+
}
180+
}
181+
182+
@Override
183+
public long getLong(int ordinal) {
184+
assertIndexIsValid(ordinal);
185+
final int offset = getElementOffset(ordinal);
186+
if (offset < 0) {
187+
return 0;
188+
} else {
189+
return PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset);
190+
}
191+
}
192+
193+
@Override
194+
public float getFloat(int ordinal) {
195+
assertIndexIsValid(ordinal);
196+
final int offset = getElementOffset(ordinal);
197+
if (offset < 0) {
198+
return 0;
199+
} else {
200+
return PlatformDependent.UNSAFE.getFloat(baseObject, baseOffset + offset);
201+
}
202+
}
203+
204+
@Override
205+
public double getDouble(int ordinal) {
206+
assertIndexIsValid(ordinal);
207+
final int offset = getElementOffset(ordinal);
208+
if (offset < 0) {
209+
return 0;
210+
} else {
211+
return PlatformDependent.UNSAFE.getDouble(baseObject, baseOffset + offset);
212+
}
213+
}
214+
215+
@Override
216+
public Decimal getDecimal(int ordinal, int precision, int scale) {
217+
assertIndexIsValid(ordinal);
218+
final int offset = getElementOffset(ordinal);
219+
if (offset < 0) {
220+
return null;
221+
} else {
222+
if (precision <= Decimal.MAX_LONG_DIGITS()) {
223+
final long value = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset);
224+
return Decimal.apply(value, precision, scale);
225+
} else {
226+
final byte[] bytes = getBinary(ordinal);
227+
final BigInteger bigInteger = new BigInteger(bytes);
228+
final BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
229+
return Decimal.apply(new scala.math.BigDecimal(javaDecimal), precision, scale);
230+
}
231+
}
232+
}
233+
234+
@Override
235+
public UTF8String getUTF8String(int ordinal) {
236+
final byte[] bytes = getBinary(ordinal);
237+
if (bytes == null) {
238+
return null;
239+
} else {
240+
return UTF8String.fromBytes(bytes);
241+
}
242+
}
243+
244+
@Override
245+
public byte[] getBinary(int ordinal) {
246+
assertIndexIsValid(ordinal);
247+
final int offset = getElementOffset(ordinal);
248+
if (offset < 0) {
249+
return null;
250+
} else {
251+
final int size = getElementSize(offset, ordinal);
252+
final byte[] bytes = new byte[size];
253+
PlatformDependent.copyMemory(
254+
baseObject,
255+
baseOffset + offset,
256+
bytes,
257+
PlatformDependent.BYTE_ARRAY_OFFSET,
258+
size);
259+
return bytes;
260+
}
261+
}
262+
263+
@Override
264+
public CalendarInterval getInterval(int ordinal) {
265+
assertIndexIsValid(ordinal);
266+
final int offset = getElementOffset(ordinal);
267+
if (offset < 0) {
268+
return null;
269+
} else {
270+
final int months = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset);
271+
final long microseconds =
272+
PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset + 8);
273+
return new CalendarInterval(months, microseconds);
274+
}
275+
}
276+
277+
@Override
278+
public InternalRow getStruct(int ordinal, int numFields) {
279+
assertIndexIsValid(ordinal);
280+
final int offset = getElementOffset(ordinal);
281+
if (offset < 0) {
282+
return null;
283+
} else {
284+
final int size = getElementSize(offset, ordinal);
285+
final UnsafeRow row = new UnsafeRow();
286+
row.pointTo(baseObject, baseOffset + offset, numFields, size);
287+
return row;
288+
}
289+
}
290+
291+
@Override
292+
public ArrayData getArray(int ordinal) {
293+
assertIndexIsValid(ordinal);
294+
final int offset = getElementOffset(ordinal);
295+
if (offset < 0) {
296+
return null;
297+
} else {
298+
final int numElements = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset);
299+
final int size = getElementSize(offset, ordinal);
300+
final UnsafeArrayData array = new UnsafeArrayData();
301+
// Skip the first 4 bytes.
302+
array.pointTo(baseObject, baseOffset + offset + 4, numElements, size - 4);
303+
return array;
304+
}
305+
}
306+
307+
@Override
308+
public MapData getMap(int ordinal) {
309+
return null;
310+
}
311+
312+
@Override
313+
public int hashCode() {
314+
return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42);
315+
}
316+
317+
@Override
318+
public boolean equals(Object other) {
319+
if (other instanceof UnsafeArrayData) {
320+
UnsafeArrayData o = (UnsafeArrayData) other;
321+
return (sizeInBytes == o.sizeInBytes) &&
322+
ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset,
323+
sizeInBytes);
324+
}
325+
return false;
326+
}
327+
328+
public void writeToMemory(Object target, long targetOffset) {
329+
PlatformDependent.copyMemory(
330+
baseObject,
331+
baseOffset,
332+
target,
333+
targetOffset,
334+
sizeInBytes
335+
);
336+
}
337+
}

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,10 @@ public Object get(int ordinal, DataType dataType) {
291291
return getInterval(ordinal);
292292
} else if (dataType instanceof StructType) {
293293
return getStruct(ordinal, ((StructType) dataType).size());
294+
} else if (dataType instanceof ArrayType) {
295+
return getArray(ordinal);
296+
} else if (dataType instanceof MapType) {
297+
return getMap(ordinal);
294298
} else {
295299
throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString());
296300
}
@@ -420,6 +424,23 @@ public UnsafeRow getStruct(int ordinal, int numFields) {
420424
}
421425
}
422426

427+
@Override
428+
public ArrayData getArray(int ordinal) {
429+
if (isNullAt(ordinal)) {
430+
return null;
431+
} else {
432+
assertIndexIsValid(ordinal);
433+
final long offsetAndSize = getLong(ordinal);
434+
final int offset = (int) (offsetAndSize >> 32);
435+
final int size = (int) (offsetAndSize & ((1L << 32) - 1));
436+
final int numElements = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset);
437+
final UnsafeArrayData array = new UnsafeArrayData();
438+
// Skip the first 4 bytes.
439+
array.pointTo(baseObject, baseOffset + offset + 4, numElements, size - 4);
440+
return array;
441+
}
442+
}
443+
423444
/**
424445
* Copies this row, returning a self-contained UnsafeRow that stores its data in an internal
425446
* byte array rather than referencing data stored in a data page.

0 commit comments

Comments
 (0)