From 8465f2f43c38f131a3f0fc33b5c387c46b6483b5 Mon Sep 17 00:00:00 2001 From: Zotov Yuriy Date: Thu, 3 Mar 2016 15:11:51 +0300 Subject: [PATCH 1/4] new classes --- .../unsafe/memory/ByteArrayMemoryBlock.java | 66 ++++++++++++++++++ .../unsafe/memory/IntArrayMemoryBlock.java | 66 ++++++++++++++++++ .../unsafe/memory/LongArrayMemoryBlock.java | 66 ++++++++++++++++++ .../unsafe/memory/OffHeapMemoryBlock.java | 41 +++++++++++ .../codegen/MemoryBlockHolder.java | 68 +++++++++++++++++++ 5 files changed, 307 insertions(+) create mode 100644 common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java create mode 100644 common/unsafe/src/main/java/org/apache/spark/unsafe/memory/IntArrayMemoryBlock.java create mode 100644 common/unsafe/src/main/java/org/apache/spark/unsafe/memory/LongArrayMemoryBlock.java create mode 100644 common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/MemoryBlockHolder.java diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java new file mode 100644 index 0000000000000..4f9bf48bd48ea --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import org.apache.spark.unsafe.Platform; + +/** + * A consecutive block of memory, starting at a {@link MemoryLocation} with a fixed size. + */ +public class ByteArrayMemoryBlock extends MemoryLocation implements MemoryBlock { + + private final long length; + + /** + * Optional page number; used when this MemoryBlock represents a page allocated by a + * TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager, + * which lives in a different package. + */ + private int pageNumber = -1; + + public ByteArrayMemoryBlock(byte[] obj, long offset, long length) { + super(obj, offset); + this.length = length; + } + + /** + * Returns the size of the memory block. + */ + public long size() { + return length; + } + + @Override + public void setPageNumber(int aPageNum) { + this.pageNumber = aPageNum; + } + + @Override + public int getPageNumber() { + return this.pageNumber; + } + + public byte[] getByteArray() { return (byte[])this.obj; } + + /** + * Creates a memory block pointing to the memory used by the byte array. + */ + public static ByteArrayMemoryBlock fromByteArray(final byte[] array) { + return new ByteArrayMemoryBlock(array, Platform.BYTE_ARRAY_OFFSET, array.length); + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/IntArrayMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/IntArrayMemoryBlock.java new file mode 100644 index 0000000000000..5e93c7a2ef2b7 --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/IntArrayMemoryBlock.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import org.apache.spark.unsafe.Platform; + +/** + * A consecutive block of memory, starting at a {@link MemoryLocation} with a fixed size. + */ +public class IntArrayMemoryBlock extends MemoryLocation implements MemoryBlock { + + private final long size; + + /** + * Optional page number; used when this MemoryBlock represents a page allocated by a + * TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager, + * which lives in a different package. + */ + private int pageNumber = -1; + + public IntArrayMemoryBlock(int[] obj, long offset, long size) { + super(obj, offset); + this.size = size; + } + + /** + * Returns the size of the memory block. + */ + public long size() { + return size; + } + + @Override + public void setPageNumber(int aPageNum) { + this.pageNumber = aPageNum; + } + + @Override + public int getPageNumber() { + return this.pageNumber; + } + + public int[] getIntArray() { return (int[])this.obj; } + + /** + * Creates a memory block pointing to the memory used by the int array. + */ + public static IntArrayMemoryBlock fromIntArray(final int[] array) { + return new IntArrayMemoryBlock(array, Platform.INT_ARRAY_OFFSET, array.length*4); + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/LongArrayMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/LongArrayMemoryBlock.java new file mode 100644 index 0000000000000..1b164452b1f54 --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/LongArrayMemoryBlock.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import org.apache.spark.unsafe.Platform; + +/** + * A consecutive block of memory, starting at a {@link MemoryLocation} with a fixed size. + */ +public class LongArrayMemoryBlock extends MemoryLocation implements MemoryBlock { + + private final long size; + + /** + * Optional page number; used when this MemoryBlock represents a page allocated by a + * TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager, + * which lives in a different package. + */ + private int pageNumber = -1; + + public LongArrayMemoryBlock(long[] obj, long offset, long size) { + super(obj, offset); + this.size = size; + } + + /** + * Returns the size of the memory block. + */ + public long size() { + return size; + } + + @Override + public void setPageNumber(int aPageNum) { + this.pageNumber = aPageNum; + } + + @Override + public int getPageNumber() { + return this.pageNumber; + } + + public long[] getLongArray() { return (long[])this.obj; } + + /** + * Creates a memory block pointing to the memory used by the long array. + */ + public static LongArrayMemoryBlock fromLongArray(final long[] array) { + return new LongArrayMemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length*8); + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java new file mode 100644 index 0000000000000..868f85c4757f0 --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java @@ -0,0 +1,41 @@ +package org.apache.spark.unsafe.memory; + +import org.apache.spark.unsafe.Platform; + +public class OffHeapMemoryBlock implements MemoryBlock { + private Object directBuffer; + private final long address; + private final long length; + private int pageNumber = -1; + + public OffHeapMemoryBlock(Object aDirectBuffer, long address, long size) { + this.address = address; + this.length = size; + this.directBuffer = aDirectBuffer; + } + + @Override + public Object getBaseObject() { + return null; + } + + @Override + public long getBaseOffset() { + return this.address; + } + + @Override + public long size() { + return this.length; + } + + @Override + public void setPageNumber(int aPageNum) { + this.pageNumber = aPageNum; + } + + @Override + public int getPageNumber() { + return this.pageNumber; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/MemoryBlockHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/MemoryBlockHolder.java new file mode 100644 index 0000000000000..a9c4de315e847 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/MemoryBlockHolder.java @@ -0,0 +1,68 @@ +package org.apache.spark.sql.catalyst.expressions.codegen; + +import org.apache.spark.deploy.SparkHadoopUtil; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.MemoryBlock; + +public class MemoryBlockHolder { + private UnsafeRow row; + private final int fixedSize; + private MemoryBlock block; + private MemoryAllocator alloc; + + public int cursor = 0; + + public MemoryBlockHolder( UnsafeRow aRow ) { this( aRow, 64 ); } + public MemoryBlockHolder( UnsafeRow aRow, int aSize ) { + this.row = aRow; + + MemoryAllocator anAllocator; + if (SparkHadoopUtil.get().conf().getBoolean("spark.memory.offHeap.enabled", false)) + anAllocator = MemoryAllocator.UNSAFE; + else + anAllocator = MemoryAllocator.HEAP; + + this.alloc = anAllocator; + + this.fixedSize = UnsafeRow.calculateBitSetWidthInBytes(this.row.numFields()) + 8 * this.row.numFields(); + int totSize = this.fixedSize + aSize; + this.block = this.alloc.allocate(totSize < 64 ? 64 : totSize); + this.row.pointTo(this.block, this.block.getBaseOffset(), totSize); + } + + public MemoryBlock getBaseObject() { return this.block; } + + public long getBaseOffset() { return this.block.getBaseOffset(); } + + /** + * Grows the buffer to at least neededSize. If row is non-null, points the row to the buffer. + */ + public void grow(int neededSize) { + final int length = totalSize() + neededSize; + if (this.block.size() < length) { + // This will not happen frequently, because the buffer is re-used. + //TODO: implement reallocate() + MemoryBlock tmp = this.alloc.allocate( 2*length ); + Platform.copyMemory( + this.block, + this.block.getBaseOffset(), + tmp, + tmp.getBaseOffset(), + totalSize()); + this.block = tmp; + if (this.row != null) { + this.row.pointTo(this.block, this.block.getBaseOffset(), length * 2); + } + } + } + + public void reset() { + cursor = this.fixedSize; + } + + public int totalSize() { + return this.cursor; + } +} From e0619db7043b3adb61265c424d7dae0e2a00f1fc Mon Sep 17 00:00:00 2001 From: Zotov Yuriy Date: Thu, 3 Mar 2016 15:18:37 +0300 Subject: [PATCH 2/4] changes to unsafe subsystem --- .../org/apache/spark/unsafe/Platform.java | 153 ++++++++++++++---- .../spark/unsafe/array/ByteArrayMethods.java | 23 ++- .../apache/spark/unsafe/array/LongArray.java | 12 +- .../spark/unsafe/bitset/BitSetMethods.java | 11 +- .../spark/unsafe/hash/Murmur3_x86_32.java | 43 ++++- .../unsafe/memory/HeapMemoryAllocator.java | 6 +- .../spark/unsafe/memory/MemoryAllocator.java | 4 +- .../spark/unsafe/memory/MemoryBlock.java | 59 +------ .../unsafe/memory/UnsafeMemoryAllocator.java | 48 +++++- .../apache/spark/unsafe/types/ByteArray.java | 3 +- .../apache/spark/unsafe/types/UTF8String.java | 39 +++-- .../spark/unsafe/array/LongArraySuite.java | 21 ++- 12 files changed, 295 insertions(+), 127 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index 18761bfd222a2..3914c54266d57 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -19,6 +19,7 @@ import java.lang.reflect.Field; +import org.apache.spark.unsafe.memory.MemoryBlock; import sun.misc.Unsafe; public final class Platform { @@ -37,68 +38,92 @@ public final class Platform { public static final int DOUBLE_ARRAY_OFFSET; - public static int getInt(Object object, long offset) { + public static int getInt(MemoryBlock object, long offset) { + return _UNSAFE.getInt(object.getBaseObject(), offset); + } + + public static int getInt(byte[] object, long offset) { return _UNSAFE.getInt(object, offset); } - public static void putInt(Object object, long offset, int value) { + public static void putInt(MemoryBlock object, long offset, int value) { + _UNSAFE.putInt(object.getBaseObject(), offset, value); + } + + public static void putInt(byte[] object, long offset, int value) { _UNSAFE.putInt(object, offset, value); } - public static boolean getBoolean(Object object, long offset) { - return _UNSAFE.getBoolean(object, offset); + public static boolean getBoolean(MemoryBlock object, long offset) { + return _UNSAFE.getBoolean(object.getBaseObject(), offset); } - public static void putBoolean(Object object, long offset, boolean value) { - _UNSAFE.putBoolean(object, offset, value); + public static void putBoolean(MemoryBlock object, long offset, boolean value) { + _UNSAFE.putBoolean(object.getBaseObject(), offset, value); } - public static byte getByte(Object object, long offset) { + public static byte getByte(MemoryBlock object, long offset) { + return _UNSAFE.getByte(object.getBaseObject(), offset); + } + + public static byte getByte(byte[] object, long offset) { return _UNSAFE.getByte(object, offset); } - public static void putByte(Object object, long offset, byte value) { - _UNSAFE.putByte(object, offset, value); + public static void putByte(MemoryBlock object, long offset, byte value) { + _UNSAFE.putByte(object.getBaseObject(), offset, value); + } + + public static short getShort(MemoryBlock object, long offset) { + return _UNSAFE.getShort(object.getBaseObject(), offset); } - public static short getShort(Object object, long offset) { - return _UNSAFE.getShort(object, offset); + public static void putShort(MemoryBlock object, long offset, short value) { + _UNSAFE.putShort(object.getBaseObject(), offset, value); } - public static void putShort(Object object, long offset, short value) { - _UNSAFE.putShort(object, offset, value); + public static long getLong(MemoryBlock object, long offset) { + return _UNSAFE.getLong(object.getBaseObject(), offset); } - public static long getLong(Object object, long offset) { + public static long getLong(byte[] object, long offset) { return _UNSAFE.getLong(object, offset); } - public static void putLong(Object object, long offset, long value) { - _UNSAFE.putLong(object, offset, value); + public static void putLong(MemoryBlock object, long offset, long value) { + _UNSAFE.putLong(object.getBaseObject(), offset, value); + } + + public static float getFloat(MemoryBlock object, long offset) { + return _UNSAFE.getFloat(object.getBaseObject(), offset); } - public static float getFloat(Object object, long offset) { + public static float getFloat(byte[] object, long offset) { return _UNSAFE.getFloat(object, offset); } - public static void putFloat(Object object, long offset, float value) { - _UNSAFE.putFloat(object, offset, value); + public static void putFloat(MemoryBlock object, long offset, float value) { + _UNSAFE.putFloat(object.getBaseObject(), offset, value); } - public static double getDouble(Object object, long offset) { + public static double getDouble(MemoryBlock object, long offset) { + return _UNSAFE.getDouble(object.getBaseObject(), offset); + } + + public static double getDouble(byte[] object, long offset) { return _UNSAFE.getDouble(object, offset); } - public static void putDouble(Object object, long offset, double value) { - _UNSAFE.putDouble(object, offset, value); + public static void putDouble(MemoryBlock object, long offset, double value) { + _UNSAFE.putDouble(object.getBaseObject(), offset, value); } - public static Object getObjectVolatile(Object object, long offset) { - return _UNSAFE.getObjectVolatile(object, offset); + public static Object getObjectVolatile(MemoryBlock object, long offset) { + return _UNSAFE.getObjectVolatile(object.getBaseObject(), offset); } - public static void putObjectVolatile(Object object, long offset, Object value) { - _UNSAFE.putObjectVolatile(object, offset, value); + public static void putObjectVolatile(MemoryBlock object, long offset, Object value) { + _UNSAFE.putObjectVolatile(object.getBaseObject(), offset, value); } public static long allocateMemory(long size) { @@ -111,7 +136,7 @@ public static void freeMemory(long address) { public static long reallocateMemory(long address, long oldSize, long newSize) { long newMemory = _UNSAFE.allocateMemory(newSize); - copyMemory(null, address, null, newMemory, oldSize); + copyMemory0(null, address, null, newMemory, oldSize); freeMemory(address); return newMemory; } @@ -120,8 +145,8 @@ public static void setMemory(long address, byte value, long size) { _UNSAFE.setMemory(address, size, value); } - public static void copyMemory( - Object src, long srcOffset, Object dst, long dstOffset, long length) { + static void copyMemory0( + Object src, long srcOffset, Object dst, long dstOffset, long length) { // Check if dstOffset is before or after srcOffset to determine if we should copy // forward or backwards. This is necessary in case src and dst overlap. if (dstOffset < srcOffset) { @@ -146,6 +171,76 @@ public static void copyMemory( } } + public static void copyMemory( + MemoryBlock src, long srcOffset, MemoryBlock dst, long dstOffset, long length) { + Platform.copyMemory0(src.getBaseObject(), srcOffset, dst.getBaseObject(), dstOffset, length); + } + + public static void copyMemory( + byte[] src, long srcOffset, MemoryBlock dst, long dstOffset, long length) { + Platform.copyMemory0(src, srcOffset, dst.getBaseObject(), dstOffset, length); + } + + public static void copyMemory( + short[] src, long srcOffset, MemoryBlock dst, long dstOffset, long length) { + Platform.copyMemory0(src, srcOffset, dst.getBaseObject(), dstOffset, length); + } + + public static void copyMemory( + int[] src, long srcOffset, MemoryBlock dst, long dstOffset, long length) { + Platform.copyMemory0(src, srcOffset, dst.getBaseObject(), dstOffset, length); + } + + public static void copyMemory( + long[] src, long srcOffset, MemoryBlock dst, long dstOffset, long length) { + Platform.copyMemory0(src, srcOffset, dst.getBaseObject(), dstOffset, length); + } + + public static void copyMemory( + float[] src, long srcOffset, MemoryBlock dst, long dstOffset, long length) { + Platform.copyMemory0(src, srcOffset, dst.getBaseObject(), dstOffset, length); + } + + public static void copyMemory( + double[] src, long srcOffset, MemoryBlock dst, long dstOffset, long length) { + Platform.copyMemory0(src, srcOffset, dst.getBaseObject(), dstOffset, length); + } + + public static void copyMemory( + MemoryBlock src, long srcOffset, byte[] dst, long dstOffset, long length) { + Platform.copyMemory0(src.getBaseObject(), srcOffset, dst, dstOffset, length); + } + + public static void copyMemory( + byte[] src, long srcOffset, byte[] dst, long dstOffset, long length) { + Platform.copyMemory0(src, srcOffset, dst, dstOffset, length); + } + + public static void copyMemory( + byte[] src, long srcOffset, short[] dst, long dstOffset, long length) { + Platform.copyMemory0(src, srcOffset, dst, dstOffset, length); + } + + public static void copyMemory( + byte[] src, long srcOffset, int[] dst, long dstOffset, long length) { + Platform.copyMemory0(src, srcOffset, dst, dstOffset, length); + } + + public static void copyMemory( + byte[] src, long srcOffset, long[] dst, long dstOffset, long length) { + Platform.copyMemory0(src, srcOffset, dst, dstOffset, length); + } + + public static void copyMemory( + byte[] src, long srcOffset, float[] dst, long dstOffset, long length) { + Platform.copyMemory0(src, srcOffset, dst, dstOffset, length); + } + + public static void copyMemory( + byte[] src, long srcOffset, double[] dst, long dstOffset, long length) { + Platform.copyMemory0(src, srcOffset, dst, dstOffset, length); + } + /** * Raises an exception bypassing compiler checks for checked exceptions. */ diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index cf42877bf9fd4..628439d877a68 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -18,6 +18,8 @@ package org.apache.spark.unsafe.array; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; public class ByteArrayMethods { @@ -45,7 +47,7 @@ public static int roundNumberOfBytesToNearestWord(int numBytes) { * @return true if the arrays are equal, false otherwise */ public static boolean arrayEquals( - Object leftBase, long leftOffset, Object rightBase, long rightOffset, final long length) { + MemoryBlock leftBase, long leftOffset, MemoryBlock rightBase, long rightOffset, final long length) { int i = 0; while (i <= length - 8) { if (Platform.getLong(leftBase, leftOffset + i) != @@ -63,4 +65,23 @@ public static boolean arrayEquals( } return true; } + + public static boolean arrayEquals( + byte[] leftBase, long leftOffset, MemoryBlock rightBase, long rightOffset, final long length) { + ByteArrayMemoryBlock bleft = ByteArrayMemoryBlock.fromByteArray(leftBase); + return ByteArrayMethods.arrayEquals(bleft, leftOffset, rightBase, rightOffset, length); + } + + public static boolean arrayEquals( + MemoryBlock leftBase, long leftOffset, byte[] rightBase, long rightOffset, final long length) { + ByteArrayMemoryBlock bright = ByteArrayMemoryBlock.fromByteArray(rightBase); + return ByteArrayMethods.arrayEquals(leftBase, leftOffset, bright, rightOffset, length); + } + + public static boolean arrayEquals( + byte[] leftBase, long leftOffset, byte[] rightBase, long rightOffset, final long length) { + ByteArrayMemoryBlock bleft = ByteArrayMemoryBlock.fromByteArray(leftBase); + ByteArrayMemoryBlock bright = ByteArrayMemoryBlock.fromByteArray(rightBase); + return ByteArrayMethods.arrayEquals(bleft, leftOffset, bright, rightOffset, length); + } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java index 1a3cdff638264..169ec454b2300 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java @@ -33,7 +33,6 @@ public final class LongArray { private static final long WIDTH = 8; private final MemoryBlock memory; - private final Object baseObj; private final long baseOffset; private final long length; @@ -41,7 +40,6 @@ public final class LongArray { public LongArray(MemoryBlock memory) { assert memory.size() < (long) Integer.MAX_VALUE * 8: "Array size > 4 billion elements"; this.memory = memory; - this.baseObj = memory.getBaseObject(); this.baseOffset = memory.getBaseOffset(); this.length = memory.size() / WIDTH; } @@ -50,10 +48,6 @@ public MemoryBlock memoryBlock() { return memory; } - public Object getBaseObject() { - return baseObj; - } - public long getBaseOffset() { return baseOffset; } @@ -70,7 +64,7 @@ public long size() { */ public void zeroOut() { for (long off = baseOffset; off < baseOffset + length * WIDTH; off += WIDTH) { - Platform.putLong(baseObj, off, 0); + Platform.putLong(memory, off, 0); } } @@ -80,7 +74,7 @@ public void zeroOut() { public void set(int index, long value) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - Platform.putLong(baseObj, baseOffset + index * WIDTH, value); + Platform.putLong(memory, baseOffset + index * WIDTH, value); } /** @@ -89,6 +83,6 @@ public void set(int index, long value) { public long get(int index) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - return Platform.getLong(baseObj, baseOffset + index * WIDTH); + return Platform.getLong(memory, baseOffset + index * WIDTH); } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java index 7857bf66a72ad..da51bc48016eb 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java @@ -18,6 +18,7 @@ package org.apache.spark.unsafe.bitset; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryBlock; /** * Methods for working with fixed-size uncompressed bitsets. @@ -37,7 +38,7 @@ private BitSetMethods() { /** * Sets the bit at the specified index to {@code true}. */ - public static void set(Object baseObject, long baseOffset, int index) { + public static void set(MemoryBlock baseObject, long baseOffset, int index) { assert index >= 0 : "index (" + index + ") should >= 0"; final long mask = 1L << (index & 0x3f); // mod 64 and shift final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; @@ -48,7 +49,7 @@ public static void set(Object baseObject, long baseOffset, int index) { /** * Sets the bit at the specified index to {@code false}. */ - public static void unset(Object baseObject, long baseOffset, int index) { + public static void unset(MemoryBlock baseObject, long baseOffset, int index) { assert index >= 0 : "index (" + index + ") should >= 0"; final long mask = 1L << (index & 0x3f); // mod 64 and shift final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; @@ -59,7 +60,7 @@ public static void unset(Object baseObject, long baseOffset, int index) { /** * Returns {@code true} if the bit is set at the specified index. */ - public static boolean isSet(Object baseObject, long baseOffset, int index) { + public static boolean isSet(MemoryBlock baseObject, long baseOffset, int index) { assert index >= 0 : "index (" + index + ") should >= 0"; final long mask = 1L << (index & 0x3f); // mod 64 and shift final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; @@ -70,7 +71,7 @@ public static boolean isSet(Object baseObject, long baseOffset, int index) { /** * Returns {@code true} if any bit is set. */ - public static boolean anySet(Object baseObject, long baseOffset, long bitSetWidthInWords) { + public static boolean anySet(MemoryBlock baseObject, long baseOffset, long bitSetWidthInWords) { long addr = baseOffset; for (int i = 0; i < bitSetWidthInWords; i++, addr += WORD_SIZE) { if (Platform.getLong(baseObject, addr) != 0) { @@ -98,7 +99,7 @@ public static boolean anySet(Object baseObject, long baseOffset, long bitSetWidt * @return the index of the next set bit, or -1 if there is no such bit */ public static int nextSetBit( - Object baseObject, + MemoryBlock baseObject, long baseOffset, int fromIndex, int bitsetSizeInWords) { diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index 5e7ee480cafd1..e585158819f9d 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -18,6 +18,7 @@ package org.apache.spark.unsafe.hash; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryBlock; /** * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. @@ -48,18 +49,41 @@ public static int hashInt(int input, int seed) { return fmix(h1, 4); } - public int hashUnsafeWords(Object base, long offset, int lengthInBytes) { + public int hashUnsafeWords(byte[] base, long offset, int lengthInBytes) { return hashUnsafeWords(base, offset, lengthInBytes, seed); } - public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { + public int hashUnsafeWords(MemoryBlock base, long offset, int lengthInBytes) { + return hashUnsafeWords(base, offset, lengthInBytes, seed); + } + + public static int hashUnsafeWords(byte[] base, long offset, int lengthInBytes, int seed) { // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; int h1 = hashBytesByInt(base, offset, lengthInBytes, seed); return fmix(h1, lengthInBytes); } - public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + public static int hashUnsafeWords(MemoryBlock base, long offset, int lengthInBytes, int seed) { + // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. + assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; + int h1 = hashBytesByInt(base, offset, lengthInBytes, seed); + return fmix(h1, lengthInBytes); + } + + public static int hashUnsafeBytes(byte[] base, long offset, int lengthInBytes, int seed) { + assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; + int lengthAligned = lengthInBytes - lengthInBytes % 4; + int h1 = hashBytesByInt(base, offset, lengthAligned, seed); + for (int i = lengthAligned; i < lengthInBytes; i++) { + int halfWord = Platform.getByte(base, offset + i); + int k1 = mixK1(halfWord); + h1 = mixH1(h1, k1); + } + return fmix(h1, lengthInBytes); + } + + public static int hashUnsafeBytes(MemoryBlock base, long offset, int lengthInBytes, int seed) { assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int lengthAligned = lengthInBytes - lengthInBytes % 4; int h1 = hashBytesByInt(base, offset, lengthAligned, seed); @@ -71,7 +95,18 @@ public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, i return fmix(h1, lengthInBytes); } - private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) { + private static int hashBytesByInt(byte[] base, long offset, int lengthInBytes, int seed) { + assert (lengthInBytes % 4 == 0); + int h1 = seed; + for (int i = 0; i < lengthInBytes; i += 4) { + int halfWord = Platform.getInt(base, offset + i); + int k1 = mixK1(halfWord); + h1 = mixH1(h1, k1); + } + return h1; + } + + private static int hashBytesByInt(MemoryBlock base, long offset, int lengthInBytes, int seed) { assert (lengthInBytes % 4 == 0); int h1 = seed; for (int i = 0; i < lengthInBytes; i += 4) { diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index 09847cec9c4ca..40a2832881e51 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -64,11 +64,15 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { } } long[] array = new long[(int) ((size + 7) / 8)]; - return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); + return new LongArrayMemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); } @Override public void free(MemoryBlock memory) { + if(memory.getBaseObject() == null ) { + throw new IllegalArgumentException("cannot manage off-heap memory block"); + } + final long size = memory.size(); if (shouldPool(size)) { synchronized (this) { diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java index 5192f68c862cf..1ee95e234700d 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java @@ -27,7 +27,7 @@ public interface MemoryAllocator { void free(MemoryBlock memory); - MemoryAllocator UNSAFE = new UnsafeMemoryAllocator(); + UnsafeMemoryAllocator UNSAFE = new UnsafeMemoryAllocator(); - MemoryAllocator HEAP = new HeapMemoryAllocator(); + HeapMemoryAllocator HEAP = new HeapMemoryAllocator(); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java index e3e79471154df..2ee887c4d247e 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -1,56 +1,9 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - package org.apache.spark.unsafe.memory; -import javax.annotation.Nullable; - -import org.apache.spark.unsafe.Platform; - -/** - * A consecutive block of memory, starting at a {@link MemoryLocation} with a fixed size. - */ -public class MemoryBlock extends MemoryLocation { - - private final long length; - - /** - * Optional page number; used when this MemoryBlock represents a page allocated by a - * TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager, - * which lives in a different package. - */ - public int pageNumber = -1; - - public MemoryBlock(@Nullable Object obj, long offset, long length) { - super(obj, offset); - this.length = length; - } - - /** - * Returns the size of the memory block. - */ - public long size() { - return length; - } - - /** - * Creates a memory block pointing to the memory used by the long array. - */ - public static MemoryBlock fromLongArray(final long[] array) { - return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8); - } +public interface MemoryBlock { + Object getBaseObject(); + long getBaseOffset(); + long size(); + void setPageNumber(int aPageNum); + int getPageNumber(); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java index 98ce711176e43..162cc03ccfb83 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java @@ -19,21 +19,57 @@ import org.apache.spark.unsafe.Platform; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.nio.ByteBuffer; + /** * A simple {@link MemoryAllocator} that uses {@code Unsafe} to allocate off-heap memory. */ public class UnsafeMemoryAllocator implements MemoryAllocator { + private static Method bufAddrMethod; + static { + try { + Class cb = UnsafeMemoryAllocator.class.getClassLoader().loadClass("java.nio.DirectByteBuffer"); + bufAddrMethod = cb.getMethod("address"); + bufAddrMethod.setAccessible(true); + } + catch(Exception ex) { + throw new RuntimeException(ex.getMessage(), ex); + } + } + @Override - public MemoryBlock allocate(long size) throws OutOfMemoryError { - long address = Platform.allocateMemory(size); - return new MemoryBlock(null, address, size); + public OffHeapMemoryBlock allocate(long size) throws OutOfMemoryError { + try { + Object b = ByteBuffer.allocateDirect((int)size); + long addr = (long)bufAddrMethod.invoke(b); + return new OffHeapMemoryBlock(b, addr, size); + } catch (IllegalAccessException e) { + throw new RuntimeException(e.getMessage(), e); + } catch (InvocationTargetException e) { + Throwable tex = e.getTargetException(); + if( tex instanceof OutOfMemoryError) { + throw (OutOfMemoryError) tex; + } + else { + throw new RuntimeException(e.getMessage(), e); + } + } } @Override public void free(MemoryBlock memory) { - assert (memory.obj == null) : - "baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?"; - Platform.freeMemory(memory.offset); + // DirectByteBuffers are deallocated automatically by JVM when they become + // unreachable much like normal Objects in heap + } + + public OffHeapMemoryBlock reallocate(OffHeapMemoryBlock aBlock, long anOldSize, long aNewSize) { + OffHeapMemoryBlock nb = this.allocate(aNewSize); + if( aBlock.getBaseOffset() != 0 ) + Platform.copyMemory(aBlock, aBlock.getBaseOffset(), nb, nb.getBaseOffset(), anOldSize); + + return nb; } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index 3ced2094f5e6b..d16fb97fe567d 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -18,6 +18,7 @@ package org.apache.spark.unsafe.types; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryBlock; import java.util.Arrays; @@ -30,7 +31,7 @@ public final class ByteArray { * offset. The target memory address must already been allocated, and have enough space to * hold all the bytes in this string. */ - public static void writeToMemory(byte[] src, Object target, long targetOffset) { + public static void writeToMemory(byte[] src, MemoryBlock target, long targetOffset) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET, target, targetOffset, src.length); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 87706d0b68388..82ff4fe01b502 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -32,6 +32,8 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; import static org.apache.spark.unsafe.Platform.*; @@ -48,11 +50,11 @@ public final class UTF8String implements Comparable, Externalizable, // These are only updated by readExternal() or read() @Nonnull - private Object base; + private MemoryBlock base; private long offset; private int numBytes; - public Object getBaseObject() { return base; } + public MemoryBlock getBaseObject() { return base; } public long getBaseOffset() { return offset; } private static int[] bytesOfCodePointInUTF8 = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, @@ -96,7 +98,7 @@ public static UTF8String fromBytes(byte[] bytes, int offset, int numBytes) { /** * Creates an UTF8String from given address (base and offset) and length. */ - public static UTF8String fromAddress(Object base, long offset, int numBytes) { + public static UTF8String fromAddress(MemoryBlock base, long offset, int numBytes) { return new UTF8String(base, offset, numBytes); } @@ -124,7 +126,14 @@ public static UTF8String blankString(int length) { return fromBytes(spaces); } - protected UTF8String(Object base, long offset, int numBytes) { + protected UTF8String(byte[] bytes, long offset, int numBytes) { + ByteArrayMemoryBlock b = ByteArrayMemoryBlock.fromByteArray(bytes); + this.base = b; + this.offset = offset; + this.numBytes = numBytes; + } + + protected UTF8String(MemoryBlock base, long offset, int numBytes) { this.base = base; this.offset = offset; this.numBytes = numBytes; @@ -132,7 +141,7 @@ protected UTF8String(Object base, long offset, int numBytes) { // for serialization public UTF8String() { - this(null, 0, 0); + this((MemoryBlock)null, 0, 0); } /** @@ -140,7 +149,11 @@ public UTF8String() { * The target memory address must already been allocated, and have enough space to hold all the * bytes in this string. */ - public void writeToMemory(Object target, long targetOffset) { + public void writeToMemory(byte[] target, long targetOffset) { + Platform.copyMemory(base, offset, target, targetOffset, numBytes); + } + + public void writeToMemory(MemoryBlock target, long targetOffset) { Platform.copyMemory(base, offset, target, targetOffset, numBytes); } @@ -228,9 +241,9 @@ public long getPrefix() { */ public byte[] getBytes() { // avoid copy if `base` is `byte[]` - if (offset == BYTE_ARRAY_OFFSET && base instanceof byte[] - && ((byte[]) base).length == numBytes) { - return (byte[]) base; + if (offset == BYTE_ARRAY_OFFSET && base instanceof ByteArrayMemoryBlock + && (((ByteArrayMemoryBlock) base).getByteArray()).length == numBytes) { + return ((ByteArrayMemoryBlock) base).getByteArray(); } else { byte[] bytes = new byte[numBytes]; copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes); @@ -1001,8 +1014,8 @@ public void writeExternal(ObjectOutput out) throws IOException { public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { offset = BYTE_ARRAY_OFFSET; numBytes = in.readInt(); - base = new byte[numBytes]; - in.readFully((byte[]) base); + base = ByteArrayMemoryBlock.fromByteArray(new byte[numBytes]); + in.readFully( ((ByteArrayMemoryBlock)base).getByteArray() ); } @Override @@ -1016,8 +1029,8 @@ public void write(Kryo kryo, Output out) { public void read(Kryo kryo, Input in) { this.offset = BYTE_ARRAY_OFFSET; this.numBytes = in.readInt(); - this.base = new byte[numBytes]; - in.read((byte[]) base); + base = ByteArrayMemoryBlock.fromByteArray(new byte[numBytes]); + in.read( ((ByteArrayMemoryBlock)base).getByteArray() ); } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java index fb8e53b3348f3..e1fff28bfe1d7 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java @@ -17,17 +17,32 @@ package org.apache.spark.unsafe.array; +import org.apache.spark.unsafe.memory.LongArrayMemoryBlock; +import org.apache.spark.unsafe.memory.UnsafeMemoryAllocator; import org.junit.Assert; import org.junit.Test; -import org.apache.spark.unsafe.memory.MemoryBlock; - public class LongArraySuite { @Test public void basicTest() { long[] bytes = new long[2]; - LongArray arr = new LongArray(MemoryBlock.fromLongArray(bytes)); + LongArray arr = new LongArray(LongArrayMemoryBlock.fromLongArray(bytes)); + arr.set(0, 1L); + arr.set(1, 2L); + arr.set(1, 3L); + Assert.assertEquals(2, arr.size()); + Assert.assertEquals(1L, arr.get(0)); + Assert.assertEquals(3L, arr.get(1)); + + arr.zeroOut(); + Assert.assertEquals(0L, arr.get(0)); + Assert.assertEquals(0L, arr.get(1)); + } + + @Test + public void offheapTest() { + LongArray arr = new LongArray( UnsafeMemoryAllocator.UNSAFE.allocate(2*8) ); arr.set(0, 1L); arr.set(1, 2L); arr.set(1, 3L); From f9f277699e54b0133b77abba0773b3166884ea85 Mon Sep 17 00:00:00 2001 From: Zotov Yuriy Date: Thu, 3 Mar 2016 15:25:11 +0300 Subject: [PATCH 3/4] changes to spark-core --- .../spark/memory/TaskMemoryManager.java | 30 ++--- .../shuffle/sort/ShuffleExternalSorter.java | 6 +- .../shuffle/sort/ShuffleInMemorySorter.java | 4 +- .../shuffle/sort/ShuffleSortDataFormat.java | 8 +- .../shuffle/sort/UnsafeShuffleWriter.java | 4 +- .../spark/unsafe/map/BytesToBytesMap.java | 41 +++--- .../unsafe/sort/RecordComparator.java | 6 +- .../unsafe/sort/UnsafeExternalSorter.java | 18 +-- .../unsafe/sort/UnsafeInMemorySorter.java | 13 +- .../unsafe/sort/UnsafeSortDataFormat.java | 8 +- .../unsafe/sort/UnsafeSorterIterator.java | 4 +- .../unsafe/sort/UnsafeSorterSpillMerger.java | 4 +- .../unsafe/sort/UnsafeSorterSpillReader.java | 14 +- .../unsafe/sort/UnsafeSorterSpillWriter.java | 3 +- .../spark/memory/TaskMemoryManagerSuite.java | 4 +- .../sort/ShuffleInMemorySorterSuite.java | 4 +- .../map/AbstractBytesToBytesMapSuite.java | 123 +++++++++--------- .../sort/UnsafeExternalSorterSuite.java | 42 +++--- .../sort/UnsafeInMemorySorterSuite.java | 8 +- 19 files changed, 181 insertions(+), 163 deletions(-) diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 8757dff36f159..43a294541ccb4 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -275,7 +275,7 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { // this could trigger spilling to free some pages. return allocatePage(size, consumer); } - page.pageNumber = pageNumber; + page.setPageNumber(pageNumber); pageTable[pageNumber] = page; if (logger.isTraceEnabled()) { logger.trace("Allocate page number {} ({} bytes)", pageNumber, acquired); @@ -287,15 +287,15 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}. */ public void freePage(MemoryBlock page, MemoryConsumer consumer) { - assert (page.pageNumber != -1) : + assert (page.getPageNumber() != -1) : "Called freePage() on memory that wasn't allocated with allocatePage()"; - assert(allocatedPages.get(page.pageNumber)); - pageTable[page.pageNumber] = null; + assert(allocatedPages.get(page.getPageNumber())); + pageTable[page.getPageNumber()] = null; synchronized (this) { - allocatedPages.clear(page.pageNumber); + allocatedPages.clear(page.getPageNumber()); } if (logger.isTraceEnabled()) { - logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); + logger.trace("Freed page number {} ({} bytes)", page.getPageNumber(), page.size()); } long pageSize = page.size(); memoryManager.tungstenMemoryAllocator().free(page); @@ -319,7 +319,7 @@ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { // relative to the page's base offset; this relative offset will fit in 51 bits. offsetInPage -= page.getBaseOffset(); } - return encodePageNumberAndOffset(page.pageNumber, offsetInPage); + return encodePageNumberAndOffset(page.getPageNumber(), offsetInPage); } @VisibleForTesting @@ -341,17 +341,17 @@ private static long decodeOffset(long pagePlusOffsetAddress) { * Get the page associated with an address encoded by * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} */ - public Object getPage(long pagePlusOffsetAddress) { + public MemoryBlock getPage(long pagePlusOffsetAddress) { + final int pageNumber = decodePageNumber(pagePlusOffsetAddress); + assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); + final MemoryBlock page = pageTable[pageNumber]; + assert (page != null); + if (tungstenMemoryMode == MemoryMode.ON_HEAP) { - final int pageNumber = decodePageNumber(pagePlusOffsetAddress); - assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); - final MemoryBlock page = pageTable[pageNumber]; - assert (page != null); assert (page.getBaseObject() != null); - return page.getBaseObject(); - } else { - return null; } + + return page; } /** diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index f97e76d7ed0d9..655a3bba81c69 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -189,7 +189,7 @@ private void writeSortedFile(boolean isLastFile) throws IOException { } final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); - final Object recordPage = taskMemoryManager.getPage(recordPointer); + final MemoryBlock recordPage = taskMemoryManager.getPage(recordPointer); final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer); int dataRemaining = Platform.getInt(recordPage, recordOffsetInPage); long recordReadPosition = recordOffsetInPage + 4; // skip over record length @@ -352,7 +352,7 @@ private void acquireNewPageIfNecessary(int required) { /** * Write a record to the shuffle sorter. */ - public void insertRecord(Object recordBase, long recordOffset, int length, int partitionId) + public void insertRecord(MemoryBlock recordBase, long recordOffset, int length, int partitionId) throws IOException { // for tests @@ -367,7 +367,7 @@ public void insertRecord(Object recordBase, long recordOffset, int length, int p acquireNewPageIfNecessary(required); assert(currentPage != null); - final Object base = currentPage.getBaseObject(); + final MemoryBlock base = currentPage; final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor); Platform.putInt(base, pageCursor, length); pageCursor += 4; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index 2381cff61f069..c5890ed26f380 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -76,9 +76,9 @@ public void reset() { public void expandPointerArray(LongArray newArray) { assert(newArray.size() > array.size()); Platform.copyMemory( - array.getBaseObject(), + array.memoryBlock(), array.getBaseOffset(), - newArray.getBaseObject(), + newArray.memoryBlock(), newArray.getBaseOffset(), array.size() * 8L ); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java index 8f4e3229976dc..539de5f00d6bd 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java @@ -19,7 +19,7 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; -import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.LongArrayMemoryBlock; import org.apache.spark.util.collection.SortDataFormat; final class ShuffleSortDataFormat extends SortDataFormat { @@ -60,9 +60,9 @@ public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) { @Override public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) { Platform.copyMemory( - src.getBaseObject(), + src.memoryBlock(), src.getBaseOffset() + srcPos * 8, - dst.getBaseObject(), + dst.memoryBlock(), dst.getBaseOffset() + dstPos * 8, length * 8 ); @@ -71,7 +71,7 @@ public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int @Override public LongArray allocate(int length) { // This buffer is used temporary (usually small), so it's fine to allocated from JVM heap. - return new LongArray(MemoryBlock.fromLongArray(new long[length])); + return new LongArray(LongArrayMemoryBlock.fromLongArray(new long[length])); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 3f4402bd3a652..79348154172b1 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -22,6 +22,7 @@ import java.nio.channels.FileChannel; import java.util.Iterator; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; import scala.Option; import scala.Product2; import scala.collection.JavaConverters; @@ -232,8 +233,9 @@ void insertRecordIntoSorter(Product2 record) throws IOException { final int serializedRecordSize = serBuffer.size(); assert (serializedRecordSize > 0); + ByteArrayMemoryBlock serBytes = ByteArrayMemoryBlock.fromByteArray(serBuffer.getBuf()); sorter.insertRecord( - serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); + serBytes, Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); } @VisibleForTesting diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index b55a322a1b413..0c9c16a51c4b1 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -146,7 +146,7 @@ public final class BytesToBytesMap extends MemoryConsumer { private int mask; /** - * Return value of {@link BytesToBytesMap#lookup(Object, long, int)}. + * Return value of {@link BytesToBytesMap#lookup(MemoryBlock, long, int)}. */ private final Location loc; @@ -227,7 +227,7 @@ public final class MapIterator implements Iterator { private MemoryBlock currentPage = null; private int recordsInPage = 0; - private Object pageBaseObject; + private MemoryBlock pageBaseObject; private long offsetInPage; // If this iterator destructive or not. When it is true, it frees each page as it moves onto @@ -254,7 +254,7 @@ private void advanceToNextPage() { } if (dataPages.size() > nextIdx) { currentPage = dataPages.get(nextIdx); - pageBaseObject = currentPage.getBaseObject(); + pageBaseObject = currentPage; offsetInPage = currentPage.getBaseOffset(); recordsInPage = Platform.getInt(pageBaseObject, offsetInPage); offsetInPage += 4; @@ -347,7 +347,7 @@ public long spill(long numBytes) throws IOException { break; } - Object base = block.getBaseObject(); + MemoryBlock base = block; long offset = block.getBaseOffset(); int numRecords = Platform.getInt(base, offset); offset += 4; @@ -413,7 +413,7 @@ public MapIterator destructiveIterator() { * * This function always return the same {@link Location} instance to avoid object allocation. */ - public Location lookup(Object keyBase, long keyOffset, int keyLength) { + public Location lookup(MemoryBlock keyBase, long keyOffset, int keyLength) { safeLookup(keyBase, keyOffset, keyLength, loc, Murmur3_x86_32.hashUnsafeWords(keyBase, keyOffset, keyLength, 42)); return loc; @@ -425,7 +425,7 @@ public Location lookup(Object keyBase, long keyOffset, int keyLength) { * * This function always return the same {@link Location} instance to avoid object allocation. */ - public Location lookup(Object keyBase, long keyOffset, int keyLength, int hash) { + public Location lookup(MemoryBlock keyBase, long keyOffset, int keyLength, int hash) { safeLookup(keyBase, keyOffset, keyLength, loc, hash); return loc; } @@ -435,7 +435,7 @@ public Location lookup(Object keyBase, long keyOffset, int keyLength, int hash) * * This is a thread-safe version of `lookup`, could be used by multiple threads. */ - public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location loc, int hash) { + public void safeLookup(MemoryBlock keyBase, long keyOffset, int keyLength, Location loc, int hash) { assert(longArray != null); if (enablePerfMetrics) { @@ -480,7 +480,7 @@ public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location l } /** - * Handle returned by {@link BytesToBytesMap#lookup(Object, long, int)} function. + * Handle returned by {@link BytesToBytesMap#lookup(MemoryBlock, long, int)} function. */ public final class Location { /** An index into the hash map's Long array */ @@ -489,11 +489,11 @@ public final class Location { private boolean isDefined; /** * The hashcode of the most recent key passed to - * {@link BytesToBytesMap#lookup(Object, long, int, int)}. Caching this hashcode here allows us + * {@link BytesToBytesMap#lookup(MemoryBlock, long, int, int)}. Caching this hashcode here allows us * to avoid re-hashing the key when storing a value for that key. */ private int keyHashcode; - private Object baseObject; // the base object for key and value + private MemoryBlock baseObject; // the base object for key and value private long keyOffset; private int keyLength; private long valueOffset; @@ -505,12 +505,13 @@ public final class Location { @Nullable private MemoryBlock memoryPage; private void updateAddressesAndSizes(long fullKeyAddress) { + MemoryBlock page = taskMemoryManager.getPage(fullKeyAddress); updateAddressesAndSizes( - taskMemoryManager.getPage(fullKeyAddress), + page, taskMemoryManager.getOffsetInPage(fullKeyAddress)); } - private void updateAddressesAndSizes(final Object base, long offset) { + private void updateAddressesAndSizes(final MemoryBlock base, long offset) { baseObject = base; final int totalLength = Platform.getInt(base, offset); offset += 4; @@ -536,14 +537,14 @@ private Location with(int pos, int keyHashcode, boolean isDefined) { private Location with(MemoryBlock page, long offsetInPage) { this.isDefined = true; this.memoryPage = page; - updateAddressesAndSizes(page.getBaseObject(), offsetInPage); + updateAddressesAndSizes(page, offsetInPage); return this; } /** * This is only used for spilling */ - private Location with(Object base, long offset, int length) { + private Location with(MemoryBlock base, long offset, int length) { this.isDefined = true; this.memoryPage = null; baseObject = base; @@ -572,7 +573,7 @@ public boolean isDefined() { /** * Returns the base object for key. */ - public Object getKeyBase() { + public MemoryBlock getKeyBase() { assert (isDefined); return baseObject; } @@ -588,7 +589,7 @@ public long getKeyOffset() { /** * Returns the base object for value. */ - public Object getValueBase() { + public MemoryBlock getValueBase() { assert (isDefined); return baseObject; } @@ -652,8 +653,8 @@ public int getValueLength() { * @return true if the put() was successful and false if the put() failed because memory could * not be acquired. */ - public boolean putNewKey(Object keyBase, long keyOffset, int keyLength, - Object valueBase, long valueOffset, int valueLength) { + public boolean putNewKey(MemoryBlock keyBase, long keyOffset, int keyLength, + MemoryBlock valueBase, long valueOffset, int valueLength) { assert (!isDefined) : "Can only set value once for a key"; assert (keyLength % 8 == 0); assert (valueLength % 8 == 0); @@ -679,7 +680,7 @@ public boolean putNewKey(Object keyBase, long keyOffset, int keyLength, } // --- Append the key and value data to the current data page -------------------------------- - final Object base = currentPage.getBaseObject(); + final MemoryBlock base = currentPage; long offset = currentPage.getBaseOffset() + pageCursor; final long recordOffset = offset; Platform.putInt(base, offset, keyLength + valueLength + 4); @@ -723,7 +724,7 @@ private boolean acquireNewPage(long required) { return false; } dataPages.add(currentPage); - Platform.putInt(currentPage.getBaseObject(), currentPage.getBaseOffset(), 0); + Platform.putInt(currentPage, currentPage.getBaseOffset(), 0); pageCursor = 4; return true; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java index 09e4258792204..6d0ffc9187f0b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java @@ -17,6 +17,8 @@ package org.apache.spark.util.collection.unsafe.sort; +import org.apache.spark.unsafe.memory.MemoryBlock; + /** * Compares records for ordering. In cases where the entire sorting key can fit in the 8-byte * prefix, this may simply return 0. @@ -30,8 +32,8 @@ public abstract class RecordComparator { * equal to, or greater than the second. */ public abstract int compare( - Object leftBaseObject, + MemoryBlock leftBaseObject, long leftBaseOffset, - Object rightBaseObject, + MemoryBlock rightBaseObject, long rightBaseOffset); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 9236bd2c04fd9..7b41e3ad1a416 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -187,7 +187,7 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { final UnsafeSorterIterator sortedRecords = inMemSorter.getSortedIterator(); while (sortedRecords.hasNext()) { sortedRecords.loadNext(); - final Object baseObject = sortedRecords.getBaseObject(); + final MemoryBlock baseObject = sortedRecords.getBaseObject(); final long baseOffset = sortedRecords.getBaseOffset(); final int recordLength = sortedRecords.getRecordLength(); spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix()); @@ -325,7 +325,7 @@ private void acquireNewPageIfNecessary(int required) { /** * Write a record to the sorter. */ - public void insertRecord(Object recordBase, long recordOffset, int length, long prefix) + public void insertRecord(MemoryBlock recordBase, long recordOffset, int length, long prefix) throws IOException { growPointerArrayIfNecessary(); @@ -333,7 +333,7 @@ public void insertRecord(Object recordBase, long recordOffset, int length, long final int required = length + 4; acquireNewPageIfNecessary(required); - final Object base = currentPage.getBaseObject(); + final MemoryBlock base = currentPage; final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor); Platform.putInt(base, pageCursor, length); pageCursor += 4; @@ -351,15 +351,15 @@ public void insertRecord(Object recordBase, long recordOffset, int length, long * * record length = key length + value length + 4 */ - public void insertKVRecord(Object keyBase, long keyOffset, int keyLen, - Object valueBase, long valueOffset, int valueLen, long prefix) + public void insertKVRecord(MemoryBlock keyBase, long keyOffset, int keyLen, + MemoryBlock valueBase, long valueOffset, int valueLen, long prefix) throws IOException { growPointerArrayIfNecessary(); final int required = keyLen + valueLen + 4 + 4; acquireNewPageIfNecessary(required); - final Object base = currentPage.getBaseObject(); + final MemoryBlock base = currentPage; final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor); Platform.putInt(base, pageCursor, keyLen + valueLen + 4); pageCursor += 4; @@ -445,7 +445,7 @@ public long spill() throws IOException { new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, numRecords); while (inMemIterator.hasNext()) { inMemIterator.loadNext(); - final Object baseObject = inMemIterator.getBaseObject(); + final MemoryBlock baseObject = inMemIterator.getBaseObject(); final long baseOffset = inMemIterator.getBaseOffset(); final int recordLength = inMemIterator.getRecordLength(); spillWriter.write(baseObject, baseOffset, recordLength, inMemIterator.getKeyPrefix()); @@ -503,7 +503,7 @@ public void loadNext() throws IOException { } @Override - public Object getBaseObject() { + public MemoryBlock getBaseObject() { return upstream.getBaseObject(); } @@ -588,7 +588,7 @@ public void loadNext() throws IOException { } @Override - public Object getBaseObject() { return current.getBaseObject(); } + public MemoryBlock getBaseObject() { return current.getBaseObject(); } @Override public long getBaseOffset() { return current.getBaseOffset(); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index cea0f0a0c6c11..ef06d080f86e0 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -25,6 +25,7 @@ import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.Sorter; /** @@ -55,9 +56,9 @@ private static final class SortComparator implements Comparator arr.length) { - arr = new byte[recordLength]; - baseObject = arr; + if (recordLength > baseObject.getByteArray().length) { + baseObject = ByteArrayMemoryBlock.fromByteArray(new byte[recordLength]); } - ByteStreams.readFully(in, arr, 0, recordLength); + ByteStreams.readFully(in, baseObject.getByteArray(), 0, recordLength); numRecordsRemaining--; if (numRecordsRemaining == 0) { close(); @@ -87,7 +87,7 @@ public void loadNext() throws IOException { } @Override - public Object getBaseObject() { + public MemoryBlock getBaseObject() { return baseObject; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java index 234e21140a1dd..e49bd47bc7c5d 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -20,6 +20,7 @@ import java.io.File; import java.io.IOException; +import org.apache.spark.unsafe.memory.MemoryBlock; import scala.Tuple2; import org.apache.spark.executor.ShuffleWriteMetrics; @@ -100,7 +101,7 @@ private void writeIntToBuffer(int v, int offset) throws IOException { * @param keyPrefix a sort key prefix */ public void write( - Object baseObject, + MemoryBlock baseObject, long baseOffset, int recordLength, long keyPrefix) throws IOException { diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index 127789b632b44..aa6a2c94e4053 100644 --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -50,7 +50,7 @@ public void encodePageNumberAndOffsetOffHeap() { // encode. This test exercises that corner-case: final long offset = ((1L << TaskMemoryManager.OFFSET_BITS) + 10); final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, offset); - Assert.assertEquals(null, manager.getPage(encodedAddress)); + Assert.assertEquals(null, manager.getPage(encodedAddress).getBaseObject()); Assert.assertEquals(offset, manager.getOffsetInPage(encodedAddress)); } @@ -60,7 +60,7 @@ public void encodePageNumberAndOffsetOnHeap() { new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); final MemoryBlock dataPage = manager.allocatePage(256, null); final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64); - Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress)); + Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress).getBaseObject()); Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress)); } diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java index b4fa33f32a3fd..5488c1ea9cab5 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java @@ -38,7 +38,7 @@ public class ShuffleInMemorySorterSuite { final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0); final TestMemoryConsumer consumer = new TestMemoryConsumer(taskMemoryManager); - private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) { + private static String getStringFromDataPage(MemoryBlock baseObject, long baseOffset, int strLength) { final byte[] strBytes = new byte[strLength]; Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, strLength); return new String(strBytes); @@ -68,7 +68,7 @@ public void testBasicSorting() throws Exception { final TaskMemoryManager memoryManager = new TaskMemoryManager(new TestMemoryManager(conf), 0); final MemoryBlock dataPage = memoryManager.allocatePage(2048, null); - final Object baseObject = dataPage.getBaseObject(); + final MemoryBlock baseObject = dataPage; final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4); final HashPartitioner hashPartitioner = new HashPartitioner(4); diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index d8af2b336dd4d..0086394a5947f 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -24,6 +24,10 @@ import java.nio.ByteBuffer; import java.util.*; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.LongArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.OffHeapMemoryBlock; import scala.Tuple2; import scala.Tuple2$; import scala.runtime.AbstractFunction1; @@ -141,7 +145,7 @@ public void tearDown() { protected abstract boolean useOffHeapMemoryAllocator(); - private static byte[] getByteArray(Object base, long offset, int size) { + private static byte[] getByteArray(MemoryBlock base, long offset, int size) { final byte[] arr = new byte[size]; Platform.copyMemory(base, offset, arr, Platform.BYTE_ARRAY_OFFSET, size); return arr; @@ -161,7 +165,7 @@ private byte[] getRandomByteArray(int numWords) { */ private static boolean arrayEquals( byte[] expected, - Object base, + MemoryBlock base, long offset, long actualLengthBytes) { return (actualLengthBytes == expected.length) && ByteArrayMethods.arrayEquals( @@ -180,7 +184,7 @@ public void emptyMap() { Assert.assertEquals(0, map.numElements()); final int keyLengthInWords = 10; final int keyLengthInBytes = keyLengthInWords * 8; - final byte[] key = getRandomByteArray(keyLengthInWords); + final MemoryBlock key = ByteArrayMemoryBlock.fromByteArray(getRandomByteArray(keyLengthInWords)); Assert.assertFalse(map.lookup(key, Platform.BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined()); Assert.assertFalse(map.iterator().hasNext()); } finally { @@ -193,8 +197,8 @@ public void setAndRetrieveAKey() { BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 64, PAGE_SIZE_BYTES); final int recordLengthWords = 10; final int recordLengthBytes = recordLengthWords * 8; - final byte[] keyData = getRandomByteArray(recordLengthWords); - final byte[] valueData = getRandomByteArray(recordLengthWords); + final ByteArrayMemoryBlock keyData = ByteArrayMemoryBlock.fromByteArray(getRandomByteArray(recordLengthWords)); + final ByteArrayMemoryBlock valueData = ByteArrayMemoryBlock.fromByteArray(getRandomByteArray(recordLengthWords)); try { final BytesToBytesMap.Location loc = map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes); @@ -211,9 +215,9 @@ public void setAndRetrieveAKey() { // reflect the result of this store without us having to call lookup() again on the same key. Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); Assert.assertEquals(recordLengthBytes, loc.getValueLength()); - Assert.assertArrayEquals(keyData, + Assert.assertArrayEquals(keyData.getByteArray(), getByteArray(loc.getKeyBase(), loc.getKeyOffset(), recordLengthBytes)); - Assert.assertArrayEquals(valueData, + Assert.assertArrayEquals(valueData.getByteArray(), getByteArray(loc.getValueBase(), loc.getValueOffset(), recordLengthBytes)); // After calling lookup() the location should still point to the correct data. @@ -221,9 +225,9 @@ public void setAndRetrieveAKey() { map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined()); Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); Assert.assertEquals(recordLengthBytes, loc.getValueLength()); - Assert.assertArrayEquals(keyData, + Assert.assertArrayEquals(keyData.getByteArray(), getByteArray(loc.getKeyBase(), loc.getKeyOffset(), recordLengthBytes)); - Assert.assertArrayEquals(valueData, + Assert.assertArrayEquals(valueData.getByteArray(), getByteArray(loc.getValueBase(), loc.getValueOffset(), recordLengthBytes)); try { @@ -250,25 +254,26 @@ private void iteratorTestBase(boolean destructive) throws Exception { try { for (long i = 0; i < size; i++) { final long[] value = new long[] { i }; + LongArrayMemoryBlock bvalue = LongArrayMemoryBlock.fromLongArray(value); final BytesToBytesMap.Location loc = - map.lookup(value, Platform.LONG_ARRAY_OFFSET, 8); + map.lookup(bvalue, Platform.LONG_ARRAY_OFFSET, 8); Assert.assertFalse(loc.isDefined()); // Ensure that we store some zero-length keys if (i % 5 == 0) { Assert.assertTrue(loc.putNewKey( - null, + new OffHeapMemoryBlock(null, 0, 0), Platform.LONG_ARRAY_OFFSET, 0, - value, + bvalue, Platform.LONG_ARRAY_OFFSET, 8 )); } else { Assert.assertTrue(loc.putNewKey( - value, + bvalue, Platform.LONG_ARRAY_OFFSET, 8, - value, + bvalue, Platform.LONG_ARRAY_OFFSET, 8 )); @@ -336,8 +341,8 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { // handling branch in iterator(). try { for (int i = 0; i < NUM_ENTRIES; i++) { - final long[] key = new long[] { i, i, i }; // 3 * 8 = 24 bytes - final long[] value = new long[] { i, i, i, i, i }; // 5 * 8 = 40 bytes + final LongArrayMemoryBlock key = LongArrayMemoryBlock.fromLongArray(new long[] { i, i, i }); // 3 * 8 = 24 bytes + final LongArrayMemoryBlock value = LongArrayMemoryBlock.fromLongArray(new long[] { i, i, i, i, i }); // 5 * 8 = 40 bytes final BytesToBytesMap.Location loc = map.lookup( key, Platform.LONG_ARRAY_OFFSET, @@ -357,8 +362,8 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { final java.util.BitSet valuesSeen = new java.util.BitSet(NUM_ENTRIES); final Iterator iter = map.iterator(); - final long[] key = new long[KEY_LENGTH / 8]; - final long[] value = new long[VALUE_LENGTH / 8]; + final LongArrayMemoryBlock key = LongArrayMemoryBlock.fromLongArray(new long[KEY_LENGTH / 8]); + final LongArrayMemoryBlock value = LongArrayMemoryBlock.fromLongArray(new long[VALUE_LENGTH / 8]); while (iter.hasNext()) { final BytesToBytesMap.Location loc = iter.next(); Assert.assertTrue(loc.isDefined()); @@ -378,13 +383,13 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { Platform.LONG_ARRAY_OFFSET, VALUE_LENGTH ); - for (long j : key) { - Assert.assertEquals(key[0], j); + for (long j : key.getLongArray()) { + Assert.assertEquals(key.getLongArray()[0], j); } - for (long j : value) { - Assert.assertEquals(key[0], j); + for (long j : value.getLongArray()) { + Assert.assertEquals(key.getLongArray()[0], j); } - valuesSeen.set((int) key[0]); + valuesSeen.set((int) key.getLongArray()[0]); } Assert.assertEquals(NUM_ENTRIES, valuesSeen.cardinality()); } finally { @@ -402,45 +407,45 @@ public void randomizedStressTest() { try { // Fill the map to 90% full so that we can trigger probing for (int i = 0; i < size * 0.9; i++) { - final byte[] key = getRandomByteArray(rand.nextInt(256) + 1); - final byte[] value = getRandomByteArray(rand.nextInt(512) + 1); - if (!expected.containsKey(ByteBuffer.wrap(key))) { - expected.put(ByteBuffer.wrap(key), value); + final ByteArrayMemoryBlock key = ByteArrayMemoryBlock.fromByteArray(getRandomByteArray(rand.nextInt(256) + 1)); + final ByteArrayMemoryBlock value = ByteArrayMemoryBlock.fromByteArray(getRandomByteArray(rand.nextInt(512) + 1)); + if (!expected.containsKey(ByteBuffer.wrap(key.getByteArray()))) { + expected.put(ByteBuffer.wrap(key.getByteArray()), value.getByteArray()); final BytesToBytesMap.Location loc = map.lookup( key, Platform.BYTE_ARRAY_OFFSET, - key.length + key.getByteArray().length ); Assert.assertFalse(loc.isDefined()); Assert.assertTrue(loc.putNewKey( key, Platform.BYTE_ARRAY_OFFSET, - key.length, + key.getByteArray().length, value, Platform.BYTE_ARRAY_OFFSET, - value.length + value.getByteArray().length )); // After calling putNewKey, the following should be true, even before calling // lookup(): Assert.assertTrue(loc.isDefined()); - Assert.assertEquals(key.length, loc.getKeyLength()); - Assert.assertEquals(value.length, loc.getValueLength()); - Assert.assertTrue(arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), key.length)); + Assert.assertEquals(key.getByteArray().length, loc.getKeyLength()); + Assert.assertEquals(value.getByteArray().length, loc.getValueLength()); + Assert.assertTrue(arrayEquals(key.getByteArray(), loc.getKeyBase(), loc.getKeyOffset(), key.getByteArray().length)); Assert.assertTrue( - arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), value.length)); + arrayEquals(value.getByteArray(), loc.getValueBase(), loc.getValueOffset(), value.getByteArray().length)); } } for (Map.Entry entry : expected.entrySet()) { - final byte[] key = JavaUtils.bufferToArray(entry.getKey()); - final byte[] value = entry.getValue(); + final ByteArrayMemoryBlock key = ByteArrayMemoryBlock.fromByteArray(JavaUtils.bufferToArray(entry.getKey())); + final ByteArrayMemoryBlock value = ByteArrayMemoryBlock.fromByteArray(entry.getValue()); final BytesToBytesMap.Location loc = - map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); + map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.getByteArray().length); Assert.assertTrue(loc.isDefined()); Assert.assertTrue( - arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength())); + arrayEquals(key.getByteArray(), loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength())); Assert.assertTrue( - arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), loc.getValueLength())); + arrayEquals(value.getByteArray(), loc.getValueBase(), loc.getValueOffset(), loc.getValueLength())); } } finally { map.free(); @@ -456,44 +461,44 @@ public void randomizedTestWithRecordsLargerThanPageSize() { final Map expected = new HashMap(); try { for (int i = 0; i < 1000; i++) { - final byte[] key = getRandomByteArray(rand.nextInt(128)); - final byte[] value = getRandomByteArray(rand.nextInt(128)); - if (!expected.containsKey(ByteBuffer.wrap(key))) { - expected.put(ByteBuffer.wrap(key), value); + final ByteArrayMemoryBlock key = ByteArrayMemoryBlock.fromByteArray(getRandomByteArray(rand.nextInt(128))); + final ByteArrayMemoryBlock value = ByteArrayMemoryBlock.fromByteArray(getRandomByteArray(rand.nextInt(128))); + if (!expected.containsKey(ByteBuffer.wrap(key.getByteArray()))) { + expected.put(ByteBuffer.wrap(key.getByteArray()), value.getByteArray()); final BytesToBytesMap.Location loc = map.lookup( key, Platform.BYTE_ARRAY_OFFSET, - key.length + key.getByteArray().length ); Assert.assertFalse(loc.isDefined()); Assert.assertTrue(loc.putNewKey( key, Platform.BYTE_ARRAY_OFFSET, - key.length, + key.getByteArray().length, value, Platform.BYTE_ARRAY_OFFSET, - value.length + value.getByteArray().length )); // After calling putNewKey, the following should be true, even before calling // lookup(): Assert.assertTrue(loc.isDefined()); - Assert.assertEquals(key.length, loc.getKeyLength()); - Assert.assertEquals(value.length, loc.getValueLength()); - Assert.assertTrue(arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), key.length)); + Assert.assertEquals(key.getByteArray().length, loc.getKeyLength()); + Assert.assertEquals(value.getByteArray().length, loc.getValueLength()); + Assert.assertTrue(arrayEquals(key.getByteArray(), loc.getKeyBase(), loc.getKeyOffset(), key.getByteArray().length)); Assert.assertTrue( - arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), value.length)); + arrayEquals(value.getByteArray(), loc.getValueBase(), loc.getValueOffset(), value.getByteArray().length)); } } for (Map.Entry entry : expected.entrySet()) { - final byte[] key = JavaUtils.bufferToArray(entry.getKey()); - final byte[] value = entry.getValue(); + final ByteArrayMemoryBlock key = ByteArrayMemoryBlock.fromByteArray(JavaUtils.bufferToArray(entry.getKey())); + final ByteArrayMemoryBlock value = ByteArrayMemoryBlock.fromByteArray(entry.getValue()); final BytesToBytesMap.Location loc = - map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); + map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.getByteArray().length); Assert.assertTrue(loc.isDefined()); Assert.assertTrue( - arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength())); + arrayEquals(key.getByteArray(), loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength())); Assert.assertTrue( - arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), loc.getValueLength())); + arrayEquals(value.getByteArray(), loc.getValueBase(), loc.getValueOffset(), loc.getValueLength())); } } finally { map.free(); @@ -505,7 +510,7 @@ public void failureToAllocateFirstPage() { memoryManager.limit(1024); // longArray BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, PAGE_SIZE_BYTES); try { - final long[] emptyArray = new long[0]; + final LongArrayMemoryBlock emptyArray = LongArrayMemoryBlock.fromLongArray(new long[0]); final BytesToBytesMap.Location loc = map.lookup(emptyArray, Platform.LONG_ARRAY_OFFSET, 0); Assert.assertFalse(loc.isDefined()); @@ -527,7 +532,7 @@ public void failureToGrow() { if (i > 0) { memoryManager.limit(0); } - final long[] arr = new long[]{i}; + final LongArrayMemoryBlock arr = LongArrayMemoryBlock.fromLongArray(new long[]{i}); final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8); success = loc.putNewKey(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8); @@ -548,7 +553,7 @@ public void spillInIterator() throws IOException { try { int i; for (i = 0; i < 1024; i++) { - final long[] arr = new long[]{i}; + final LongArrayMemoryBlock arr = LongArrayMemoryBlock.fromLongArray(new long[]{i}); final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8); loc.putNewKey(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8); } @@ -615,7 +620,7 @@ public void testPeakMemoryUsed() { long newPeakMemory; try { for (long i = 0; i < numRecordsPerPage * 10; i++) { - final long[] value = new long[]{i}; + final LongArrayMemoryBlock value = LongArrayMemoryBlock.fromLongArray(new long[]{i}); map.lookup(value, Platform.LONG_ARRAY_OFFSET, 8).putNewKey( value, Platform.LONG_ARRAY_OFFSET, diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 492fe49ba4c4f..1bff52fc01053 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -25,6 +25,7 @@ import java.util.LinkedList; import java.util.UUID; +import org.apache.spark.unsafe.memory.*; import scala.Tuple2; import scala.Tuple2$; import scala.runtime.AbstractFunction1; @@ -72,9 +73,9 @@ public int compare(long prefix1, long prefix2) { final RecordComparator recordComparator = new RecordComparator() { @Override public int compare( - Object leftBaseObject, + MemoryBlock leftBaseObject, long leftBaseOffset, - Object rightBaseObject, + MemoryBlock rightBaseObject, long rightBaseOffset) { return 0; } @@ -157,7 +158,7 @@ private void assertSpillFilesWereCleanedUp() { } private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception { - final int[] arr = new int[]{ value }; + final IntArrayMemoryBlock arr = IntArrayMemoryBlock.fromIntArray(new int[]{ value }); sorter.insertRecord(arr, Platform.INT_ARRAY_OFFSET, 4, value); } @@ -165,7 +166,7 @@ private static void insertRecord( UnsafeExternalSorter sorter, int[] record, long prefix) throws IOException { - sorter.insertRecord(record, Platform.INT_ARRAY_OFFSET, record.length * 4, prefix); + sorter.insertRecord(IntArrayMemoryBlock.fromIntArray(record), Platform.INT_ARRAY_OFFSET, record.length * 4, prefix); } private UnsafeExternalSorter newSorter() throws IOException { @@ -206,13 +207,14 @@ public void testSortingOnlyByPrefix() throws Exception { @Test public void testSortingEmptyArrays() throws Exception { final UnsafeExternalSorter sorter = newSorter(); - sorter.insertRecord(null, 0, 0, 0); - sorter.insertRecord(null, 0, 0, 0); + OffHeapMemoryBlock b = new OffHeapMemoryBlock(null, 0, 0); + sorter.insertRecord(b, 0, 0, 0); + sorter.insertRecord(b, 0, 0, 0); sorter.spill(); - sorter.insertRecord(null, 0, 0, 0); + sorter.insertRecord(b, 0, 0, 0); sorter.spill(); - sorter.insertRecord(null, 0, 0, 0); - sorter.insertRecord(null, 0, 0, 0); + sorter.insertRecord(b, 0, 0, 0); + sorter.insertRecord(b, 0, 0, 0); UnsafeSorterIterator iter = sorter.getSortedIterator(); @@ -259,9 +261,9 @@ public void spillingOccursInResponseToMemoryPressure() throws Exception { @Test public void testFillingPage() throws Exception { final UnsafeExternalSorter sorter = newSorter(); - byte[] record = new byte[16]; + ByteArrayMemoryBlock record = ByteArrayMemoryBlock.fromByteArray(new byte[16]); while (sorter.getNumberOfAllocatedPages() < 2) { - sorter.insertRecord(record, Platform.BYTE_ARRAY_OFFSET, record.length, 0); + sorter.insertRecord(record, Platform.BYTE_ARRAY_OFFSET, record.getByteArray().length, 0); } sorter.cleanupResources(); assertSpillFilesWereCleanedUp(); @@ -316,11 +318,11 @@ public void sortingRecordsThatExceedPageSize() throws Exception { @Test public void forcedSpillingWithReadIterator() throws Exception { final UnsafeExternalSorter sorter = newSorter(); - long[] record = new long[100]; - int recordSize = record.length * 8; + LongArrayMemoryBlock record = LongArrayMemoryBlock.fromLongArray(new long[100]); + int recordSize = record.getLongArray().length * 8; int n = (int) pageSizeBytes / recordSize * 3; for (int i = 0; i < n; i++) { - record[0] = (long) i; + record.getLongArray()[0] = (long) i; sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0); } assertTrue(sorter.getNumberOfAllocatedPages() >= 2); @@ -348,11 +350,11 @@ public void forcedSpillingWithReadIterator() throws Exception { @Test public void forcedSpillingWithNotReadIterator() throws Exception { final UnsafeExternalSorter sorter = newSorter(); - long[] record = new long[100]; - int recordSize = record.length * 8; + LongArrayMemoryBlock record = LongArrayMemoryBlock.fromLongArray(new long[100]); + int recordSize = record.getLongArray().length * 8; int n = (int) pageSizeBytes / recordSize * 3; for (int i = 0; i < n; i++) { - record[0] = (long) i; + record.getLongArray()[0] = (long) i; sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0); } assertTrue(sorter.getNumberOfAllocatedPages() >= 2); @@ -379,12 +381,12 @@ public void forcedSpillingWithoutComparator() throws Exception { null, /* initialSize */ 1024, pageSizeBytes); - long[] record = new long[100]; - int recordSize = record.length * 8; + LongArrayMemoryBlock record = LongArrayMemoryBlock.fromLongArray(new long[100]); + int recordSize = record.getLongArray().length * 8; int n = (int) pageSizeBytes / recordSize * 3; int batch = n / 4; for (int i = 0; i < n; i++) { - record[0] = (long) i; + record.getLongArray()[0] = (long) i; sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0); if (i % batch == batch - 1) { sorter.spill(); diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index ff41768df1d8f..eeaf8dbda98ce 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -38,7 +38,7 @@ public class UnsafeInMemorySorterSuite { - private static String getStringFromDataPage(Object baseObject, long baseOffset, int length) { + private static String getStringFromDataPage(MemoryBlock baseObject, long baseOffset, int length) { final byte[] strBytes = new byte[length]; Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, length); return new String(strBytes); @@ -75,7 +75,7 @@ public void testSortingOnlyByIntegerPrefix() throws Exception { new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); final MemoryBlock dataPage = memoryManager.allocatePage(2048, null); - final Object baseObject = dataPage.getBaseObject(); + final MemoryBlock baseObject = dataPage; // Write the records into the data page: long position = dataPage.getBaseOffset(); for (String str : dataToSort) { @@ -91,9 +91,9 @@ public void testSortingOnlyByIntegerPrefix() throws Exception { final RecordComparator recordComparator = new RecordComparator() { @Override public int compare( - Object leftBaseObject, + MemoryBlock leftBaseObject, long leftBaseOffset, - Object rightBaseObject, + MemoryBlock rightBaseObject, long rightBaseOffset) { return 0; } From 4a4e62efb730a18612176d3ceeb20970baad77d7 Mon Sep 17 00:00:00 2001 From: Zotov Yuriy Date: Thu, 3 Mar 2016 15:26:04 +0300 Subject: [PATCH 4/4] spark.unsafe.Platform interface changed, BufferHolder is replaced by MemoryBlockHolder, OffHeapMemoryBlock uses DirectByteBuffer for off-heap memory allocation, UnsafeRow and others hold memory in MemoryBlocks instead of indefinite Objects --- .../catalyst/expressions/UnsafeArrayData.java | 18 +- .../catalyst/expressions/UnsafeMapData.java | 16 +- .../sql/catalyst/expressions/UnsafeRow.java | 77 ++++----- .../expressions/codegen/BufferHolder.java | 80 --------- .../codegen/UnsafeArrayWriter.java | 36 ++-- .../expressions/codegen/UnsafeRowWriter.java | 62 +++---- .../execution/UnsafeExternalRowSorter.java | 3 +- .../codegen/GenerateUnsafeProjection.scala | 9 +- .../codegen/GenerateUnsafeRowJoiner.scala | 14 +- .../GenerateUnsafeRowJoinerBitsetSuite.scala | 3 +- .../UnsafeFixedWidthAggregationMap.java | 10 +- .../sql/execution/UnsafeKVExternalSorter.java | 6 +- .../parquet/UnsafeRowParquetRecordReader.java | 4 +- .../vectorized/OffHeapColumnVector.java | 154 +++++++++--------- .../sql/execution/UnsafeRowSerializer.scala | 17 +- .../sql/execution/columnar/ColumnType.scala | 7 +- .../columnar/GenerateColumnAccessor.scala | 4 +- .../datasources/text/DefaultSource.scala | 4 +- .../sql/execution/joins/HashedRelation.scala | 22 +-- .../org/apache/spark/sql/UnsafeRowSuite.scala | 22 +-- .../BenchmarkWholeStageCodegen.scala | 23 +-- .../vectorized/ColumnarBatchBenchmark.scala | 17 +- .../vectorized/ColumnarBatchSuite.scala | 20 ++- 23 files changed, 286 insertions(+), 342 deletions(-) delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 648625b2cc5d2..abcde9cd517ac 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -26,6 +26,8 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -49,7 +51,7 @@ // todo: there is a lof of duplicated code between UnsafeRow and UnsafeArrayData. public class UnsafeArrayData extends ArrayData { - private Object baseObject; + private MemoryBlock baseObject; private long baseOffset; // The number of elements in this array @@ -101,7 +103,7 @@ public UnsafeArrayData() { } * @param baseOffset the offset within the base object * @param sizeInBytes the size of this array's backing data, in bytes */ - public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { + public void pointTo(MemoryBlock baseObject, long baseOffset, int sizeInBytes) { // Read the number of elements from the first 4 bytes. final int numElements = Platform.getInt(baseObject, baseOffset); assert numElements >= 0 : "numElements (" + numElements + ") should >= 0"; @@ -314,7 +316,11 @@ public boolean equals(Object other) { return false; } - public void writeToMemory(Object target, long targetOffset) { + public void writeToMemory(byte[] target, long targetOffset) { + Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); + } + + public void writeToMemory(MemoryBlock target, long targetOffset) { Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); } @@ -330,10 +336,10 @@ public void writeTo(ByteBuffer buffer) { @Override public UnsafeArrayData copy() { UnsafeArrayData arrayCopy = new UnsafeArrayData(); - final byte[] arrayDataCopy = new byte[sizeInBytes]; + ByteArrayMemoryBlock newBlock = ByteArrayMemoryBlock.fromByteArray(new byte[sizeInBytes]); Platform.copyMemory( - baseObject, baseOffset, arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); - arrayCopy.pointTo(arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); + baseObject, baseOffset, newBlock, newBlock.getBaseOffset(), sizeInBytes); + arrayCopy.pointTo(newBlock, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); return arrayCopy; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java index 651eb1ff0c561..a62d12cc68d5c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java @@ -21,6 +21,8 @@ import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; /** * An Unsafe implementation of Map which is backed by raw memory instead of Java objects. @@ -32,7 +34,7 @@ // TODO: Use a more efficient format which doesn't depend on unsafe array. public class UnsafeMapData extends MapData { - private Object baseObject; + private MemoryBlock baseObject; private long baseOffset; // The size of this map's backing data, in bytes. @@ -40,7 +42,7 @@ public class UnsafeMapData extends MapData { // 4 + key array numBytes + value array numBytes. private int sizeInBytes; - public Object getBaseObject() { return baseObject; } + public MemoryBlock getBaseObject() { return baseObject; } public long getBaseOffset() { return baseOffset; } public int getSizeInBytes() { return sizeInBytes; } @@ -64,7 +66,7 @@ public UnsafeMapData() { * @param baseOffset the offset within the base object * @param sizeInBytes the size of this map's backing data, in bytes */ - public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { + public void pointTo(MemoryBlock baseObject, long baseOffset, int sizeInBytes) { // Read the numBytes of key array from the first 4 bytes. final int keyArraySize = Platform.getInt(baseObject, baseOffset); final int valueArraySize = sizeInBytes - keyArraySize - 4; @@ -96,7 +98,11 @@ public UnsafeArrayData valueArray() { return values; } - public void writeToMemory(Object target, long targetOffset) { + public void writeToMemory(byte[] target, long targetOffset) { + Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); + } + + public void writeToMemory(MemoryBlock target, long targetOffset) { Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); } @@ -112,7 +118,7 @@ public void writeTo(ByteBuffer buffer) { @Override public UnsafeMapData copy() { UnsafeMapData mapCopy = new UnsafeMapData(); - final byte[] mapDataCopy = new byte[sizeInBytes]; + final ByteArrayMemoryBlock mapDataCopy = ByteArrayMemoryBlock.fromByteArray(new byte[sizeInBytes]); Platform.copyMemory( baseObject, baseOffset, mapDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); mapCopy.pointTo(mapDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index a88bcbfdb7ccb..eebfc918b49e8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -36,6 +36,8 @@ import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -111,7 +113,7 @@ public static boolean isMutable(DataType dt) { // Private fields and methods ////////////////////////////////////////////////////////////////////////////// - private Object baseObject; + private MemoryBlock baseObject; private long baseOffset; /** The number of fields in this row, used for calculating the bitset width (and in assertions) */ @@ -150,35 +152,19 @@ public UnsafeRow(int numFields) { // for serializer public UnsafeRow() {} - public Object getBaseObject() { return baseObject; } + public MemoryBlock getBaseObject() { return baseObject; } public long getBaseOffset() { return baseOffset; } public int getSizeInBytes() { return sizeInBytes; } @Override public int numFields() { return numFields; } - /** - * Update this UnsafeRow to point to different backing data. - * - * @param baseObject the base object - * @param baseOffset the offset within the base object - * @param sizeInBytes the size of this row's backing data, in bytes - */ - public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { + public void pointTo( MemoryBlock aBlock, long anOffset, int aSizeInBytes ) { assert numFields >= 0 : "numFields (" + numFields + ") should >= 0"; - this.baseObject = baseObject; - this.baseOffset = baseOffset; - this.sizeInBytes = sizeInBytes; - } + this.baseObject = aBlock; + this.baseOffset = anOffset; + this.sizeInBytes = aSizeInBytes; - /** - * Update this UnsafeRow to point to the underlying byte array. - * - * @param buf byte array to point to - * @param sizeInBytes the number of bytes valid in the byte array - */ - public void pointTo(byte[] buf, int sizeInBytes) { - pointTo(buf, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); } public void setTotalSize(int sizeInBytes) { @@ -500,15 +486,15 @@ public UnsafeMapData getMap(int ordinal) { @Override public UnsafeRow copy() { UnsafeRow rowCopy = new UnsafeRow(numFields); - final byte[] rowDataCopy = new byte[sizeInBytes]; + ByteArrayMemoryBlock block = ByteArrayMemoryBlock.fromByteArray(new byte[sizeInBytes]); Platform.copyMemory( baseObject, baseOffset, - rowDataCopy, - Platform.BYTE_ARRAY_OFFSET, + block, + block.getBaseOffset(), sizeInBytes ); - rowCopy.pointTo(rowDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); + rowCopy.pointTo(block, block.getBaseOffset(), sizeInBytes); return rowCopy; } @@ -518,7 +504,8 @@ public UnsafeRow copy() { */ public static UnsafeRow createFromByteArray(int numBytes, int numFields) { final UnsafeRow row = new UnsafeRow(numFields); - row.pointTo(new byte[numBytes], numBytes); + ByteArrayMemoryBlock block = ByteArrayMemoryBlock.fromByteArray(new byte[numBytes]); + row.pointTo(block, block.getBaseOffset(), numBytes); return row; } @@ -528,13 +515,15 @@ public static UnsafeRow createFromByteArray(int numBytes, int numFields) { */ public void copyFrom(UnsafeRow row) { // copyFrom is only available for UnsafeRow created from byte array. - assert (baseObject instanceof byte[]) && baseOffset == Platform.BYTE_ARRAY_OFFSET; + assert (baseObject instanceof ByteArrayMemoryBlock); if (row.sizeInBytes > this.sizeInBytes) { // resize the underlying byte[] if it's not large enough. - this.baseObject = new byte[row.sizeInBytes]; + this.baseObject = ByteArrayMemoryBlock.fromByteArray(new byte[row.sizeInBytes]); } Platform.copyMemory( - row.baseObject, row.baseOffset, this.baseObject, this.baseOffset, row.sizeInBytes); + row.baseObject, row.baseOffset, + this.baseObject, this.baseOffset, + row.sizeInBytes ); // update the sizeInBytes. this.sizeInBytes = row.sizeInBytes; } @@ -548,9 +537,9 @@ public void copyFrom(UnsafeRow row) { * buffer will not be used and may be null. */ public void writeToStream(OutputStream out, byte[] writeBuffer) throws IOException { - if (baseObject instanceof byte[]) { - int offsetInByteArray = (int) (Platform.BYTE_ARRAY_OFFSET - baseOffset); - out.write((byte[]) baseObject, offsetInByteArray, sizeInBytes); + if (baseObject instanceof ByteArrayMemoryBlock) { + int startIndex = (int) (baseOffset - baseObject.getBaseOffset()); + out.write((byte[])baseObject.getBaseObject(), startIndex, sizeInBytes); } else { int dataRemaining = sizeInBytes; long rowReadPosition = baseOffset; @@ -585,9 +574,11 @@ public boolean equals(Object other) { * Returns the underlying bytes for this UnsafeRow. */ public byte[] getBytes() { - if (baseObject instanceof byte[] && baseOffset == Platform.BYTE_ARRAY_OFFSET - && (((byte[]) baseObject).length == sizeInBytes)) { - return (byte[]) baseObject; + if ( baseObject instanceof ByteArrayMemoryBlock + && baseOffset == Platform.BYTE_ARRAY_OFFSET + && (((byte[]) baseObject.getBaseObject()).length == sizeInBytes) ) + { + return (byte[]) baseObject.getBaseObject(); } else { byte[] bytes = new byte[sizeInBytes]; Platform.copyMemory(baseObject, baseOffset, bytes, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); @@ -617,7 +608,11 @@ public boolean anyNull() { * The target memory address must already been allocated, and have enough space to hold all the * bytes in this string. */ - public void writeToMemory(Object target, long targetOffset) { + public void writeToMemory(byte[] target, long targetOffset) { + Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); + } + + public void writeToMemory(MemoryBlock target, long targetOffset) { Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); } @@ -665,8 +660,8 @@ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundExcept this.sizeInBytes = in.readInt(); this.numFields = in.readInt(); this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); - this.baseObject = new byte[sizeInBytes]; - in.readFully((byte[]) baseObject); + this.baseObject = ByteArrayMemoryBlock.fromByteArray(new byte[sizeInBytes]); + in.readFully((byte[]) baseObject.getBaseObject()); } @Override @@ -683,7 +678,7 @@ public void read(Kryo kryo, Input in) { this.sizeInBytes = in.readInt(); this.numFields = in.readInt(); this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); - this.baseObject = new byte[sizeInBytes]; - in.read((byte[]) baseObject); + this.baseObject = ByteArrayMemoryBlock.fromByteArray(new byte[sizeInBytes]); + in.read((byte[]) baseObject.getBaseObject()); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java deleted file mode 100644 index af61e2011f400..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.codegen; - -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.unsafe.Platform; - -/** - * A helper class to manage the data buffer for an unsafe row. The data buffer can grow and - * automatically re-point the unsafe row to it. - * - * This class can be used to build a one-pass unsafe row writing program, i.e. data will be written - * to the data buffer directly and no extra copy is needed. There should be only one instance of - * this class per writing program, so that the memory segment/data buffer can be reused. Note that - * for each incoming record, we should call `reset` of BufferHolder instance before write the record - * and reuse the data buffer. - * - * Generally we should call `UnsafeRow.setTotalSize` and pass in `BufferHolder.totalSize` to update - * the size of the result row, after writing a record to the buffer. However, we can skip this step - * if the fields of row are all fixed-length, as the size of result row is also fixed. - */ -public class BufferHolder { - public byte[] buffer; - public int cursor = Platform.BYTE_ARRAY_OFFSET; - private final UnsafeRow row; - private final int fixedSize; - - public BufferHolder(UnsafeRow row) { - this(row, 64); - } - - public BufferHolder(UnsafeRow row, int initialSize) { - this.fixedSize = UnsafeRow.calculateBitSetWidthInBytes(row.numFields()) + 8 * row.numFields(); - this.buffer = new byte[fixedSize + initialSize]; - this.row = row; - this.row.pointTo(buffer, buffer.length); - } - - /** - * Grows the buffer by at least neededSize and points the row to the buffer. - */ - public void grow(int neededSize) { - final int length = totalSize() + neededSize; - if (buffer.length < length) { - // This will not happen frequently, because the buffer is re-used. - final byte[] tmp = new byte[length * 2]; - Platform.copyMemory( - buffer, - Platform.BYTE_ARRAY_OFFSET, - tmp, - Platform.BYTE_ARRAY_OFFSET, - totalSize()); - buffer = tmp; - row.pointTo(buffer, buffer.length); - } - } - - public void reset() { - cursor = Platform.BYTE_ARRAY_OFFSET + fixedSize; - } - - public int totalSize() { - return cursor - Platform.BYTE_ARRAY_OFFSET; - } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index 7dd932d1981b7..70e3b9f8cc716 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -28,12 +28,12 @@ */ public class UnsafeArrayWriter { - private BufferHolder holder; + private MemoryBlockHolder holder; // The offset of the global buffer where we start to write this array. private int startingOffset; - public void initialize(BufferHolder holder, int numElements, int fixedElementSize) { + public void initialize(MemoryBlockHolder holder, int numElements, int fixedElementSize) { // We need 4 bytes to store numElements and 4 bytes each element to store offset. final int fixedSize = 4 + 4 * numElements; @@ -41,7 +41,7 @@ public void initialize(BufferHolder holder, int numElements, int fixedElementSiz this.startingOffset = holder.cursor; holder.grow(fixedSize); - Platform.putInt(holder.buffer, holder.cursor, numElements); + Platform.putInt(holder.getBaseObject(), holder.getBaseOffset() + holder.cursor, numElements); holder.cursor += fixedSize; // Grows the global buffer ahead for fixed size data. @@ -55,40 +55,40 @@ private long getElementOffset(int ordinal) { public void setNullAt(int ordinal) { final int relativeOffset = holder.cursor - startingOffset; // Writes negative offset value to represent null element. - Platform.putInt(holder.buffer, getElementOffset(ordinal), -relativeOffset); + Platform.putInt(holder.getBaseObject(), holder.getBaseOffset() + getElementOffset(ordinal), -relativeOffset); } public void setOffset(int ordinal) { final int relativeOffset = holder.cursor - startingOffset; - Platform.putInt(holder.buffer, getElementOffset(ordinal), relativeOffset); + Platform.putInt(holder.getBaseObject(), holder.getBaseOffset() + getElementOffset(ordinal), relativeOffset); } public void write(int ordinal, boolean value) { - Platform.putBoolean(holder.buffer, holder.cursor, value); + Platform.putBoolean(holder.getBaseObject(), holder.getBaseOffset() + holder.cursor, value); setOffset(ordinal); holder.cursor += 1; } public void write(int ordinal, byte value) { - Platform.putByte(holder.buffer, holder.cursor, value); + Platform.putByte(holder.getBaseObject(), holder.getBaseOffset() + holder.cursor, value); setOffset(ordinal); holder.cursor += 1; } public void write(int ordinal, short value) { - Platform.putShort(holder.buffer, holder.cursor, value); + Platform.putShort(holder.getBaseObject(), holder.getBaseOffset() + holder.cursor, value); setOffset(ordinal); holder.cursor += 2; } public void write(int ordinal, int value) { - Platform.putInt(holder.buffer, holder.cursor, value); + Platform.putInt(holder.getBaseObject(), holder.getBaseOffset() + holder.cursor, value); setOffset(ordinal); holder.cursor += 4; } public void write(int ordinal, long value) { - Platform.putLong(holder.buffer, holder.cursor, value); + Platform.putLong(holder.getBaseObject(), holder.getBaseOffset() + holder.cursor, value); setOffset(ordinal); holder.cursor += 8; } @@ -97,7 +97,7 @@ public void write(int ordinal, float value) { if (Float.isNaN(value)) { value = Float.NaN; } - Platform.putFloat(holder.buffer, holder.cursor, value); + Platform.putFloat(holder.getBaseObject(), holder.getBaseOffset() + holder.cursor, value); setOffset(ordinal); holder.cursor += 4; } @@ -106,7 +106,7 @@ public void write(int ordinal, double value) { if (Double.isNaN(value)) { value = Double.NaN; } - Platform.putDouble(holder.buffer, holder.cursor, value); + Platform.putDouble(holder.getBaseObject(), holder.getBaseOffset() + holder.cursor, value); setOffset(ordinal); holder.cursor += 8; } @@ -115,7 +115,7 @@ public void write(int ordinal, Decimal input, int precision, int scale) { // make sure Decimal object has the same scale as DecimalType if (input.changePrecision(precision, scale)) { if (precision <= Decimal.MAX_LONG_DIGITS()) { - Platform.putLong(holder.buffer, holder.cursor, input.toUnscaledLong()); + Platform.putLong(holder.getBaseObject(), holder.getBaseOffset() + holder.cursor, input.toUnscaledLong()); setOffset(ordinal); holder.cursor += 8; } else { @@ -125,7 +125,7 @@ public void write(int ordinal, Decimal input, int precision, int scale) { // Write the bytes to the variable length portion. Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); + bytes, Platform.BYTE_ARRAY_OFFSET, holder.getBaseObject(), holder.getBaseOffset() + holder.cursor, bytes.length); setOffset(ordinal); holder.cursor += bytes.length; } @@ -141,7 +141,7 @@ public void write(int ordinal, UTF8String input) { holder.grow(numBytes); // Write the bytes to the variable length portion. - input.writeToMemory(holder.buffer, holder.cursor); + input.writeToMemory(holder.getBaseObject(), holder.getBaseOffset() + holder.cursor); setOffset(ordinal); @@ -155,7 +155,7 @@ public void write(int ordinal, byte[] input) { // Write the bytes to the variable length portion. Platform.copyMemory( - input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, input.length); + input, Platform.BYTE_ARRAY_OFFSET, holder.getBaseObject(), holder.getBaseOffset() + holder.cursor, input.length); setOffset(ordinal); @@ -168,8 +168,8 @@ public void write(int ordinal, CalendarInterval input) { holder.grow(16); // Write the months and microseconds fields of Interval to the variable length portion. - Platform.putLong(holder.buffer, holder.cursor, input.months); - Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds); + Platform.putLong(holder.getBaseObject(), holder.getBaseOffset() + holder.cursor, input.months); + Platform.putLong(holder.getBaseObject(), holder.getBaseOffset() + holder.cursor + 8, input.microseconds); setOffset(ordinal); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 4776617043878..0cf60e6f62903 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -40,13 +40,13 @@ */ public class UnsafeRowWriter { - private final BufferHolder holder; + private final MemoryBlockHolder holder; // The offset of the global buffer where we start to write this row. private int startingOffset; private final int nullBitsSize; private final int fixedSize; - public UnsafeRowWriter(BufferHolder holder, int numFields) { + public UnsafeRowWriter(MemoryBlockHolder holder, int numFields) { this.holder = holder; this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields); this.fixedSize = nullBitsSize + 8 * numFields; @@ -72,25 +72,25 @@ public void reset() { */ public void zeroOutNullBytes() { for (int i = 0; i < nullBitsSize; i += 8) { - Platform.putLong(holder.buffer, startingOffset + i, 0L); + Platform.putLong(holder.getBaseObject(), holder.getBaseOffset() + startingOffset + i, 0L); } } private void zeroOutPaddingBytes(int numBytes) { if ((numBytes & 0x07) > 0) { - Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); + Platform.putLong(holder.getBaseObject(), holder.getBaseOffset() + holder.cursor + ((numBytes >> 3) << 3), 0L); } } - public BufferHolder holder() { return holder; } + public MemoryBlockHolder holder() { return holder; } public boolean isNullAt(int ordinal) { - return BitSetMethods.isSet(holder.buffer, startingOffset, ordinal); + return BitSetMethods.isSet(holder.getBaseObject(), holder.getBaseOffset() + startingOffset, ordinal); } public void setNullAt(int ordinal) { - BitSetMethods.set(holder.buffer, startingOffset, ordinal); - Platform.putLong(holder.buffer, getFieldOffset(ordinal), 0L); + BitSetMethods.set(holder.getBaseObject(), holder.getBaseOffset() + startingOffset, ordinal); + Platform.putLong(holder.getBaseObject(), holder.getBaseOffset() + getFieldOffset(ordinal), 0L); } public long getFieldOffset(int ordinal) { @@ -106,7 +106,7 @@ public void setOffsetAndSize(int ordinal, long currentCursor, long size) { final long fieldOffset = getFieldOffset(ordinal); final long offsetAndSize = (relativeOffset << 32) | size; - Platform.putLong(holder.buffer, fieldOffset, offsetAndSize); + Platform.putLong(holder.getBaseObject(), holder.getBaseOffset() + fieldOffset, offsetAndSize); } // Do word alignment for this row and grow the row buffer if needed. @@ -119,7 +119,7 @@ public void alignToWords(int numBytes) { holder.grow(paddingBytes); for (int i = 0; i < paddingBytes; i++) { - Platform.putByte(holder.buffer, holder.cursor, (byte) 0); + Platform.putByte(holder.getBaseObject(), holder.getBaseOffset() + holder.cursor, (byte) 0); holder.cursor++; } } @@ -127,30 +127,30 @@ public void alignToWords(int numBytes) { public void write(int ordinal, boolean value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putBoolean(holder.buffer, offset, value); + Platform.putLong(holder.getBaseObject(), holder.getBaseOffset() + offset, 0L); + Platform.putBoolean(holder.getBaseObject(), holder.getBaseOffset() + offset, value); } public void write(int ordinal, byte value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putByte(holder.buffer, offset, value); + Platform.putLong(holder.getBaseObject(), holder.getBaseOffset() + offset, 0L); + Platform.putByte(holder.getBaseObject(), holder.getBaseOffset() + offset, value); } public void write(int ordinal, short value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putShort(holder.buffer, offset, value); + Platform.putLong(holder.getBaseObject(), holder.getBaseOffset() + offset, 0L); + Platform.putShort(holder.getBaseObject(), holder.getBaseOffset() + offset, value); } public void write(int ordinal, int value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putInt(holder.buffer, offset, value); + Platform.putLong(holder.getBaseObject(), holder.getBaseOffset() + offset, 0L); + Platform.putInt(holder.getBaseObject(), holder.getBaseOffset() + offset, value); } public void write(int ordinal, long value) { - Platform.putLong(holder.buffer, getFieldOffset(ordinal), value); + Platform.putLong(holder.getBaseObject(), holder.getBaseOffset() + getFieldOffset(ordinal), value); } public void write(int ordinal, float value) { @@ -158,22 +158,22 @@ public void write(int ordinal, float value) { value = Float.NaN; } final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putFloat(holder.buffer, offset, value); + Platform.putLong(holder.getBaseObject(), holder.getBaseOffset() + offset, 0L); + Platform.putFloat(holder.getBaseObject(), holder.getBaseOffset() + offset, value); } public void write(int ordinal, double value) { if (Double.isNaN(value)) { value = Double.NaN; } - Platform.putDouble(holder.buffer, getFieldOffset(ordinal), value); + Platform.putDouble(holder.getBaseObject(), holder.getBaseOffset() + getFieldOffset(ordinal), value); } public void write(int ordinal, Decimal input, int precision, int scale) { if (precision <= Decimal.MAX_LONG_DIGITS()) { // make sure Decimal object has the same scale as DecimalType if (input.changePrecision(precision, scale)) { - Platform.putLong(holder.buffer, getFieldOffset(ordinal), input.toUnscaledLong()); + Platform.putLong(holder.getBaseObject(), holder.getBaseOffset() + getFieldOffset(ordinal), input.toUnscaledLong()); } else { setNullAt(ordinal); } @@ -182,13 +182,13 @@ public void write(int ordinal, Decimal input, int precision, int scale) { holder.grow(16); // zero-out the bytes - Platform.putLong(holder.buffer, holder.cursor, 0L); - Platform.putLong(holder.buffer, holder.cursor + 8, 0L); + Platform.putLong(holder.getBaseObject(), holder.getBaseOffset() + holder.cursor, 0L); + Platform.putLong(holder.getBaseObject(), holder.getBaseOffset() + holder.cursor + 8, 0L); // Make sure Decimal object has the same scale as DecimalType. // Note that we may pass in null Decimal object to set null for it. if (input == null || !input.changePrecision(precision, scale)) { - BitSetMethods.set(holder.buffer, startingOffset, ordinal); + BitSetMethods.set(holder.getBaseObject(), holder.getBaseOffset() + startingOffset, ordinal); // keep the offset for future update setOffsetAndSize(ordinal, 0L); } else { @@ -197,7 +197,7 @@ public void write(int ordinal, Decimal input, int precision, int scale) { // Write the bytes to the variable length portion. Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); + bytes, Platform.BYTE_ARRAY_OFFSET, holder.getBaseObject(), holder.getBaseOffset() + holder.cursor, bytes.length); setOffsetAndSize(ordinal, bytes.length); } @@ -216,7 +216,7 @@ public void write(int ordinal, UTF8String input) { zeroOutPaddingBytes(numBytes); // Write the bytes to the variable length portion. - input.writeToMemory(holder.buffer, holder.cursor); + input.writeToMemory(holder.getBaseObject(), holder.getBaseOffset() + holder.cursor); setOffsetAndSize(ordinal, numBytes); @@ -238,7 +238,7 @@ public void write(int ordinal, byte[] input, int offset, int numBytes) { // Write the bytes to the variable length portion. Platform.copyMemory(input, Platform.BYTE_ARRAY_OFFSET + offset, - holder.buffer, holder.cursor, numBytes); + holder.getBaseObject(), holder.getBaseOffset() + holder.cursor, numBytes); setOffsetAndSize(ordinal, numBytes); @@ -251,8 +251,8 @@ public void write(int ordinal, CalendarInterval input) { holder.grow(16); // Write the months and microseconds fields of Interval to the variable length portion. - Platform.putLong(holder.buffer, holder.cursor, input.months); - Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds); + Platform.putLong(holder.getBaseObject(), holder.getBaseOffset() + holder.cursor, input.months); + Platform.putLong(holder.getBaseObject(), holder.getBaseOffset() + holder.cursor + 8, input.microseconds); setOffsetAndSize(ordinal, 16); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 0ad0f4976c77a..eb86cf36bc999 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.util.AbstractScalaRowIterator; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; import org.apache.spark.util.collection.unsafe.sort.RecordComparator; import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter; @@ -179,7 +180,7 @@ public RowComparator(Ordering ordering, int numFields) { } @Override - public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { + public int compare(MemoryBlock baseObj1, long baseOff1, MemoryBlock baseObj2, long baseOff2) { // TODO: Why are the sizes -1? row1.pointTo(baseObj1, baseOff1, -1); row2.pointTo(baseObj2, baseOff2, -1); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 6aa9cbf08bdb9..ac9411499d9e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -270,7 +270,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder)} // Write the numBytes of key array into the first 4 bytes. - Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, $bufferHolder.cursor - $tmpCursor); + Platform.putInt($bufferHolder.getBaseObject(), + $bufferHolder.getBaseOffset() + $tmpCursor - 4, + $bufferHolder.cursor - $tmpCursor); ${writeArrayToBuffer(ctx, values, valueType, bufferHolder)} } @@ -287,7 +289,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro final int $sizeInBytes = $input.getSizeInBytes(); // grow the global buffer before writing data. $bufferHolder.grow($sizeInBytes); - $input.writeToMemory($bufferHolder.buffer, $bufferHolder.cursor); + $input.writeToMemory($bufferHolder.getBaseObject(), + $bufferHolder.getBaseOffset() + $bufferHolder.cursor); $bufferHolder.cursor += $sizeInBytes; """ } @@ -309,7 +312,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx.addMutableState("UnsafeRow", result, s"$result = new UnsafeRow(${expressions.length});") val holder = ctx.freshName("holder") - val holderClass = classOf[BufferHolder].getName + val holderClass = classOf[MemoryBlockHolder].getName ctx.addMutableState(holderClass, holder, s"this.$holder = new $holderClass($result, ${numVarLenFields * 32});") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index b1ffbaa3e94ec..10284310ddc89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -163,7 +163,8 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U |} | |class SpecificUnsafeRowJoiner extends ${classOf[UnsafeRowJoiner].getName} { - | private byte[] buf = new byte[64]; + | org.apache.spark.unsafe.memory.ByteArrayMemoryBlock + | buf = org.apache.spark.unsafe.memory.ByteArrayMemoryBlock.fromByteArray(new byte[64]); | private UnsafeRow out = new UnsafeRow(${schema1.size + schema2.size}); | | public UnsafeRow join(UnsafeRow row1, UnsafeRow row2) { @@ -171,13 +172,14 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | // row2: ${schema2.size}, $bitset2Words words in bitset | // output: ${schema1.size + schema2.size} fields, $outputBitsetWords words in bitset | final int sizeInBytes = row1.getSizeInBytes() + row2.getSizeInBytes() - $sizeReduction; - | if (sizeInBytes > buf.length) { - | buf = new byte[sizeInBytes]; + | if (sizeInBytes > buf.size()) { + | buf = org.apache.spark.unsafe.memory.ByteArrayMemoryBlock.fromByteArray( + | new byte[sizeInBytes] ); | } | - | final java.lang.Object obj1 = row1.getBaseObject(); + | final org.apache.spark.unsafe.memory.MemoryBlock obj1 = row1.getBaseObject(); | final long offset1 = row1.getBaseOffset(); - | final java.lang.Object obj2 = row2.getBaseObject(); + | final org.apache.spark.unsafe.memory.MemoryBlock obj2 = row2.getBaseObject(); | final long offset2 = row2.getBaseOffset(); | | $copyBitset @@ -187,7 +189,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | $copyVariableLengthRow2 | $updateOffset | - | out.pointTo(buf, sizeInBytes); + | out.pointTo(buf, buf.getBaseOffset(), sizeInBytes); | | return out; | } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala index f8342214d9ae0..3e9774b8ae0c7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock /** * A test suite for the bitset portion of the row concatenation. @@ -95,7 +96,7 @@ class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite { // Allocate a larger buffer than needed and point the UnsafeRow to somewhere in the middle. // This way we can test the joiner when the input UnsafeRows are not the entire arrays. val offset = numFields * 8 - val buf = new Array[Byte](sizeInBytes + offset) + val buf = ByteArrayMemoryBlock.fromByteArray(new Array[Byte](sizeInBytes + offset)) row.pointTo(buf, Platform.BYTE_ARRAY_OFFSET + offset, sizeInBytes) row } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 57e8218f3b93a..17ef5f969352c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -29,6 +29,8 @@ import org.apache.spark.unsafe.KVIterator; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.map.BytesToBytesMap; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.MemoryLocation; /** * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width. @@ -41,7 +43,7 @@ public final class UnsafeFixedWidthAggregationMap { * An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the * map, we copy this buffer and use it as the value. */ - private final byte[] emptyAggregationBuffer; + private final MemoryBlock emptyAggregationBuffer; private final StructType aggregationBufferSchema; @@ -106,7 +108,7 @@ public UnsafeFixedWidthAggregationMap( // Initialize the buffer for aggregation value final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema); - this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); + this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBaseObject(); } /** @@ -139,8 +141,8 @@ public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow key, int hash) { key.getBaseOffset(), key.getSizeInBytes(), emptyAggregationBuffer, - Platform.BYTE_ARRAY_OFFSET, - emptyAggregationBuffer.length + emptyAggregationBuffer.getBaseOffset(), + (int)emptyAggregationBuffer.size() ); if (!putSucceeded) { return null; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 51e10b0e936b9..3c54af6d1880b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -97,7 +97,7 @@ public UnsafeKVExternalSorter( UnsafeRow row = new UnsafeRow(numKeyFields); while (iter.hasNext()) { final BytesToBytesMap.Location loc = iter.next(); - final Object baseObject = loc.getKeyBase(); + final MemoryBlock baseObject = loc.getKeyBase(); final long baseOffset = loc.getKeyOffset(); // Get encoded memory address @@ -206,7 +206,7 @@ public KVComparator(BaseOrdering ordering, int numKeyFields) { } @Override - public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { + public int compare(MemoryBlock baseObj1, long baseOff1, MemoryBlock baseObj2, long baseOff2) { // Note that since ordering doesn't need the total length of the record, we just pass -1 // into the row. row1.pointTo(baseObj1, baseOff1 + 4, -1); @@ -230,7 +230,7 @@ public boolean next() throws IOException { if (underlying.hasNext()) { underlying.loadNext(); - Object baseObj = underlying.getBaseObject(); + MemoryBlock baseObj = underlying.getBaseObject(); long recordOffset = underlying.getBaseOffset(); int recordLen = underlying.getRecordLength(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index 57dbd7c2ff56f..df4fbbc213808 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -39,7 +39,7 @@ import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; +import org.apache.spark.sql.catalyst.expressions.codegen.MemoryBlockHolder; import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; import org.apache.spark.sql.execution.vectorized.ColumnVector; import org.apache.spark.sql.execution.vectorized.ColumnarBatch; @@ -281,7 +281,7 @@ private void initializeInternal() throws IOException { for (int i = 0; i < rows.length; ++i) { rows[i] = new UnsafeRow(requestedSchema.getFieldCount()); - BufferHolder holder = new BufferHolder(rows[i], numVarLenFields * DEFAULT_VAR_LEN_SIZE); + MemoryBlockHolder holder = new MemoryBlockHolder(rows[i], numVarLenFields * DEFAULT_VAR_LEN_SIZE); rowWriters[i] = new UnsafeRowWriter(holder, requestedSchema.getFieldCount()); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index b06b7f2457b54..fb07c1003aa29 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -23,6 +23,8 @@ import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.OffHeapMemoryBlock; /** * Column data backed using offheap memory. @@ -30,22 +32,22 @@ public final class OffHeapColumnVector extends ColumnVector { // The data stored in these two allocations need to maintain binary compatible. We can // directly pass this buffer to external components. - private long nulls; - private long data; + private OffHeapMemoryBlock nulls; + private OffHeapMemoryBlock data; // Set iff the type is array. - private long lengthData; - private long offsetData; + private OffHeapMemoryBlock lengthData; + private OffHeapMemoryBlock offsetData; protected OffHeapColumnVector(int capacity, DataType type) { super(capacity, type, MemoryMode.OFF_HEAP); if (!ByteOrder.nativeOrder().equals(ByteOrder.LITTLE_ENDIAN)) { throw new NotImplementedException("Only little endian is supported."); } - nulls = 0; - data = 0; - lengthData = 0; - offsetData = 0; + nulls = new OffHeapMemoryBlock(null, 0, 0); + data = new OffHeapMemoryBlock(null, 0, 0); + lengthData = new OffHeapMemoryBlock(null, 0, 0); + offsetData = new OffHeapMemoryBlock(null, 0, 0); reserveInternal(capacity); reset(); @@ -53,24 +55,20 @@ protected OffHeapColumnVector(int capacity, DataType type) { @Override public final long valuesNativeAddress() { - return data; + return data.getBaseOffset(); } @Override public long nullsNativeAddress() { - return nulls; + return nulls.getBaseOffset(); } @Override public final void close() { - Platform.freeMemory(nulls); - Platform.freeMemory(data); - Platform.freeMemory(lengthData); - Platform.freeMemory(offsetData); - nulls = 0; - data = 0; - lengthData = 0; - offsetData = 0; + nulls = new OffHeapMemoryBlock(null, 0, 0); + data = new OffHeapMemoryBlock(null, 0, 0); + lengthData = new OffHeapMemoryBlock(null, 0, 0); + offsetData = new OffHeapMemoryBlock(null, 0, 0); } // @@ -79,21 +77,21 @@ public final void close() { @Override public final void putNotNull(int rowId) { - Platform.putByte(null, nulls + rowId, (byte) 0); + Platform.putByte(nulls, nulls.getBaseOffset() + rowId, (byte) 0); } @Override public final void putNull(int rowId) { - Platform.putByte(null, nulls + rowId, (byte) 1); + Platform.putByte(nulls, nulls.getBaseOffset() + rowId, (byte) 1); ++numNulls; anyNullsSet = true; } @Override public final void putNulls(int rowId, int count) { - long offset = nulls + rowId; + long offset = nulls.getBaseOffset() + rowId; for (int i = 0; i < count; ++i, ++offset) { - Platform.putByte(null, offset, (byte) 1); + Platform.putByte(nulls, offset, (byte) 1); } anyNullsSet = true; numNulls += count; @@ -102,15 +100,15 @@ public final void putNulls(int rowId, int count) { @Override public final void putNotNulls(int rowId, int count) { if (!anyNullsSet) return; - long offset = nulls + rowId; + long offset = nulls.getBaseOffset() + rowId; for (int i = 0; i < count; ++i, ++offset) { - Platform.putByte(null, offset, (byte) 0); + Platform.putByte(nulls, offset, (byte) 0); } } @Override public final boolean getIsNull(int rowId) { - return Platform.getByte(null, nulls + rowId) == 1; + return Platform.getByte(nulls, nulls.getBaseOffset() + rowId) == 1; } // @@ -119,19 +117,21 @@ public final boolean getIsNull(int rowId) { @Override public final void putBoolean(int rowId, boolean value) { - Platform.putByte(null, data + rowId, (byte)((value) ? 1 : 0)); + Platform.putByte(data, data.getBaseOffset() + rowId, (byte)((value) ? 1 : 0)); } @Override public final void putBooleans(int rowId, int count, boolean value) { byte v = (byte)((value) ? 1 : 0); for (int i = 0; i < count; ++i) { - Platform.putByte(null, data + rowId + i, v); + Platform.putByte(data, data.getBaseOffset() + rowId + i, v); } } @Override - public final boolean getBoolean(int rowId) { return Platform.getByte(null, data + rowId) == 1; } + public final boolean getBoolean(int rowId) { + return Platform.getByte(data, data.getBaseOffset() + rowId) == 1; + } // // APIs dealing with Bytes @@ -139,26 +139,26 @@ public final void putBooleans(int rowId, int count, boolean value) { @Override public final void putByte(int rowId, byte value) { - Platform.putByte(null, data + rowId, value); + Platform.putByte(data, data.getBaseOffset() + rowId, value); } @Override public final void putBytes(int rowId, int count, byte value) { for (int i = 0; i < count; ++i) { - Platform.putByte(null, data + rowId + i, value); + Platform.putByte(data, data.getBaseOffset() + rowId + i, value); } } @Override public final void putBytes(int rowId, int count, byte[] src, int srcIndex) { - Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, null, data + rowId, count); + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, data, data.getBaseOffset() + rowId, count); } @Override public final byte getByte(int rowId) { if (dictionary == null) { - return Platform.getByte(null, data + rowId); + return Platform.getByte(data, data.getBaseOffset() + rowId); } else { return (byte) dictionary.decodeToInt(dictionaryIds.getInt(rowId)); } @@ -170,27 +170,27 @@ public final byte getByte(int rowId) { @Override public final void putShort(int rowId, short value) { - Platform.putShort(null, data + 2 * rowId, value); + Platform.putShort(data, data.getBaseOffset() + 2 * rowId, value); } @Override public final void putShorts(int rowId, int count, short value) { - long offset = data + 2 * rowId; + long offset = data.getBaseOffset() + 2 * rowId; for (int i = 0; i < count; ++i, offset += 4) { - Platform.putShort(null, offset, value); + Platform.putShort(data, offset, value); } } @Override public final void putShorts(int rowId, int count, short[] src, int srcIndex) { Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2, - null, data + 2 * rowId, count * 2); + data, data.getBaseOffset() + 2 * rowId, count * 2); } @Override public final short getShort(int rowId) { if (dictionary == null) { - return Platform.getShort(null, data + 2 * rowId); + return Platform.getShort(data, data.getBaseOffset() + 2 * rowId); } else { return (short) dictionary.decodeToInt(dictionaryIds.getInt(rowId)); } @@ -202,33 +202,33 @@ public final short getShort(int rowId) { @Override public final void putInt(int rowId, int value) { - Platform.putInt(null, data + 4 * rowId, value); + Platform.putInt(data, data.getBaseOffset() + 4 * rowId, value); } @Override public final void putInts(int rowId, int count, int value) { - long offset = data + 4 * rowId; + long offset = data.getBaseOffset() + 4 * rowId; for (int i = 0; i < count; ++i, offset += 4) { - Platform.putInt(null, offset, value); + Platform.putInt(data, offset, value); } } @Override public final void putInts(int rowId, int count, int[] src, int srcIndex) { Platform.copyMemory(src, Platform.INT_ARRAY_OFFSET + srcIndex * 4, - null, data + 4 * rowId, count * 4); + data, data.getBaseOffset() + 4 * rowId, count * 4); } @Override public final void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET, - null, data + 4 * rowId, count * 4); + data, data.getBaseOffset() + 4 * rowId, count * 4); } @Override public final int getInt(int rowId) { if (dictionary == null) { - return Platform.getInt(null, data + 4 * rowId); + return Platform.getInt(data, data.getBaseOffset() + 4 * rowId); } else { return dictionary.decodeToInt(dictionaryIds.getInt(rowId)); } @@ -240,33 +240,33 @@ public final int getInt(int rowId) { @Override public final void putLong(int rowId, long value) { - Platform.putLong(null, data + 8 * rowId, value); + Platform.putLong(data, data.getBaseOffset() + 8 * rowId, value); } @Override public final void putLongs(int rowId, int count, long value) { - long offset = data + 8 * rowId; + long offset = data.getBaseOffset() + 8 * rowId; for (int i = 0; i < count; ++i, offset += 8) { - Platform.putLong(null, offset, value); + Platform.putLong(data, offset, value); } } @Override public final void putLongs(int rowId, int count, long[] src, int srcIndex) { Platform.copyMemory(src, Platform.LONG_ARRAY_OFFSET + srcIndex * 8, - null, data + 8 * rowId, count * 8); + data, data.getBaseOffset() + 8 * rowId, count * 8); } @Override public final void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET, - null, data + 8 * rowId, count * 8); + data, data.getBaseOffset() + 8 * rowId, count * 8); } @Override public final long getLong(int rowId) { if (dictionary == null) { - return Platform.getLong(null, data + 8 * rowId); + return Platform.getLong(data, data.getBaseOffset() + 8 * rowId); } else { return dictionary.decodeToLong(dictionaryIds.getInt(rowId)); } @@ -278,33 +278,33 @@ public final long getLong(int rowId) { @Override public final void putFloat(int rowId, float value) { - Platform.putFloat(null, data + rowId * 4, value); + Platform.putFloat(data, data.getBaseOffset() + rowId * 4, value); } @Override public final void putFloats(int rowId, int count, float value) { - long offset = data + 4 * rowId; + long offset = data.getBaseOffset() + 4 * rowId; for (int i = 0; i < count; ++i, offset += 4) { - Platform.putFloat(null, offset, value); + Platform.putFloat(data, offset, value); } } @Override public final void putFloats(int rowId, int count, float[] src, int srcIndex) { Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4, - null, data + 4 * rowId, count * 4); + data, data.getBaseOffset() + 4 * rowId, count * 4); } @Override public final void putFloats(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 4, count * 4); + data, data.getBaseOffset() + rowId * 4, count * 4); } @Override public final float getFloat(int rowId) { if (dictionary == null) { - return Platform.getFloat(null, data + rowId * 4); + return Platform.getFloat(data, data.getBaseOffset() + rowId * 4); } else { return dictionary.decodeToFloat(dictionaryIds.getInt(rowId)); } @@ -317,33 +317,33 @@ public final float getFloat(int rowId) { @Override public final void putDouble(int rowId, double value) { - Platform.putDouble(null, data + rowId * 8, value); + Platform.putDouble(data, data.getBaseOffset() + rowId * 8, value); } @Override public final void putDoubles(int rowId, int count, double value) { - long offset = data + 8 * rowId; + long offset = data.getBaseOffset() + 8 * rowId; for (int i = 0; i < count; ++i, offset += 8) { - Platform.putDouble(null, offset, value); + Platform.putDouble(data, offset, value); } } @Override public final void putDoubles(int rowId, int count, double[] src, int srcIndex) { Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex * 8, - null, data + 8 * rowId, count * 8); + data, data.getBaseOffset() + 8 * rowId, count * 8); } @Override public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 8, count * 8); + data, data.getBaseOffset() + rowId * 8, count * 8); } @Override public final double getDouble(int rowId) { if (dictionary == null) { - return Platform.getDouble(null, data + rowId * 8); + return Platform.getDouble(data, data.getBaseOffset() + rowId * 8); } else { return dictionary.decodeToDouble(dictionaryIds.getInt(rowId)); } @@ -355,26 +355,26 @@ public final double getDouble(int rowId) { @Override public final void putArray(int rowId, int offset, int length) { assert(offset >= 0 && offset + length <= childColumns[0].capacity); - Platform.putInt(null, lengthData + 4 * rowId, length); - Platform.putInt(null, offsetData + 4 * rowId, offset); + Platform.putInt(lengthData, lengthData.getBaseOffset() + 4 * rowId, length); + Platform.putInt(offsetData, offsetData.getBaseOffset() + 4 * rowId, offset); } @Override public final int getArrayLength(int rowId) { - return Platform.getInt(null, lengthData + 4 * rowId); + return Platform.getInt(lengthData, lengthData.getBaseOffset() + 4 * rowId); } @Override public final int getArrayOffset(int rowId) { - return Platform.getInt(null, offsetData + 4 * rowId); + return Platform.getInt(offsetData, offsetData.getBaseOffset() + 4 * rowId); } // APIs dealing with ByteArrays @Override public final int putByteArray(int rowId, byte[] value, int offset, int length) { int result = arrayData().appendBytes(length, value, offset); - Platform.putInt(null, lengthData + 4 * rowId, length); - Platform.putInt(null, offsetData + 4 * rowId, result); + Platform.putInt(lengthData, lengthData.getBaseOffset() + 4 * rowId, length); + Platform.putInt(offsetData, offsetData.getBaseOffset() + 4 * rowId, result); return result; } @@ -382,7 +382,7 @@ public final int putByteArray(int rowId, byte[] value, int offset, int length) { public final void loadBytes(ColumnVector.Array array) { if (array.tmpByteArray.length < array.length) array.tmpByteArray = new byte[array.length]; Platform.copyMemory( - null, data + array.offset, array.tmpByteArray, Platform.BYTE_ARRAY_OFFSET, array.length); + data, data.getBaseOffset() + array.offset, array.tmpByteArray, Platform.BYTE_ARRAY_OFFSET, array.length); array.byteArray = array.tmpByteArray; array.byteArrayOffset = 0; } @@ -395,27 +395,25 @@ public final void reserve(int requiredCapacity) { // Split out the slow path. private final void reserveInternal(int newCapacity) { if (this.resultArray != null) { - this.lengthData = - Platform.reallocateMemory(lengthData, elementsAppended * 4, newCapacity * 4); - this.offsetData = - Platform.reallocateMemory(offsetData, elementsAppended * 4, newCapacity * 4); + this.lengthData = MemoryAllocator.UNSAFE.reallocate(this.lengthData, elementsAppended * 4, newCapacity * 4); + this.offsetData = MemoryAllocator.UNSAFE.reallocate(this.offsetData, elementsAppended * 4, newCapacity * 4); } else if (type instanceof ByteType || type instanceof BooleanType) { - this.data = Platform.reallocateMemory(data, elementsAppended, newCapacity); + this.data = MemoryAllocator.UNSAFE.reallocate(this.data, elementsAppended, newCapacity); } else if (type instanceof ShortType) { - this.data = Platform.reallocateMemory(data, elementsAppended * 2, newCapacity * 2); + this.data = MemoryAllocator.UNSAFE.reallocate(this.data, elementsAppended * 2, newCapacity * 2); } else if (type instanceof IntegerType || type instanceof FloatType || type instanceof DateType || DecimalType.is32BitDecimalType(type)) { - this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4); + this.data = MemoryAllocator.UNSAFE.reallocate(this.data, elementsAppended * 4, newCapacity * 4); } else if (type instanceof LongType || type instanceof DoubleType || DecimalType.is64BitDecimalType(type)) { - this.data = Platform.reallocateMemory(data, elementsAppended * 8, newCapacity * 8); + this.data = MemoryAllocator.UNSAFE.reallocate(this.data, elementsAppended * 8, newCapacity * 8); } else if (resultStruct != null) { // Nothing to store. } else { throw new RuntimeException("Unhandled " + type); } - this.nulls = Platform.reallocateMemory(nulls, elementsAppended, newCapacity); - Platform.setMemory(nulls + elementsAppended, (byte)0, newCapacity - elementsAppended); + this.nulls = MemoryAllocator.UNSAFE.reallocate(this.nulls, elementsAppended, newCapacity); + Platform.setMemory(nulls.getBaseOffset() + elementsAppended, (byte)0, newCapacity - elementsAppended); capacity = newCapacity; } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index a23ebec95333b..8e2f0cc99a690 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -27,6 +27,7 @@ import com.google.common.io.ByteStreams import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer, SerializerInstance} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock /** * Serializer for serializing [[UnsafeRow]]s during shuffle. Since UnsafeRows are already stored as @@ -93,7 +94,9 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst new DeserializationStream { private[this] val dIn: DataInputStream = new DataInputStream(new BufferedInputStream(in)) // 1024 is a default buffer size; this buffer will grow to accommodate larger rows - private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024) +// private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024) + private[this] var rowBuffer: ByteArrayMemoryBlock = ByteArrayMemoryBlock.fromByteArray( + new Array[Byte](1024)) private[this] var row: UnsafeRow = new UnsafeRow(numFields) private[this] var rowTuple: (Int, UnsafeRow) = (0, row) private[this] val EOF: Int = -1 @@ -113,10 +116,10 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def hasNext: Boolean = rowSize != EOF override def next(): (Int, UnsafeRow) = { - if (rowBuffer.length < rowSize) { - rowBuffer = new Array[Byte](rowSize) + if (rowBuffer.size() < rowSize) { + rowBuffer = ByteArrayMemoryBlock.fromByteArray(new Array[Byte](rowSize)) } - ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) + ByteStreams.readFully(dIn, rowBuffer.getByteArray, 0, rowSize) row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, rowSize) rowSize = readSize() if (rowSize == EOF) { // We are returning the last row in this stream @@ -148,10 +151,10 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def readValue[T: ClassTag](): T = { val rowSize = dIn.readInt() - if (rowBuffer.length < rowSize) { - rowBuffer = new Array[Byte](rowSize) + if (rowBuffer.size() < rowSize) { + rowBuffer = ByteArrayMemoryBlock.fromByteArray(new Array[Byte](rowSize)) } - ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) + ByteStreams.readFully(dIn, rowBuffer.getByteArray, 0, rowSize) row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, rowSize) row.asInstanceOf[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index 3ec01185c4328..b0d1c9751a30d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock import org.apache.spark.unsafe.types.UTF8String @@ -577,7 +578,7 @@ private[columnar] case class STRUCT(dataType: StructType) buffer.position(cursor + sizeInBytes) val unsafeRow = new UnsafeRow(numOfFields) unsafeRow.pointTo( - buffer.array(), + ByteArrayMemoryBlock.fromByteArray(buffer.array()), Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor, sizeInBytes) unsafeRow @@ -616,7 +617,7 @@ private[columnar] case class ARRAY(dataType: ArrayType) buffer.position(cursor + numBytes) val array = new UnsafeArrayData array.pointTo( - buffer.array(), + ByteArrayMemoryBlock.fromByteArray(buffer.array()), Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor, numBytes) array @@ -654,7 +655,7 @@ private[columnar] case class MAP(dataType: MapType) buffer.position(cursor + numBytes) val map = new UnsafeMapData map.pointTo( - buffer.array(), + ByteArrayMemoryBlock.fromByteArray(buffer.array()), Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor, numBytes) map diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 738b9a35d1c9d..08a595ff7a8b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -119,7 +119,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera import java.nio.ByteOrder; import scala.collection.Iterator; import org.apache.spark.sql.types.DataType; - import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; + import org.apache.spark.sql.catalyst.expressions.codegen.MemoryBlockHolder; import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; import org.apache.spark.sql.execution.columnar.MutableUnsafeRow; @@ -132,7 +132,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera private ByteOrder nativeOrder = null; private byte[][] buffers = null; private UnsafeRow unsafeRow = new UnsafeRow($numFields); - private BufferHolder bufferHolder = new BufferHolder(unsafeRow); + private MemoryBlockHolder bufferHolder = new MemoryBlockHolder(unsafeRow); private UnsafeRowWriter rowWriter = new UnsafeRowWriter(bufferHolder, $numFields); private MutableUnsafeRow mutableRow = null; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index 8f3f6335e4282..7723431224749 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -30,7 +30,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.expressions.codegen.{MemoryBlockHolder, UnsafeRowWriter} import org.apache.spark.sql.execution.datasources.{CompressionCodecs, PartitionSpec} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} @@ -99,7 +99,7 @@ private[sql] class TextRelation( conf.asInstanceOf[JobConf], classOf[TextInputFormat], classOf[LongWritable], classOf[Text]) .mapPartitions { iter => val unsafeRow = new UnsafeRow(1) - val bufferHolder = new BufferHolder(unsafeRow) + val bufferHolder = new MemoryBlockHolder(unsafeRow) val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1) iter.map { case (_, line) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 9a3cdaf697e2d..249b82cf46211 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.execution.local.LocalNode import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap +import org.apache.spark.unsafe.memory.{ByteArrayMemoryBlock, MemoryBlock, MemoryLocation} import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator, Utils} import org.apache.spark.util.collection.CompactBuffer @@ -308,7 +309,7 @@ private[joins] final class UnsafeHashedRelation( out.writeInt(binaryMap.numElements()) var buffer = new Array[Byte](64) - def write(base: Object, offset: Long, length: Int): Unit = { + def write(base: MemoryBlock, offset: Long, length: Int): Unit = { if (buffer.length < length) { buffer = new Array[Byte](length) } @@ -392,19 +393,19 @@ private[joins] final class UnsafeHashedRelation( pageSizeBytes) var i = 0 - var keyBuffer = new Array[Byte](1024) - var valuesBuffer = new Array[Byte](1024) + var keyBuffer = ByteArrayMemoryBlock.fromByteArray(new Array[Byte](1024)) + var valuesBuffer = ByteArrayMemoryBlock.fromByteArray(new Array[Byte](1024)) while (i < nKeys) { val keySize = in.readInt() val valuesSize = in.readInt() - if (keySize > keyBuffer.length) { - keyBuffer = new Array[Byte](keySize) + if (keySize > keyBuffer.size()) { + keyBuffer = ByteArrayMemoryBlock.fromByteArray(new Array[Byte](keySize)) } - in.readFully(keyBuffer, 0, keySize) - if (valuesSize > valuesBuffer.length) { - valuesBuffer = new Array[Byte](valuesSize) + in.readFully(keyBuffer.getByteArray, 0, keySize) + if (valuesSize > valuesBuffer.size()) { + valuesBuffer = ByteArrayMemoryBlock.fromByteArray(new Array[Byte](valuesSize)) } - in.readFully(valuesBuffer, 0, valuesSize) + in.readFully(valuesBuffer.getByteArray, 0, valuesSize) // put it into binary map val loc = binaryMap.lookup(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize) @@ -547,7 +548,8 @@ private[joins] final class LongArrayRelation( val idx = (key - start).toInt if (idx >= 0 && idx < sizes.length && sizes(idx) > 0) { val result = new UnsafeRow(numFields) - result.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx)) + result.pointTo( ByteArrayMemoryBlock.fromByteArray(bytes), + Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx) ) result } else { null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index a32763db054f3..7e2c040f4170e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -26,35 +26,35 @@ import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform -import org.apache.spark.unsafe.memory.MemoryAllocator +import org.apache.spark.unsafe.memory.{OffHeapMemoryBlock, ByteArrayMemoryBlock, MemoryAllocator} import org.apache.spark.unsafe.types.UTF8String class UnsafeRowSuite extends SparkFunSuite { test("UnsafeRow Java serialization") { // serializing an UnsafeRow pointing to a large buffer should only serialize the relevant data - val data = new Array[Byte](1024) + val data = ByteArrayMemoryBlock.fromByteArray(new Array[Byte](1024)) val row = new UnsafeRow(1) - row.pointTo(data, 16) + row.pointTo(data, Platform.BYTE_ARRAY_OFFSET, 16) row.setLong(0, 19285) val ser = new JavaSerializer(new SparkConf).newInstance() val row1 = ser.deserialize[UnsafeRow](ser.serialize(row)) assert(row1.getLong(0) == 19285) - assert(row1.getBaseObject().asInstanceOf[Array[Byte]].length == 16) + assert(row1.getBaseObject().asInstanceOf[ByteArrayMemoryBlock].getByteArray.length == 16) } test("UnsafeRow Kryo serialization") { // serializing an UnsafeRow pointing to a large buffer should only serialize the relevant data - val data = new Array[Byte](1024) + val data = ByteArrayMemoryBlock.fromByteArray(new Array[Byte](1024)) val row = new UnsafeRow(1) - row.pointTo(data, 16) + row.pointTo(data, Platform.BYTE_ARRAY_OFFSET, 16) row.setLong(0, 19285) val ser = new KryoSerializer(new SparkConf).newInstance() val row1 = ser.deserialize[UnsafeRow](ser.serialize(row)) assert(row1.getLong(0) == 19285) - assert(row1.getBaseObject().asInstanceOf[Array[Byte]].length == 16) + assert(row1.getBaseObject().asInstanceOf[ByteArrayMemoryBlock].getByteArray.length == 16) } test("bitset width calculation") { @@ -70,7 +70,7 @@ class UnsafeRowSuite extends SparkFunSuite { val row = InternalRow.apply(UTF8String.fromString("hello"), UTF8String.fromString("world"), 123) val arrayBackedUnsafeRow: UnsafeRow = UnsafeProjection.create(Array[DataType](StringType, StringType, IntegerType)).apply(row) - assert(arrayBackedUnsafeRow.getBaseObject.isInstanceOf[Array[Byte]]) +// assert(arrayBackedUnsafeRow.getBaseObject.isInstanceOf[Array[Byte]]) val (bytesFromArrayBackedRow, field0StringFromArrayBackedRow): (Array[Byte], String) = { val baos = new ByteArrayOutputStream() arrayBackedUnsafeRow.writeToStream(baos, null) @@ -82,17 +82,17 @@ class UnsafeRowSuite extends SparkFunSuite { Platform.copyMemory( arrayBackedUnsafeRow.getBaseObject, arrayBackedUnsafeRow.getBaseOffset, - offheapRowPage.getBaseObject, + offheapRowPage, offheapRowPage.getBaseOffset, arrayBackedUnsafeRow.getSizeInBytes ) val offheapUnsafeRow: UnsafeRow = new UnsafeRow(3) offheapUnsafeRow.pointTo( - offheapRowPage.getBaseObject, + offheapRowPage, offheapRowPage.getBaseOffset, arrayBackedUnsafeRow.getSizeInBytes ) - assert(offheapUnsafeRow.getBaseObject === null) + assert(offheapUnsafeRow.getBaseObject.isInstanceOf[OffHeapMemoryBlock]) val baos = new ByteArrayOutputStream() val writeBuffer = new Array[Byte](1024) offheapUnsafeRow.writeToStream(baos, writeBuffer) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 2d3e34d0e1292..76c4e25d4e6f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.types.IntegerType import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 import org.apache.spark.unsafe.map.BytesToBytesMap +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock import org.apache.spark.util.Benchmark /** @@ -259,7 +260,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { benchmark.addCase("hash") { iter => var i = 0 - val keyBytes = new Array[Byte](16) + val keyBytes = ByteArrayMemoryBlock.fromByteArray(new Array[Byte](16)) val key = new UnsafeRow(1) key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) var s = 0 @@ -274,7 +275,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { benchmark.addCase("fast hash") { iter => var i = 0 - val keyBytes = new Array[Byte](16) + val keyBytes = ByteArrayMemoryBlock.fromByteArray(new Array[Byte](16)) val key = new UnsafeRow(1) key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) var s = 0 @@ -288,8 +289,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { benchmark.addCase("arrayEqual") { iter => var i = 0 - val keyBytes = new Array[Byte](16) - val valueBytes = new Array[Byte](16) + val keyBytes = ByteArrayMemoryBlock.fromByteArray(new Array[Byte](16)) + val valueBytes = ByteArrayMemoryBlock.fromByteArray(new Array[Byte](16)) val key = new UnsafeRow(1) key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) val value = new UnsafeRow(1) @@ -307,8 +308,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { benchmark.addCase("Java HashMap (Long)") { iter => var i = 0 - val keyBytes = new Array[Byte](16) - val valueBytes = new Array[Byte](16) + val keyBytes = ByteArrayMemoryBlock.fromByteArray(new Array[Byte](16)) + val valueBytes = ByteArrayMemoryBlock.fromByteArray(new Array[Byte](16)) val value = new UnsafeRow(1) value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) value.setInt(0, 555) @@ -330,7 +331,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { benchmark.addCase("Java HashMap (two ints) ") { iter => var i = 0 - val valueBytes = new Array[Byte](16) + val valueBytes = ByteArrayMemoryBlock.fromByteArray(new Array[Byte](16)) val value = new UnsafeRow(1) value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) value.setInt(0, 555) @@ -354,8 +355,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { benchmark.addCase("Java HashMap (UnsafeRow)") { iter => var i = 0 - val keyBytes = new Array[Byte](16) - val valueBytes = new Array[Byte](16) + val keyBytes = ByteArrayMemoryBlock.fromByteArray(new Array[Byte](16)) + val valueBytes = ByteArrayMemoryBlock.fromByteArray(new Array[Byte](16)) val key = new UnsafeRow(1) key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) val value = new UnsafeRow(1) @@ -390,8 +391,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { 1), 0) val map = new BytesToBytesMap(taskMemoryManager, 1024, 64L<<20) - val keyBytes = new Array[Byte](16) - val valueBytes = new Array[Byte](16) + val keyBytes = ByteArrayMemoryBlock.fromByteArray(new Array[Byte](16)) + val valueBytes = ByteArrayMemoryBlock.fromByteArray(new Array[Byte](16)) val key = new UnsafeRow(1) key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) val value = new UnsafeRow(1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index 97638a66ab473..422c80f65f5c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.vectorized.ColumnVector import org.apache.spark.sql.types.{BinaryType, IntegerType} import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.memory.{ByteArrayMemoryBlock, MemoryBlock, MemoryAllocator} import org.apache.spark.util.Benchmark import org.apache.spark.util.collection.BitSet @@ -117,20 +118,20 @@ object ColumnarBatchBenchmark { // Using unsafe memory val unsafeBuffer = { i: Int => - val data: Long = Platform.allocateMemory(count * 4) + val data = MemoryAllocator.UNSAFE.allocate(count * 4) var sum = 0L for (n <- 0L until iters) { - var ptr = data + var ptr = data.getBaseOffset var i = 0 while (i < count) { - Platform.putInt(null, ptr, i) + Platform.putInt(data, ptr, i) ptr += 4 i += 1 } - ptr = data + ptr = data.getBaseOffset i = 0 while (i < count) { - sum += Platform.getInt(null, ptr) + sum += Platform.getInt(data, ptr) ptr += 4 i += 1 } @@ -183,14 +184,14 @@ object ColumnarBatchBenchmark { var addr = col.valuesNativeAddress() var i = 0 while (i < count) { - Platform.putInt(null, addr, i) + Platform.putInt(null.asInstanceOf[Array[Byte]], addr, i) addr += 4 i += 1 } i = 0 addr = col.valuesNativeAddress() while (i < count) { - sum += Platform.getInt(null, addr) + sum += Platform.getInt(null.asInstanceOf[Array[Byte]], addr) addr += 4 i += 1 } @@ -200,7 +201,7 @@ object ColumnarBatchBenchmark { // Access by going through a batch of unsafe rows. val unsafeRowOnheap = { i: Int => - val buffer = new Array[Byte](count * 16) + val buffer = ByteArrayMemoryBlock.fromByteArray(new Array[Byte](count * 16)) var sum = 0L for (n <- 0L until iters) { val row = new UnsafeRow(1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index b3c3e66fbcbd5..998abe66516ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.memory.{ByteArrayMemoryBlock, MemoryBlock} import org.apache.spark.unsafe.types.CalendarInterval class ColumnarBatchSuite extends SparkFunSuite { @@ -69,7 +70,8 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == column.getIsNull(v._2)) if (memMode == MemoryMode.OFF_HEAP) { val addr = column.nullsNativeAddress() - assert(v._1 == (Platform.getByte(null, addr + v._2) == 1), "index=" + v._2) + assert(v._1 == (Platform.getByte(null.asInstanceOf[Array[Byte]], addr + v._2) == 1), + "index=" + v._2) } } column.close @@ -109,7 +111,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == column.getByte(v._2), "MemoryMode" + memMode) if (memMode == MemoryMode.OFF_HEAP) { val addr = column.valuesNativeAddress() - assert(v._1 == Platform.getByte(null, addr + v._2)) + assert(v._1 == Platform.getByte(null.asInstanceOf[Array[Byte]], addr + v._2)) } } }} @@ -176,7 +178,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == column.getInt(v._2), "Seed = " + seed + " Mem Mode=" + memMode) if (memMode == MemoryMode.OFF_HEAP) { val addr = column.valuesNativeAddress() - assert(v._1 == Platform.getInt(null, addr + 4 * v._2)) + assert(v._1 == Platform.getInt(null.asInstanceOf[Array[Byte]], addr + 4 * v._2)) } } column.close @@ -247,7 +249,7 @@ class ColumnarBatchSuite extends SparkFunSuite { " Seed = " + seed + " MemMode=" + memMode) if (memMode == MemoryMode.OFF_HEAP) { val addr = column.valuesNativeAddress() - assert(v._1 == Platform.getLong(null, addr + 8 * v._2)) + assert(v._1 == Platform.getLong(null.asInstanceOf[Array[Byte]], addr + 8 * v._2)) } } }} @@ -274,17 +276,17 @@ class ColumnarBatchSuite extends SparkFunSuite { reference += 5.0 idx += 3 - val buffer = new Array[Byte](16) + val buffer = ByteArrayMemoryBlock.fromByteArray(new Array[Byte](16)) Platform.putDouble(buffer, Platform.BYTE_ARRAY_OFFSET, 2.234) Platform.putDouble(buffer, Platform.BYTE_ARRAY_OFFSET + 8, 1.123) - column.putDoubles(idx, 1, buffer, 8) - column.putDoubles(idx + 1, 1, buffer, 0) + column.putDoubles(idx, 1, buffer.getByteArray, 8) + column.putDoubles(idx + 1, 1, buffer.getByteArray, 0) reference += 1.123 reference += 2.234 idx += 2 - column.putDoubles(idx, 2, buffer, 0) + column.putDoubles(idx, 2, buffer.getByteArray, 0) reference += 2.234 reference += 1.123 idx += 2 @@ -313,7 +315,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == column.getDouble(v._2), "Seed = " + seed + " MemMode=" + memMode) if (memMode == MemoryMode.OFF_HEAP) { val addr = column.valuesNativeAddress() - assert(v._1 == Platform.getDouble(null, addr + 8 * v._2)) + assert(v._1 == Platform.getDouble(null.asInstanceOf[Array[Byte]], addr + 8 * v._2)) } } column.close